### What problem does this PR solve? - Add comprehensive test suite for chunk operations including: - Test files for create, list, retrieve, update, and delete chunks - Authorization tests - Batch operations tests - Update test configurations and common utilities - Validate `important_kwd` and `question_kwd` fields are lists in chunk_app.py - Reorganize imports and clean up duplicate code ### Type of change - [x] Add test casestags/v0.20.0
| # | # | ||||
| import datetime | import datetime | ||||
| import json | import json | ||||
| import re | |||||
| import xxhash | |||||
| from flask import request | from flask import request | ||||
| from flask_login import login_required, current_user | |||||
| from flask_login import current_user, login_required | |||||
| from rag.app.qa import rmPrefix, beAdoc | |||||
| from rag.app.tag import label_question | |||||
| from rag.nlp import search, rag_tokenizer | |||||
| from rag.prompts import keyword_extraction, cross_languages | |||||
| from rag.settings import PAGERANK_FLD | |||||
| from rag.utils import rmSpace | |||||
| from api import settings | |||||
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.document_service import DocumentService | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.user_service import UserTenantService | from api.db.services.user_service import UserTenantService | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||||
| from api.db.services.document_service import DocumentService | |||||
| from api import settings | |||||
| from api.utils.api_utils import get_json_result | |||||
| import xxhash | |||||
| import re | |||||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | |||||
| from rag.app.qa import beAdoc, rmPrefix | |||||
| from rag.app.tag import label_question | |||||
| from rag.nlp import rag_tokenizer, search | |||||
| from rag.prompts import cross_languages, keyword_extraction | |||||
| from rag.settings import PAGERANK_FLD | |||||
| from rag.utils import rmSpace | |||||
| @manager.route('/list', methods=['POST']) # noqa: F821 | @manager.route('/list', methods=['POST']) # noqa: F821 | ||||
| d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | ||||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | ||||
| if "important_kwd" in req: | if "important_kwd" in req: | ||||
| if not isinstance(req["important_kwd"], list): | |||||
| return get_data_error_result(message="`important_kwd` should be a list") | |||||
| d["important_kwd"] = req["important_kwd"] | d["important_kwd"] = req["important_kwd"] | ||||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | ||||
| if "question_kwd" in req: | if "question_kwd" in req: | ||||
| if not isinstance(req["question_kwd"], list): | |||||
| return get_data_error_result(message="`question_kwd` should be a list") | |||||
| d["question_kwd"] = req["question_kwd"] | d["question_kwd"] = req["question_kwd"] | ||||
| d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) | d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) | ||||
| if "tag_kwd" in req: | if "tag_kwd" in req: | ||||
| d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | ||||
| if "tag_feas" in req: | if "tag_feas" in req: | ||||
| d["tag_feas"] = req["tag_feas"] | d["tag_feas"] = req["tag_feas"] | ||||
| if "tag_feas" in req: | |||||
| d["tag_feas"] = req["tag_feas"] | |||||
| try: | try: | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from time import sleep | from time import sleep | ||||
| import pytest | import pytest |
| KB_APP_URL = f"/{VERSION}/kb" | KB_APP_URL = f"/{VERSION}/kb" | ||||
| DOCUMENT_APP_URL = f"/{VERSION}/document" | DOCUMENT_APP_URL = f"/{VERSION}/document" | ||||
| # FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" | |||||
| # FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" | |||||
| # CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" | |||||
| CHUNK_API_URL = f"/{VERSION}/chunk" | |||||
| # CHAT_ASSISTANT_API_URL = "/api/v1/chats" | # CHAT_ASSISTANT_API_URL = "/api/v1/chats" | ||||
| # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" | # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" | ||||
| # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" | # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" | ||||
| for document in res["data"]: | for document in res["data"]: | ||||
| document_ids.append(document["id"]) | document_ids.append(document["id"]) | ||||
| return document_ids | return document_ids | ||||
| # CHUNK APP | |||||
| def add_chunk(auth, payload=None, *, headers=HEADERS, data=None): | |||||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/create", headers=headers, auth=auth, json=payload, data=data) | |||||
| return res.json() | |||||
| def list_chunks(auth, payload=None, *, headers=HEADERS): | |||||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/list", headers=headers, auth=auth, json=payload) | |||||
| return res.json() | |||||
| def get_chunk(auth, params=None, *, headers=HEADERS): | |||||
| res = requests.get(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/get", headers=headers, auth=auth, params=params) | |||||
| return res.json() | |||||
| def update_chunk(auth, payload=None, *, headers=HEADERS): | |||||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/set", headers=headers, auth=auth, json=payload) | |||||
| return res.json() | |||||
| def delete_chunks(auth, payload=None, *, headers=HEADERS): | |||||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload) | |||||
| return res.json() | |||||
| def retrieval_chunks(auth, payload=None, *, headers=HEADERS): | |||||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/retrieval_test", headers=headers, auth=auth, json=payload) | |||||
| return res.json() | |||||
| def batch_add_chunks(auth, doc_id, num): | |||||
| chunk_ids = [] | |||||
| for i in range(num): | |||||
| res = add_chunk(auth, {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"}) | |||||
| chunk_ids.append(res["data"]["chunk_id"]) | |||||
| return chunk_ids |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from time import sleep | |||||
| import pytest | import pytest | ||||
| from common import ( | from common import ( | ||||
| batch_add_chunks, | |||||
| batch_create_datasets, | batch_create_datasets, | ||||
| bulk_upload_documents, | |||||
| delete_chunks, | |||||
| list_chunks, | |||||
| list_documents, | |||||
| list_kbs, | list_kbs, | ||||
| parse_documents, | |||||
| rm_kb, | rm_kb, | ||||
| ) | ) | ||||
| # from configs import HOST_ADDRESS, VERSION | |||||
| from libs.auth import RAGFlowWebApiAuth | from libs.auth import RAGFlowWebApiAuth | ||||
| from pytest import FixtureRequest | from pytest import FixtureRequest | ||||
| # from ragflow_sdk import RAGFlow | |||||
| from utils import wait_for | |||||
| from utils.file_utils import ( | from utils.file_utils import ( | ||||
| create_docx_file, | create_docx_file, | ||||
| create_eml_file, | create_eml_file, | ||||
| ) | ) | ||||
| @wait_for(30, 1, "Document parsing timeout") | |||||
| def condition(_auth, _kb_id): | |||||
| res = list_documents(_auth, {"kb_id": _kb_id}) | |||||
| for doc in res["data"]["docs"]: | |||||
| if doc["run"] != "3": | |||||
| return False | |||||
| return True | |||||
| @pytest.fixture | @pytest.fixture | ||||
| def generate_test_files(request: FixtureRequest, tmp_path): | def generate_test_files(request: FixtureRequest, tmp_path): | ||||
| file_creators = { | file_creators = { | ||||
| return RAGFlowWebApiAuth(auth) | return RAGFlowWebApiAuth(auth) | ||||
| # @pytest.fixture(scope="session") | |||||
| # def client(token: str) -> RAGFlow: | |||||
| # return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) | |||||
| @pytest.fixture(scope="function") | @pytest.fixture(scope="function") | ||||
| def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth): | def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth): | ||||
| def cleanup(): | def cleanup(): | ||||
| request.addfinalizer(cleanup) | request.addfinalizer(cleanup) | ||||
| return batch_create_datasets(WebApiAuth, 1)[0] | return batch_create_datasets(WebApiAuth, 1)[0] | ||||
| @pytest.fixture(scope="class") | |||||
| def add_document(request, WebApiAuth, add_dataset, ragflow_tmp_dir): | |||||
| # def cleanup(): | |||||
| # res = list_documents(WebApiAuth, {"kb_id": dataset_id}) | |||||
| # for doc in res["data"]["docs"]: | |||||
| # delete_document(WebApiAuth, {"doc_id": doc["id"]}) | |||||
| # request.addfinalizer(cleanup) | |||||
| dataset_id = add_dataset | |||||
| return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 1, ragflow_tmp_dir)[0] | |||||
| @pytest.fixture(scope="class") | |||||
| def add_chunks(request, WebApiAuth, add_document): | |||||
| def cleanup(): | |||||
| res = list_chunks(WebApiAuth, {"doc_id": document_id}) | |||||
| if res["code"] == 0: | |||||
| chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]] | |||||
| delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids}) | |||||
| request.addfinalizer(cleanup) | |||||
| kb_id, document_id = add_document | |||||
| parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) | |||||
| condition(WebApiAuth, kb_id) | |||||
| chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4) | |||||
| # issues/6487 | |||||
| sleep(1) | |||||
| return kb_id, document_id, chunk_ids |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from time import sleep | |||||
| import pytest | |||||
| from common import batch_add_chunks, delete_chunks, list_chunks, list_documents, parse_documents | |||||
| from utils import wait_for | |||||
| @wait_for(30, 1, "Document parsing timeout") | |||||
| def condition(_auth, _kb_id): | |||||
| res = list_documents(_auth, {"kb_id": _kb_id}) | |||||
| for doc in res["data"]["docs"]: | |||||
| if doc["run"] != "3": | |||||
| return False | |||||
| return True | |||||
| @pytest.fixture(scope="function") | |||||
| def add_chunks_func(request, WebApiAuth, add_document): | |||||
| def cleanup(): | |||||
| res = list_chunks(WebApiAuth, {"doc_id": document_id}) | |||||
| chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]] | |||||
| delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids}) | |||||
| request.addfinalizer(cleanup) | |||||
| kb_id, document_id = add_document | |||||
| parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) | |||||
| condition(WebApiAuth, kb_id) | |||||
| chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4) | |||||
| # issues/6487 | |||||
| sleep(1) | |||||
| return kb_id, document_id, chunk_ids |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| import pytest | |||||
| from common import add_chunk, delete_document, get_chunk, list_chunks | |||||
| from configs import INVALID_API_TOKEN | |||||
| from libs.auth import RAGFlowWebApiAuth | |||||
| def validate_chunk_details(auth, kb_id, doc_id, payload, res): | |||||
| chunk_id = res["data"]["chunk_id"] | |||||
| res = get_chunk(auth, {"chunk_id": chunk_id}) | |||||
| assert res["code"] == 0, res | |||||
| chunk = res["data"] | |||||
| assert chunk["doc_id"] == doc_id | |||||
| assert chunk["kb_id"] == kb_id | |||||
| assert chunk["content_with_weight"] == payload["content_with_weight"] | |||||
| if "important_kwd" in payload: | |||||
| assert chunk["important_kwd"] == payload["important_kwd"] | |||||
| if "question_kwd" in payload: | |||||
| expected = [str(q).strip() for q in payload.get("question_kwd", [])] | |||||
| assert chunk["question_kwd"] == expected | |||||
| @pytest.mark.p1 | |||||
| class TestAuthorization: | |||||
| @pytest.mark.parametrize( | |||||
| "invalid_auth, expected_code, expected_message", | |||||
| [ | |||||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||||
| res = add_chunk(invalid_auth) | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| class TestAddChunk: | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"content_with_weight": None}, 100, """TypeError("unsupported operand type(s) for +: 'NoneType' and 'str'")"""), | |||||
| ({"content_with_weight": ""}, 0, ""), | |||||
| pytest.param( | |||||
| {"content_with_weight": 1}, | |||||
| 100, | |||||
| """TypeError("unsupported operand type(s) for +: 'int' and 'str'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| ({"content_with_weight": "a"}, 0, ""), | |||||
| ({"content_with_weight": " "}, 0, ""), | |||||
| ({"content_with_weight": "\n!?。;!?\"'"}, 0, ""), | |||||
| ], | |||||
| ) | |||||
| def test_content(self, WebApiAuth, add_document, payload, expected_code, expected_message): | |||||
| kb_id, doc_id = add_document | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] == 0: | |||||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||||
| else: | |||||
| chunks_count = 0 | |||||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"content_with_weight": "chunk test", "important_kwd": ["a", "b", "c"]}, 0, ""), | |||||
| ({"content_with_weight": "chunk test", "important_kwd": [""]}, 0, ""), | |||||
| ( | |||||
| {"content_with_weight": "chunk test", "important_kwd": [1]}, | |||||
| 100, | |||||
| "TypeError('sequence item 0: expected str instance, int found')", | |||||
| ), | |||||
| ({"content_with_weight": "chunk test", "important_kwd": ["a", "a"]}, 0, ""), | |||||
| ({"content_with_weight": "chunk test", "important_kwd": "abc"}, 102, "`important_kwd` is required to be a list"), | |||||
| ({"content_with_weight": "chunk test", "important_kwd": 123}, 102, "`important_kwd` is required to be a list"), | |||||
| ], | |||||
| ) | |||||
| def test_important_keywords(self, WebApiAuth, add_document, payload, expected_code, expected_message): | |||||
| kb_id, doc_id = add_document | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] == 0: | |||||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||||
| else: | |||||
| chunks_count = 0 | |||||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"content_with_weight": "chunk test", "question_kwd": ["a", "b", "c"]}, 0, ""), | |||||
| ({"content_with_weight": "chunk test", "question_kwd": [""]}, 0, ""), | |||||
| ({"content_with_weight": "chunk test", "question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), | |||||
| ({"content_with_weight": "chunk test", "question_kwd": ["a", "a"]}, 0, ""), | |||||
| ({"content_with_weight": "chunk test", "question_kwd": "abc"}, 102, "`question_kwd` is required to be a list"), | |||||
| ({"content_with_weight": "chunk test", "question_kwd": 123}, 102, "`question_kwd` is required to be a list"), | |||||
| ], | |||||
| ) | |||||
| def test_questions(self, WebApiAuth, add_document, payload, expected_code, expected_message): | |||||
| kb_id, doc_id = add_document | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] == 0: | |||||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||||
| else: | |||||
| chunks_count = 0 | |||||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "doc_id, expected_code, expected_message", | |||||
| [ | |||||
| ("", 102, "Document not found!"), | |||||
| ("invalid_document_id", 102, "Document not found!"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_document_id(self, WebApiAuth, add_document, doc_id, expected_code, expected_message): | |||||
| _, _ = add_document | |||||
| res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"}) | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| def test_repeated_add_chunk(self, WebApiAuth, add_document): | |||||
| payload = {"content_with_weight": "chunk test"} | |||||
| kb_id, doc_id = add_document | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] != 0: | |||||
| assert False, res | |||||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] != 0: | |||||
| assert False, res | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] != 0: | |||||
| assert False, res | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 2, res | |||||
| @pytest.mark.p2 | |||||
| def test_add_chunk_to_deleted_document(self, WebApiAuth, add_document): | |||||
| _, doc_id = add_document | |||||
| delete_document(WebApiAuth, {"doc_id": doc_id}) | |||||
| res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"}) | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == "Document not found!", res | |||||
| @pytest.mark.skip(reason="issues/6411") | |||||
| @pytest.mark.p3 | |||||
| def test_concurrent_add_chunk(self, WebApiAuth, add_document): | |||||
| count = 50 | |||||
| _, doc_id = add_document | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] == 0: | |||||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||||
| else: | |||||
| chunks_count = 0 | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [ | |||||
| executor.submit( | |||||
| add_chunk, | |||||
| WebApiAuth, | |||||
| {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"}, | |||||
| ) | |||||
| for i in range(count) | |||||
| ] | |||||
| responses = list(as_completed(futures)) | |||||
| assert len(responses) == count, responses | |||||
| assert all(future.result()["code"] == 0 for future in futures) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + count |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| import pytest | |||||
| from common import batch_add_chunks, list_chunks | |||||
| from configs import INVALID_API_TOKEN | |||||
| from libs.auth import RAGFlowWebApiAuth | |||||
| @pytest.mark.p1 | |||||
| class TestAuthorization: | |||||
| @pytest.mark.parametrize( | |||||
| "invalid_auth, expected_code, expected_message", | |||||
| [ | |||||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||||
| res = list_chunks(invalid_auth, {"doc_id": "document_id"}) | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| class TestChunksList: | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "params, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| pytest.param({"page": None, "size": 2}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip), | |||||
| pytest.param({"page": 0, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), | |||||
| ({"page": 2, "size": 2}, 0, 2, ""), | |||||
| ({"page": 3, "size": 2}, 0, 1, ""), | |||||
| ({"page": "3", "size": 2}, 0, 1, ""), | |||||
| pytest.param({"page": -1, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), | |||||
| pytest.param({"page": "a", "size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), | |||||
| ], | |||||
| ) | |||||
| def test_page(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): | |||||
| _, doc_id, _ = add_chunks | |||||
| payload = {"doc_id": doc_id} | |||||
| if params: | |||||
| payload.update(params) | |||||
| res = list_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "params, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| ({"size": None}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")"""), | |||||
| pytest.param({"size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), | |||||
| pytest.param({"size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), | |||||
| ({"size": 1}, 0, 1, ""), | |||||
| ({"size": 6}, 0, 5, ""), | |||||
| ({"size": "1"}, 0, 1, ""), | |||||
| pytest.param({"size": -1}, 0, 5, "", marks=pytest.mark.skip), | |||||
| pytest.param({"size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), | |||||
| ], | |||||
| ) | |||||
| def test_page_size(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): | |||||
| _, doc_id, _ = add_chunks | |||||
| payload = {"doc_id": doc_id} | |||||
| if params: | |||||
| payload.update(params) | |||||
| res = list_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "params, expected_page_size", | |||||
| [ | |||||
| ({"keywords": None}, 5), | |||||
| ({"keywords": ""}, 5), | |||||
| ({"keywords": "1"}, 1), | |||||
| pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), | |||||
| ({"keywords": "content"}, 1), | |||||
| ({"keywords": "unknown"}, 0), | |||||
| ], | |||||
| ) | |||||
| def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size): | |||||
| _, doc_id, _ = add_chunks | |||||
| payload = {"doc_id": doc_id} | |||||
| if params: | |||||
| payload.update(params) | |||||
| res = list_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| @pytest.mark.p3 | |||||
| def test_invalid_params(self, WebApiAuth, add_chunks): | |||||
| _, doc_id, _ = add_chunks | |||||
| payload = {"doc_id": doc_id, "a": "b"} | |||||
| res = list_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert len(res["data"]["chunks"]) == 5, res | |||||
| @pytest.mark.p3 | |||||
| def test_concurrent_list(self, WebApiAuth, add_chunks): | |||||
| _, doc_id, _ = add_chunks | |||||
| count = 100 | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [executor.submit(list_chunks, WebApiAuth, {"doc_id": doc_id}) for i in range(count)] | |||||
| responses = list(as_completed(futures)) | |||||
| assert len(responses) == count, responses | |||||
| assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures) | |||||
| @pytest.mark.p1 | |||||
| def test_default(self, WebApiAuth, add_document): | |||||
| _, doc_id = add_document | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||||
| batch_add_chunks(WebApiAuth, doc_id, 31) | |||||
| # issues/6487 | |||||
| from time import sleep | |||||
| sleep(3) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0 | |||||
| assert len(res["data"]["chunks"]) == 30 | |||||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 31 |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| import pytest | |||||
| from common import retrieval_chunks | |||||
| from configs import INVALID_API_TOKEN | |||||
| from libs.auth import RAGFlowWebApiAuth | |||||
| @pytest.mark.p1 | |||||
| class TestAuthorization: | |||||
| @pytest.mark.parametrize( | |||||
| "invalid_auth, expected_code, expected_message", | |||||
| [ | |||||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||||
| res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"}) | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| class TestChunksRetrieval: | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| ({"question": "chunk", "kb_id": None}, 0, 4, ""), | |||||
| ({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "), | |||||
| ({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""), | |||||
| ({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "), | |||||
| ], | |||||
| ) | |||||
| def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||||
| dataset_id, document_id, _ = add_chunks | |||||
| if "kb_id" in payload: | |||||
| payload["kb_id"] = [dataset_id] | |||||
| if "doc_ids" in payload: | |||||
| payload["doc_ids"] = [document_id] | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| pytest.param( | |||||
| {"page": None, "size": 2}, | |||||
| 100, | |||||
| 0, | |||||
| """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| pytest.param( | |||||
| {"page": 0, "size": 2}, | |||||
| 100, | |||||
| 0, | |||||
| "ValueError('Search does not support negative slicing.')", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")), | |||||
| ({"page": 3, "size": 2}, 0, 0, ""), | |||||
| ({"page": "3", "size": 2}, 0, 0, ""), | |||||
| pytest.param( | |||||
| {"page": -1, "size": 2}, | |||||
| 100, | |||||
| 0, | |||||
| "ValueError('Search does not support negative slicing.')", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| pytest.param( | |||||
| {"page": "a", "size": 2}, | |||||
| 100, | |||||
| 0, | |||||
| """ValueError("invalid literal for int() with base 10: 'a'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| pytest.param( | |||||
| {"size": None}, | |||||
| 100, | |||||
| 0, | |||||
| """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| # ({"size": 0}, 0, 0, ""), | |||||
| ({"size": 1}, 0, 1, ""), | |||||
| ({"size": 5}, 0, 4, ""), | |||||
| ({"size": "1"}, 0, 1, ""), | |||||
| # ({"size": -1}, 0, 0, ""), | |||||
| pytest.param( | |||||
| {"size": "a"}, | |||||
| 100, | |||||
| 0, | |||||
| """ValueError("invalid literal for int() with base 10: 'a'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| ({"vector_similarity_weight": 0}, 0, 4, ""), | |||||
| ({"vector_similarity_weight": 0.5}, 0, 4, ""), | |||||
| ({"vector_similarity_weight": 10}, 0, 4, ""), | |||||
| pytest.param( | |||||
| {"vector_similarity_weight": "a"}, | |||||
| 100, | |||||
| 0, | |||||
| """ValueError("could not convert string to float: 'a'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| ({"top_k": 10}, 0, 4, ""), | |||||
| pytest.param( | |||||
| {"top_k": 1}, | |||||
| 0, | |||||
| 4, | |||||
| "", | |||||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), | |||||
| ), | |||||
| pytest.param( | |||||
| {"top_k": 1}, | |||||
| 0, | |||||
| 1, | |||||
| "", | |||||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), | |||||
| ), | |||||
| pytest.param( | |||||
| {"top_k": -1}, | |||||
| 100, | |||||
| 4, | |||||
| "must be greater than 0", | |||||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), | |||||
| ), | |||||
| pytest.param( | |||||
| {"top_k": -1}, | |||||
| 100, | |||||
| 4, | |||||
| "3014", | |||||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), | |||||
| ), | |||||
| pytest.param( | |||||
| {"top_k": "a"}, | |||||
| 100, | |||||
| 0, | |||||
| """ValueError("invalid literal for int() with base 10: 'a'")""", | |||||
| marks=pytest.mark.skip, | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.skip | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), | |||||
| pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), | |||||
| ], | |||||
| ) | |||||
| def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) > 0, res | |||||
| else: | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.skip | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_page_size, expected_message", | |||||
| [ | |||||
| ({"keyword": True}, 0, 5, ""), | |||||
| ({"keyword": "True"}, 0, 5, ""), | |||||
| ({"keyword": False}, 0, 5, ""), | |||||
| ({"keyword": "False"}, 0, 5, ""), | |||||
| ({"keyword": None}, 0, 5, ""), | |||||
| ], | |||||
| ) | |||||
| def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk test", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code == 0: | |||||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||||
| else: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_highlight, expected_message", | |||||
| [ | |||||
| ({"highlight": True}, 0, True, ""), | |||||
| ({"highlight": "True"}, 0, True, ""), | |||||
| pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||||
| pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||||
| pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||||
| ], | |||||
| ) | |||||
| def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_highlight: | |||||
| for chunk in res["data"]["chunks"]: | |||||
| assert "highlight" in chunk, res | |||||
| else: | |||||
| for chunk in res["data"]["chunks"]: | |||||
| assert "highlight" not in chunk, res | |||||
| if expected_code != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| def test_invalid_params(self, WebApiAuth, add_chunks): | |||||
| dataset_id, _, _ = add_chunks | |||||
| payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"} | |||||
| res = retrieval_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert len(res["data"]["chunks"]) == 4, res | |||||
| @pytest.mark.p3 | |||||
| def test_concurrent_retrieval(self, WebApiAuth, add_chunks): | |||||
| dataset_id, _, _ = add_chunks | |||||
| count = 100 | |||||
| payload = {"question": "chunk", "kb_id": [dataset_id]} | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)] | |||||
| responses = list(as_completed(futures)) | |||||
| assert len(responses) == count, responses | |||||
| assert all(future.result()["code"] == 0 for future in futures) |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| import pytest | |||||
| from common import batch_add_chunks, delete_chunks, list_chunks | |||||
| from configs import INVALID_API_TOKEN | |||||
| from libs.auth import RAGFlowWebApiAuth | |||||
| @pytest.mark.p1 | |||||
| class TestAuthorization: | |||||
| @pytest.mark.parametrize( | |||||
| "invalid_auth, expected_code, expected_message", | |||||
| [ | |||||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||||
| res = delete_chunks(invalid_auth, {"doc_id": "document_id", "chunk_ids": ["1"]}) | |||||
| assert res["code"] == expected_code | |||||
| assert res["message"] == expected_message | |||||
| class TestChunksDeletion: | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "doc_id, expected_code, expected_message", | |||||
| [ | |||||
| ("", 102, "Document not found!"), | |||||
| ("invalid_document_id", 102, "Document not found!"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_document_id(self, WebApiAuth, add_chunks_func, doc_id, expected_code, expected_message): | |||||
| _, _, chunk_ids = add_chunks_func | |||||
| res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids}) | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.parametrize( | |||||
| "payload", | |||||
| [ | |||||
| pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), | |||||
| pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), | |||||
| pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), | |||||
| ], | |||||
| ) | |||||
| def test_delete_partial_invalid_id(self, WebApiAuth, add_chunks_func, payload): | |||||
| _, doc_id, chunk_ids = add_chunks_func | |||||
| if callable(payload): | |||||
| payload = payload(chunk_ids) | |||||
| payload["doc_id"] = doc_id | |||||
| res = delete_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert len(res["data"]["chunks"]) == 0, res | |||||
| assert res["data"]["total"] == 0, res | |||||
| @pytest.mark.p3 | |||||
| def test_repeated_deletion(self, WebApiAuth, add_chunks_func): | |||||
| _, doc_id, chunk_ids = add_chunks_func | |||||
| payload = {"chunk_ids": chunk_ids, "doc_id": doc_id} | |||||
| res = delete_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = delete_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == "Index updating failure", res | |||||
| @pytest.mark.p3 | |||||
| def test_duplicate_deletion(self, WebApiAuth, add_chunks_func): | |||||
| _, doc_id, chunk_ids = add_chunks_func | |||||
| payload = {"chunk_ids": chunk_ids * 2, "doc_id": doc_id} | |||||
| res = delete_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert len(res["data"]["chunks"]) == 0, res | |||||
| assert res["data"]["total"] == 0, res | |||||
| @pytest.mark.p3 | |||||
| def test_concurrent_deletion(self, WebApiAuth, add_document): | |||||
| count = 100 | |||||
| _, doc_id = add_document | |||||
| chunk_ids = batch_add_chunks(WebApiAuth, doc_id, count) | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [ | |||||
| executor.submit( | |||||
| delete_chunks, | |||||
| WebApiAuth, | |||||
| {"doc_id": doc_id, "chunk_ids": chunk_ids[i : i + 1]}, | |||||
| ) | |||||
| for i in range(count) | |||||
| ] | |||||
| responses = list(as_completed(futures)) | |||||
| assert len(responses) == count, responses | |||||
| assert all(future.result()["code"] == 0 for future in futures) | |||||
| @pytest.mark.p3 | |||||
| def test_delete_1k(self, WebApiAuth, add_document): | |||||
| chunks_num = 1_000 | |||||
| _, doc_id = add_document | |||||
| chunk_ids = batch_add_chunks(WebApiAuth, doc_id, chunks_num) | |||||
| from time import sleep | |||||
| sleep(1) | |||||
| res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids}) | |||||
| assert res["code"] == 0 | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] != 0: | |||||
| assert False, res | |||||
| assert len(res["data"]["chunks"]) == 0, res | |||||
| assert res["data"]["total"] == 0, res | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message, remaining", | |||||
| [ | |||||
| pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip), | |||||
| pytest.param({"chunk_ids": ["invalid_id"]}, 102, "Index updating failure", 4, marks=pytest.mark.p3), | |||||
| pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")), | |||||
| pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 3, marks=pytest.mark.p3), | |||||
| pytest.param(lambda r: {"chunk_ids": r}, 0, "", 0, marks=pytest.mark.p1), | |||||
| pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3), | |||||
| ], | |||||
| ) | |||||
| def test_basic_scenarios(self, WebApiAuth, add_chunks_func, payload, expected_code, expected_message, remaining): | |||||
| _, doc_id, chunk_ids = add_chunks_func | |||||
| if callable(payload): | |||||
| payload = payload(chunk_ids) | |||||
| payload["doc_id"] = doc_id | |||||
| res = delete_chunks(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if res["code"] != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| if res["code"] != 0: | |||||
| assert False, res | |||||
| assert len(res["data"]["chunks"]) == remaining, res | |||||
| assert res["data"]["total"] == remaining, res |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| from random import randint | |||||
| from time import sleep | |||||
| import pytest | |||||
| from common import delete_document, list_chunks, update_chunk | |||||
| from configs import INVALID_API_TOKEN | |||||
| from libs.auth import RAGFlowWebApiAuth | |||||
| @pytest.mark.p1 | |||||
| class TestAuthorization: | |||||
| @pytest.mark.parametrize( | |||||
| "invalid_auth, expected_code, expected_message", | |||||
| [ | |||||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||||
| res = update_chunk(invalid_auth, {"doc_id": "doc_id", "chunk_id": "chunk_id", "content_with_weight": "test"}) | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| class TestUpdateChunk: | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"content_with_weight": None}, 100, "TypeError('expected string or bytes-like object')"), | |||||
| ({"content_with_weight": ""}, 0, ""), | |||||
| ({"content_with_weight": 1}, 100, "TypeError('expected string or bytes-like object')"), | |||||
| ({"content_with_weight": "update chunk"}, 0, ""), | |||||
| ({"content_with_weight": " "}, 0, ""), | |||||
| ({"content_with_weight": "\n!?。;!?\"'"}, 0, ""), | |||||
| ], | |||||
| ) | |||||
| def test_content(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| chunk_id = chunk_ids[0] | |||||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id} | |||||
| if payload: | |||||
| update_payload.update(payload) | |||||
| res = update_chunk(WebApiAuth, update_payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| else: | |||||
| sleep(1) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| for chunk in res["data"]["chunks"]: | |||||
| if chunk["chunk_id"] == chunk_id: | |||||
| assert chunk["content_with_weight"] == payload["content_with_weight"] | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"important_kwd": ["a", "b", "c"]}, 0, ""), | |||||
| ({"important_kwd": [""]}, 0, ""), | |||||
| ({"important_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), | |||||
| ({"important_kwd": ["a", "a"]}, 0, ""), | |||||
| ({"important_kwd": "abc"}, 102, "`important_kwd` should be a list"), | |||||
| ({"important_kwd": 123}, 102, "`important_kwd` should be a list"), | |||||
| ], | |||||
| ) | |||||
| def test_important_keywords(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| chunk_id = chunk_ids[0] | |||||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required | |||||
| if payload: | |||||
| update_payload.update(payload) | |||||
| res = update_chunk(WebApiAuth, update_payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| else: | |||||
| sleep(1) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| for chunk in res["data"]["chunks"]: | |||||
| if chunk["chunk_id"] == chunk_id: | |||||
| assert chunk["important_kwd"] == payload["important_kwd"] | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"question_kwd": ["a", "b", "c"]}, 0, ""), | |||||
| ({"question_kwd": [""]}, 0, ""), | |||||
| ({"question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), | |||||
| ({"question_kwd": ["a", "a"]}, 0, ""), | |||||
| ({"question_kwd": "abc"}, 102, "`question_kwd` should be a list"), | |||||
| ({"question_kwd": 123}, 102, "`question_kwd` should be a list"), | |||||
| ], | |||||
| ) | |||||
| def test_questions(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| chunk_id = chunk_ids[0] | |||||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required | |||||
| if payload: | |||||
| update_payload.update(payload) | |||||
| res = update_chunk(WebApiAuth, update_payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| else: | |||||
| sleep(1) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| for chunk in res["data"]["chunks"]: | |||||
| if chunk["chunk_id"] == chunk_id: | |||||
| assert chunk["question_kwd"] == payload["question_kwd"] | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"available_int": 1}, 0, ""), | |||||
| ({"available_int": 0}, 0, ""), | |||||
| ], | |||||
| ) | |||||
| def test_available(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| chunk_id = chunk_ids[0] | |||||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} | |||||
| if payload: | |||||
| update_payload.update(payload) | |||||
| res = update_chunk(WebApiAuth, update_payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| else: | |||||
| sleep(1) | |||||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||||
| for chunk in res["data"]["chunks"]: | |||||
| if chunk["chunk_id"] == chunk_id: | |||||
| assert chunk["available_int"] == payload["available_int"] | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "doc_id_param, expected_code, expected_message", | |||||
| [ | |||||
| ("", 102, "Tenant not found!"), | |||||
| ("invalid_doc_id", 102, "Tenant not found!"), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_document_id_for_update(self, WebApiAuth, add_chunks, doc_id_param, expected_code, expected_message): | |||||
| _, _, chunk_ids = add_chunks | |||||
| chunk_id = chunk_ids[0] | |||||
| payload = {"doc_id": doc_id_param, "chunk_id": chunk_id, "content_with_weight": "test content"} | |||||
| res = update_chunk(WebApiAuth, payload) | |||||
| assert res["code"] == expected_code | |||||
| assert expected_message in res["message"] | |||||
| @pytest.mark.p3 | |||||
| def test_repeated_update_chunk(self, WebApiAuth, add_chunks): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| payload1 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 1"} | |||||
| res = update_chunk(WebApiAuth, payload1) | |||||
| assert res["code"] == 0 | |||||
| payload2 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 2"} | |||||
| res = update_chunk(WebApiAuth, payload2) | |||||
| assert res["code"] == 0 | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_code, expected_message", | |||||
| [ | |||||
| ({"unknown_key": "unknown_value"}, 0, ""), | |||||
| ({}, 0, ""), | |||||
| pytest.param(None, 100, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip), | |||||
| ], | |||||
| ) | |||||
| def test_invalid_params(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| chunk_id = chunk_ids[0] | |||||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} | |||||
| if payload is not None: | |||||
| update_payload.update(payload) | |||||
| res = update_chunk(WebApiAuth, update_payload) | |||||
| assert res["code"] == expected_code, res | |||||
| if expected_code != 0: | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") | |||||
| def test_concurrent_update_chunk(self, WebApiAuth, add_chunks): | |||||
| count = 50 | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [ | |||||
| executor.submit( | |||||
| update_chunk, | |||||
| WebApiAuth, | |||||
| {"doc_id": doc_id, "chunk_id": chunk_ids[randint(0, 3)], "content_with_weight": f"update chunk test {i}"}, | |||||
| ) | |||||
| for i in range(count) | |||||
| ] | |||||
| responses = list(as_completed(futures)) | |||||
| assert len(responses) == count, responses | |||||
| assert all(future.result()["code"] == 0 for future in futures) | |||||
| @pytest.mark.p3 | |||||
| def test_update_chunk_to_deleted_document(self, WebApiAuth, add_chunks): | |||||
| _, doc_id, chunk_ids = add_chunks | |||||
| delete_document(WebApiAuth, {"doc_id": doc_id}) | |||||
| payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "test content"} | |||||
| res = update_chunk(WebApiAuth, payload) | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == "Tenant not found!", res |