### What problem does this PR solve? Refactor test fixtures and test cases ### Type of change - [ ] Refactoring test casestags/v0.18.0
| @@ -15,26 +15,27 @@ | |||
| # | |||
| import os | |||
| import pytest | |||
| import requests | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380') | |||
| HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") | |||
| # def generate_random_email(): | |||
| # return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com' | |||
| def generate_email(): | |||
| return 'user_123@1.com' | |||
| return "user_123@1.com" | |||
| EMAIL = generate_email() | |||
| # password is "123" | |||
| PASSWORD = '''ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD | |||
| PASSWORD = """ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD | |||
| fN33jCHRoDUW81IH9zjij/vaw8IbVyb6vuwg6MX6inOEBRRzVbRYxXOu1wkWY6SsI8X70oF9aeLFp/PzQpjoe/YbSqpTq8qqrmHzn9vO+yvyYyvmDsphXe | |||
| X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w==''' | |||
| X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w==""" | |||
| def register(): | |||
| @@ -92,3 +93,64 @@ def get_email(): | |||
| @pytest.fixture(scope="session") | |||
| def get_http_api_auth(get_api_key_fixture): | |||
| return RAGFlowHttpApiAuth(get_api_key_fixture) | |||
| def get_my_llms(auth, name): | |||
| url = HOST_ADDRESS + "/v1/llm/my_llms" | |||
| authorization = {"Authorization": auth} | |||
| response = requests.get(url=url, headers=authorization) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| if name in res.get("data"): | |||
| return True | |||
| return False | |||
| def add_models(auth): | |||
| url = HOST_ADDRESS + "/v1/llm/set_api_key" | |||
| authorization = {"Authorization": auth} | |||
| models_info = { | |||
| "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": "d06253dacd404180aa8afb096fcb6c30.KatwBIUpvCSml9sU"}, | |||
| } | |||
| for name, model_info in models_info.items(): | |||
| if not get_my_llms(auth, name): | |||
| response = requests.post(url=url, headers=authorization, json=model_info) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| def get_tenant_info(auth): | |||
| url = HOST_ADDRESS + "/v1/user/tenant_info" | |||
| authorization = {"Authorization": auth} | |||
| response = requests.get(url=url, headers=authorization) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| return res["data"].get("tenant_id") | |||
| @pytest.fixture(scope="session", autouse=True) | |||
| def set_tenant_info(get_auth): | |||
| auth = get_auth | |||
| try: | |||
| add_models(auth) | |||
| tenant_id = get_tenant_info(auth) | |||
| except Exception as e: | |||
| raise Exception(e) | |||
| url = HOST_ADDRESS + "/v1/user/set_tenant_info" | |||
| authorization = {"Authorization": get_auth} | |||
| tenant_info = { | |||
| "tenant_id": tenant_id, | |||
| "llm_id": "glm-4-flash@ZHIPU-AI", | |||
| "embd_id": "embedding-3@ZHIPU-AI", | |||
| "img2txt_id": "glm-4v@ZHIPU-AI", | |||
| "asr_id": "", | |||
| "tts_id": None, | |||
| } | |||
| response = requests.post(url=url, headers=authorization, json=tenant_info) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| @@ -27,6 +27,7 @@ DATASETS_API_URL = "/api/v1/datasets" | |||
| FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" | |||
| FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" | |||
| CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" | |||
| CHAT_ASSISTANT_API_URL = "/api/v1/chats" | |||
| INVALID_API_TOKEN = "invalid_key_123" | |||
| DATASET_NAME_LIMIT = 128 | |||
| @@ -39,7 +40,7 @@ def create_dataset(auth, payload=None): | |||
| return res.json() | |||
| def list_dataset(auth, params=None): | |||
| def list_datasets(auth, params=None): | |||
| res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params) | |||
| return res.json() | |||
| @@ -49,7 +50,7 @@ def update_dataset(auth, dataset_id, payload=None): | |||
| return res.json() | |||
| def delete_dataset(auth, payload=None): | |||
| def delete_datasets(auth, payload=None): | |||
| res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| @@ -105,7 +106,7 @@ def download_document(auth, dataset_id, document_id, save_path): | |||
| return res | |||
| def list_documnet(auth, dataset_id, params=None): | |||
| def list_documnets(auth, dataset_id, params=None): | |||
| url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) | |||
| res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) | |||
| return res.json() | |||
| @@ -117,19 +118,19 @@ def update_documnet(auth, dataset_id, document_id, payload=None): | |||
| return res.json() | |||
| def delete_documnet(auth, dataset_id, payload=None): | |||
| def delete_documnets(auth, dataset_id, payload=None): | |||
| url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) | |||
| res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| def parse_documnet(auth, dataset_id, payload=None): | |||
| def parse_documnets(auth, dataset_id, payload=None): | |||
| url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) | |||
| res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| def stop_parse_documnet(auth, dataset_id, payload=None): | |||
| def stop_parse_documnets(auth, dataset_id, payload=None): | |||
| url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) | |||
| res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| @@ -184,3 +185,36 @@ def batch_add_chunks(auth, dataset_id, document_id, num): | |||
| res = add_chunk(auth, dataset_id, document_id, {"content": f"chunk test {i}"}) | |||
| chunk_ids.append(res["data"]["chunk"]["id"]) | |||
| return chunk_ids | |||
| # CHAT ASSISTANT MANAGEMENT | |||
| def create_chat_assistant(auth, payload=None): | |||
| url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" | |||
| res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| def list_chat_assistants(auth, params=None): | |||
| url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" | |||
| res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) | |||
| return res.json() | |||
| def update_chat_assistant(auth, chat_assistant_id, payload=None): | |||
| url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}" | |||
| res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| def delete_chat_assistants(auth, payload=None): | |||
| url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" | |||
| res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) | |||
| return res.json() | |||
| def batch_create_chat_assistants(auth, num): | |||
| chat_assistant_ids = [] | |||
| for i in range(num): | |||
| res = create_chat_assistant(auth, {"name": f"test_chat_assistant_{i}"}) | |||
| chat_assistant_ids.append(res["data"]["id"]) | |||
| return chat_assistant_ids | |||
| @@ -14,9 +14,8 @@ | |||
| # limitations under the License. | |||
| # | |||
| import pytest | |||
| from common import delete_dataset | |||
| from common import batch_create_datasets, bulk_upload_documents, delete_datasets | |||
| from libs.utils.file_utils import ( | |||
| create_docx_file, | |||
| create_eml_file, | |||
| @@ -34,7 +33,7 @@ from libs.utils.file_utils import ( | |||
| @pytest.fixture(scope="function") | |||
| def clear_datasets(get_http_api_auth): | |||
| yield | |||
| delete_dataset(get_http_api_auth) | |||
| delete_datasets(get_http_api_auth) | |||
| @pytest.fixture | |||
| @@ -58,3 +57,38 @@ def generate_test_files(request, tmp_path): | |||
| creator_func(file_path) | |||
| files[file_type] = file_path | |||
| return files | |||
| @pytest.fixture(scope="class") | |||
| def ragflow_tmp_dir(request, tmp_path_factory): | |||
| class_name = request.cls.__name__ | |||
| return tmp_path_factory.mktemp(class_name) | |||
| @pytest.fixture(scope="class") | |||
| def add_dataset(request, get_http_api_auth): | |||
| def cleanup(): | |||
| delete_datasets(get_http_api_auth) | |||
| request.addfinalizer(cleanup) | |||
| dataset_ids = batch_create_datasets(get_http_api_auth, 1) | |||
| return dataset_ids[0] | |||
| @pytest.fixture(scope="function") | |||
| def add_dataset_func(request, get_http_api_auth): | |||
| def cleanup(): | |||
| delete_datasets(get_http_api_auth) | |||
| request.addfinalizer(cleanup) | |||
| dataset_ids = batch_create_datasets(get_http_api_auth, 1) | |||
| return dataset_ids[0] | |||
| @pytest.fixture(scope="class") | |||
| def add_document(get_http_api_auth, add_dataset, ragflow_tmp_dir): | |||
| dataset_id = add_dataset | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir) | |||
| return dataset_id, document_ids[0] | |||
| @@ -16,13 +16,13 @@ | |||
| import pytest | |||
| from common import add_chunk, batch_create_datasets, bulk_upload_documents, delete_chunks, delete_dataset, list_documnet, parse_documnet | |||
| from common import add_chunk, delete_chunks, list_documnets, parse_documnets | |||
| from libs.utils import wait_for | |||
| @wait_for(10, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id): | |||
| res = list_documnet(_auth, _dataset_id) | |||
| res = list_documnets(_auth, _dataset_id) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| @@ -30,29 +30,11 @@ def condition(_auth, _dataset_id): | |||
| @pytest.fixture(scope="class") | |||
| def chunk_management_tmp_dir(tmp_path_factory): | |||
| return tmp_path_factory.mktemp("chunk_management") | |||
| @pytest.fixture(scope="class") | |||
| def get_dataset_id_and_document_id(get_http_api_auth, chunk_management_tmp_dir, request): | |||
| def cleanup(): | |||
| delete_dataset(get_http_api_auth) | |||
| request.addfinalizer(cleanup) | |||
| dataset_ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = dataset_ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, chunk_management_tmp_dir) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| def add_chunks(get_http_api_auth, add_document): | |||
| dataset_id, document_id = add_document | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]}) | |||
| condition(get_http_api_auth, dataset_id) | |||
| return dataset_id, document_ids[0] | |||
| @pytest.fixture(scope="class") | |||
| def add_chunks(get_http_api_auth, get_dataset_id_and_document_id): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| chunk_ids = [] | |||
| for i in range(4): | |||
| res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": f"chunk test {i}"}) | |||
| @@ -66,8 +48,10 @@ def add_chunks(get_http_api_auth, get_dataset_id_and_document_id): | |||
| @pytest.fixture(scope="function") | |||
| def add_chunks_func(get_http_api_auth, get_dataset_id_and_document_id, request): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| def add_chunks_func(request, get_http_api_auth, add_document): | |||
| dataset_id, document_id = add_document | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]}) | |||
| condition(get_http_api_auth, dataset_id) | |||
| chunk_ids = [] | |||
| for i in range(4): | |||
| @@ -16,7 +16,7 @@ | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, add_chunk, delete_documnet, list_chunks | |||
| from common import INVALID_API_TOKEN, add_chunk, delete_documnets, list_chunks | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -44,7 +44,7 @@ class TestAuthorization: | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = add_chunk(auth, "dataset_id", "document_id", {}) | |||
| res = add_chunk(auth, "dataset_id", "document_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -66,8 +66,8 @@ class TestAddChunk: | |||
| ({"content": "\n!?。;!?\"'"}, 0, ""), | |||
| ], | |||
| ) | |||
| def test_content(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| def test_content(self, get_http_api_auth, add_document, payload, expected_code, expected_message): | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(get_http_api_auth, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| @@ -98,8 +98,8 @@ class TestAddChunk: | |||
| ({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"), | |||
| ], | |||
| ) | |||
| def test_important_keywords(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| def test_important_keywords(self, get_http_api_auth, add_document, payload, expected_code, expected_message): | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(get_http_api_auth, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| @@ -126,8 +126,8 @@ class TestAddChunk: | |||
| ({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"), | |||
| ], | |||
| ) | |||
| def test_questions(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| def test_questions(self, get_http_api_auth, add_document, payload, expected_code, expected_message): | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(get_http_api_auth, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| @@ -157,12 +157,12 @@ class TestAddChunk: | |||
| def test_invalid_dataset_id( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_id, | |||
| add_document, | |||
| dataset_id, | |||
| expected_code, | |||
| expected_message, | |||
| ): | |||
| _, document_id = get_dataset_id_and_document_id | |||
| _, document_id = add_document | |||
| res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -178,15 +178,15 @@ class TestAddChunk: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_id, document_id, expected_code, expected_message): | |||
| dataset_id, _ = get_dataset_id_and_document_id | |||
| def test_invalid_document_id(self, get_http_api_auth, add_document, document_id, expected_code, expected_message): | |||
| dataset_id, _ = add_document | |||
| res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| def test_repeated_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id): | |||
| def test_repeated_add_chunk(self, get_http_api_auth, add_document): | |||
| payload = {"content": "chunk test"} | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(get_http_api_auth, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| @@ -207,17 +207,17 @@ class TestAddChunk: | |||
| assert False, res | |||
| assert res["data"]["doc"]["chunk_count"] == chunks_count + 2 | |||
| def test_add_chunk_to_deleted_document(self, get_http_api_auth, get_dataset_id_and_document_id): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]}) | |||
| def test_add_chunk_to_deleted_document(self, get_http_api_auth, add_document): | |||
| dataset_id, document_id = add_document | |||
| delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]}) | |||
| res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"}) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == f"You don't own the document {document_id}." | |||
| @pytest.mark.skip(reason="issues/6411") | |||
| def test_concurrent_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id): | |||
| def test_concurrent_add_chunk(self, get_http_api_auth, add_document): | |||
| chunk_num = 50 | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(get_http_api_auth, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| @@ -39,7 +39,7 @@ class TestAuthorization: | |||
| assert res["message"] == expected_message | |||
| class TestChunkstDeletion: | |||
| class TestChunksDeletion: | |||
| @pytest.mark.parametrize( | |||
| "dataset_id, expected_code, expected_message", | |||
| [ | |||
| @@ -61,25 +61,14 @@ class TestChunkstDeletion: | |||
| "document_id, expected_code, expected_message", | |||
| [ | |||
| ("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"), | |||
| pytest.param( | |||
| "invalid_document_id", | |||
| 100, | |||
| "LookupError('Document not found which is supposed to be there')", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6611"), | |||
| ), | |||
| pytest.param( | |||
| "invalid_document_id", | |||
| 100, | |||
| "rm_chunk deleted chunks 0, expect 4", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="issues/6611"), | |||
| ), | |||
| ("invalid_document_id", 100, """LookupError("Can't find the document with ID invalid_document_id!")"""), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id(self, get_http_api_auth, add_chunks_func, document_id, expected_code, expected_message): | |||
| dataset_id, _, chunk_ids = add_chunks_func | |||
| res = delete_chunks(get_http_api_auth, dataset_id, document_id, {"chunk_ids": chunk_ids}) | |||
| assert res["code"] == expected_code | |||
| #assert res["message"] == expected_message | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.parametrize( | |||
| "payload", | |||
| @@ -17,11 +17,7 @@ import os | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import ( | |||
| INVALID_API_TOKEN, | |||
| batch_add_chunks, | |||
| list_chunks, | |||
| ) | |||
| from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -153,8 +149,9 @@ class TestChunksList: | |||
| assert all(r["code"] == 0 for r in responses) | |||
| assert all(len(r["data"]["chunks"]) == 5 for r in responses) | |||
| def test_default(self, get_http_api_auth, get_dataset_id_and_document_id): | |||
| dataset_id, document_id = get_dataset_id_and_document_id | |||
| def test_default(self, get_http_api_auth, add_document): | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(get_http_api_auth, dataset_id, document_id) | |||
| chunks_count = res["data"]["doc"]["chunk_count"] | |||
| batch_add_chunks(get_http_api_auth, dataset_id, document_id, 31) | |||
| @@ -13,7 +13,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import pytest | |||
| @@ -52,9 +51,7 @@ class TestChunksRetrieval: | |||
| ({"question": "chunk"}, 102, 0, "`dataset_ids` is required."), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios( | |||
| self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message | |||
| ): | |||
| def test_basic_scenarios(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, document_id, _ = add_chunks | |||
| if "dataset_ids" in payload: | |||
| payload["dataset_ids"] = [dataset_id] | |||
| @@ -137,9 +134,7 @@ class TestChunksRetrieval: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page_size( | |||
| self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message | |||
| ): | |||
| def test_page_size(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) | |||
| @@ -165,9 +160,7 @@ class TestChunksRetrieval: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_vector_similarity_weight( | |||
| self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message | |||
| ): | |||
| def test_vector_similarity_weight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) | |||
| res = retrieval_chunks(get_http_api_auth, payload) | |||
| @@ -233,9 +226,7 @@ class TestChunksRetrieval: | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), | |||
| pytest.param( | |||
| {"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip | |||
| ), | |||
| pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_rerank_id(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message): | |||
| @@ -248,7 +239,6 @@ class TestChunksRetrieval: | |||
| else: | |||
| assert expected_message in res["message"] | |||
| @pytest.mark.skip(reason="chat model is not set") | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| @@ -279,9 +269,7 @@ class TestChunksRetrieval: | |||
| pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||
| ], | |||
| ) | |||
| def test_highlight( | |||
| self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message | |||
| ): | |||
| def test_highlight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) | |||
| res = retrieval_chunks(get_http_api_auth, payload) | |||
| @@ -302,3 +290,14 @@ class TestChunksRetrieval: | |||
| res = retrieval_chunks(get_http_api_auth, payload) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]["chunks"]) == 4 | |||
| def test_concurrent_retrieval(self, get_http_api_auth, add_chunks): | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| dataset_id, _, _ = add_chunks | |||
| payload = {"question": "chunk", "dataset_ids": [dataset_id]} | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(retrieval_chunks, get_http_api_auth, payload) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| @@ -18,7 +18,7 @@ from concurrent.futures import ThreadPoolExecutor | |||
| from random import randint | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, delete_documnet, update_chunk | |||
| from common import INVALID_API_TOKEN, delete_documnets, update_chunk | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -233,7 +233,7 @@ class TestUpdatedChunk: | |||
| def test_update_chunk_to_deleted_document(self, get_http_api_auth, add_chunks): | |||
| dataset_id, document_id, chunk_ids = add_chunks | |||
| delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]}) | |||
| delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]}) | |||
| res = update_chunk(get_http_api_auth, dataset_id, document_id, chunk_ids[0]) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == f"Can't find this chunk {chunk_ids[0]}" | |||
| @@ -16,14 +16,24 @@ | |||
| import pytest | |||
| from common import batch_create_datasets, delete_dataset | |||
| from common import batch_create_datasets, delete_datasets | |||
| @pytest.fixture(scope="class") | |||
| def get_dataset_ids(get_http_api_auth, request): | |||
| def add_datasets(get_http_api_auth, request): | |||
| def cleanup(): | |||
| delete_dataset(get_http_api_auth) | |||
| delete_datasets(get_http_api_auth) | |||
| request.addfinalizer(cleanup) | |||
| return batch_create_datasets(get_http_api_auth, 5) | |||
| @pytest.fixture(scope="function") | |||
| def add_datasets_func(get_http_api_auth, request): | |||
| def cleanup(): | |||
| delete_datasets(get_http_api_auth) | |||
| request.addfinalizer(cleanup) | |||
| return batch_create_datasets(get_http_api_auth, 3) | |||
| @@ -75,9 +75,6 @@ class TestDatasetCreation: | |||
| res = create_dataset(get_http_api_auth, payload) | |||
| assert res["code"] == 0, f"Failed to create dataset {i}" | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestAdvancedConfigurations: | |||
| def test_avatar(self, get_http_api_auth, tmp_path): | |||
| fn = create_image_file(tmp_path / "ragflow_test.png") | |||
| payload = { | |||
| @@ -20,13 +20,12 @@ import pytest | |||
| from common import ( | |||
| INVALID_API_TOKEN, | |||
| batch_create_datasets, | |||
| delete_dataset, | |||
| list_dataset, | |||
| delete_datasets, | |||
| list_datasets, | |||
| ) | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "auth, expected_code, expected_message", | |||
| @@ -39,18 +38,13 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = delete_dataset(auth, {"ids": ids}) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = delete_datasets(auth) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| res = list_dataset(get_http_api_auth) | |||
| assert len(res["data"]) == 1 | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestDatasetDeletion: | |||
| class TestDatasetsDeletion: | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message, remaining", | |||
| [ | |||
| @@ -73,16 +67,16 @@ class TestDatasetDeletion: | |||
| (lambda r: {"ids": r}, 0, "", 0), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining): | |||
| ids = batch_create_datasets(get_http_api_auth, 3) | |||
| def test_basic_scenarios(self, get_http_api_auth, add_datasets_func, payload, expected_code, expected_message, remaining): | |||
| dataset_ids = add_datasets_func | |||
| if callable(payload): | |||
| payload = payload(ids) | |||
| res = delete_dataset(get_http_api_auth, payload) | |||
| payload = payload(dataset_ids) | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == expected_code | |||
| if res["code"] != 0: | |||
| assert res["message"] == expected_message | |||
| res = list_dataset(get_http_api_auth) | |||
| res = list_datasets(get_http_api_auth) | |||
| assert len(res["data"]) == remaining | |||
| @pytest.mark.parametrize( | |||
| @@ -93,50 +87,50 @@ class TestDatasetDeletion: | |||
| lambda r: {"ids": r + ["invalid_id"]}, | |||
| ], | |||
| ) | |||
| def test_delete_partial_invalid_id(self, get_http_api_auth, payload): | |||
| ids = batch_create_datasets(get_http_api_auth, 3) | |||
| def test_delete_partial_invalid_id(self, get_http_api_auth, add_datasets_func, payload): | |||
| dataset_ids = add_datasets_func | |||
| if callable(payload): | |||
| payload = payload(ids) | |||
| res = delete_dataset(get_http_api_auth, payload) | |||
| payload = payload(dataset_ids) | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == 0 | |||
| assert res["data"]["errors"][0] == "You don't own the dataset invalid_id" | |||
| assert res["data"]["success_count"] == 3 | |||
| res = list_dataset(get_http_api_auth) | |||
| res = list_datasets(get_http_api_auth) | |||
| assert len(res["data"]) == 0 | |||
| def test_repeated_deletion(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = delete_dataset(get_http_api_auth, {"ids": ids}) | |||
| def test_repeated_deletion(self, get_http_api_auth, add_datasets_func): | |||
| dataset_ids = add_datasets_func | |||
| res = delete_datasets(get_http_api_auth, {"ids": dataset_ids}) | |||
| assert res["code"] == 0 | |||
| res = delete_dataset(get_http_api_auth, {"ids": ids}) | |||
| res = delete_datasets(get_http_api_auth, {"ids": dataset_ids}) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == f"You don't own the dataset {ids[0]}" | |||
| assert "You don't own the dataset" in res["message"] | |||
| def test_duplicate_deletion(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = delete_dataset(get_http_api_auth, {"ids": ids + ids}) | |||
| def test_duplicate_deletion(self, get_http_api_auth, add_datasets_func): | |||
| dataset_ids = add_datasets_func | |||
| res = delete_datasets(get_http_api_auth, {"ids": dataset_ids + dataset_ids}) | |||
| assert res["code"] == 0 | |||
| assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}" | |||
| assert res["data"]["success_count"] == 1 | |||
| assert "Duplicate dataset ids" in res["data"]["errors"][0] | |||
| assert res["data"]["success_count"] == 3 | |||
| res = list_dataset(get_http_api_auth) | |||
| res = list_datasets(get_http_api_auth) | |||
| assert len(res["data"]) == 0 | |||
| def test_concurrent_deletion(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 100) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)] | |||
| futures = [executor.submit(delete_datasets, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| @pytest.mark.slow | |||
| def test_delete_10k(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 10_000) | |||
| res = delete_dataset(get_http_api_auth, {"ids": ids}) | |||
| res = delete_datasets(get_http_api_auth, {"ids": ids}) | |||
| assert res["code"] == 0 | |||
| res = list_dataset(get_http_api_auth) | |||
| res = list_datasets(get_http_api_auth) | |||
| assert len(res["data"]) == 0 | |||
| @@ -16,7 +16,7 @@ | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, list_dataset | |||
| from common import INVALID_API_TOKEN, list_datasets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -25,7 +25,6 @@ def is_sorted(data, field, descending=True): | |||
| return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:])) | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "auth, expected_code, expected_message", | |||
| @@ -39,15 +38,15 @@ class TestAuthorization: | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = list_dataset(auth) | |||
| res = list_datasets(auth) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.usefixtures("get_dataset_ids") | |||
| class TestDatasetList: | |||
| @pytest.mark.usefixtures("add_datasets") | |||
| class TestDatasetsList: | |||
| def test_default(self, get_http_api_auth): | |||
| res = list_dataset(get_http_api_auth, params={}) | |||
| res = list_datasets(get_http_api_auth, params={}) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 5 | |||
| @@ -77,7 +76,7 @@ class TestDatasetList: | |||
| ], | |||
| ) | |||
| def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message): | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]) == expected_page_size | |||
| @@ -116,7 +115,7 @@ class TestDatasetList: | |||
| expected_page_size, | |||
| expected_message, | |||
| ): | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]) == expected_page_size | |||
| @@ -168,7 +167,7 @@ class TestDatasetList: | |||
| assertions, | |||
| expected_message, | |||
| ): | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if callable(assertions): | |||
| @@ -244,7 +243,7 @@ class TestDatasetList: | |||
| assertions, | |||
| expected_message, | |||
| ): | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if callable(assertions): | |||
| @@ -262,7 +261,7 @@ class TestDatasetList: | |||
| ], | |||
| ) | |||
| def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message): | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if params["name"] in [None, ""]: | |||
| @@ -284,19 +283,19 @@ class TestDatasetList: | |||
| def test_id( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_ids, | |||
| add_datasets, | |||
| dataset_id, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| dataset_ids = get_dataset_ids | |||
| dataset_ids = add_datasets | |||
| if callable(dataset_id): | |||
| params = {"id": dataset_id(dataset_ids)} | |||
| else: | |||
| params = {"id": dataset_id} | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if params["id"] in [None, ""]: | |||
| @@ -318,20 +317,20 @@ class TestDatasetList: | |||
| def test_name_and_id( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_ids, | |||
| add_datasets, | |||
| dataset_id, | |||
| name, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| dataset_ids = get_dataset_ids | |||
| dataset_ids = add_datasets | |||
| if callable(dataset_id): | |||
| params = {"id": dataset_id(dataset_ids), "name": name} | |||
| else: | |||
| params = {"id": dataset_id, "name": name} | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| if expected_code == 0: | |||
| assert len(res["data"]) == expected_num | |||
| else: | |||
| @@ -339,12 +338,12 @@ class TestDatasetList: | |||
| def test_concurrent_list(self, get_http_api_auth): | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)] | |||
| futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| def test_invalid_params(self, get_http_api_auth): | |||
| params = {"a": "b"} | |||
| res = list_dataset(get_http_api_auth, params=params) | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 5 | |||
| @@ -19,8 +19,7 @@ import pytest | |||
| from common import ( | |||
| DATASET_NAME_LIMIT, | |||
| INVALID_API_TOKEN, | |||
| batch_create_datasets, | |||
| list_dataset, | |||
| list_datasets, | |||
| update_dataset, | |||
| ) | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -30,7 +29,6 @@ from libs.utils.file_utils import create_image_file | |||
| # TODO: Missing scenario for updating embedding_model with chunk_count != 0 | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "auth, expected_code, expected_message", | |||
| @@ -43,14 +41,12 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = update_dataset(auth, ids[0], {"name": "new_name"}) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = update_dataset(auth, "dataset_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestDatasetUpdate: | |||
| @pytest.mark.parametrize( | |||
| "name, expected_code, expected_message", | |||
| @@ -72,12 +68,12 @@ class TestDatasetUpdate: | |||
| ("DATASET_1", 102, "Duplicated dataset name in updating dataset."), | |||
| ], | |||
| ) | |||
| def test_name(self, get_http_api_auth, name, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 2) | |||
| res = update_dataset(get_http_api_auth, ids[0], {"name": name}) | |||
| def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message): | |||
| dataset_ids = add_datasets_func | |||
| res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name}) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]}) | |||
| assert res["data"][0]["name"] == name | |||
| else: | |||
| assert res["message"] == expected_message | |||
| @@ -95,12 +91,12 @@ class TestDatasetUpdate: | |||
| (None, 102, "`embedding_model` can't be empty"), | |||
| ], | |||
| ) | |||
| def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model}) | |||
| def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message): | |||
| dataset_id = add_dataset_func | |||
| res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model}) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["embedding_model"] == embedding_model | |||
| else: | |||
| assert res["message"] == expected_message | |||
| @@ -129,12 +125,12 @@ class TestDatasetUpdate: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method}) | |||
| def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message): | |||
| dataset_id = add_dataset_func | |||
| res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method}) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| if chunk_method != "": | |||
| assert res["data"][0]["chunk_method"] == chunk_method | |||
| else: | |||
| @@ -142,38 +138,38 @@ class TestDatasetUpdate: | |||
| else: | |||
| assert res["message"] == expected_message | |||
| def test_avatar(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| fn = create_image_file(tmp_path / "ragflow_test.png") | |||
| payload = {"avatar": encode_avatar(fn)} | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 0 | |||
| def test_description(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_description(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| payload = {"description": "description"} | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 0 | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["description"] == "description" | |||
| def test_pagerank(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_pagerank(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| payload = {"pagerank": 1} | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 0 | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["pagerank"] == 1 | |||
| def test_similarity_threshold(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_similarity_threshold(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| payload = {"similarity_threshold": 1} | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 0 | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["similarity_threshold"] == 1 | |||
| @pytest.mark.parametrize( | |||
| @@ -187,29 +183,28 @@ class TestDatasetUpdate: | |||
| ("other_permission", 102), | |||
| ], | |||
| ) | |||
| def test_permission(self, get_http_api_auth, permission, expected_code): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code): | |||
| dataset_id = add_dataset_func | |||
| payload = {"permission": permission} | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == expected_code | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| if expected_code == 0 and permission != "": | |||
| assert res["data"][0]["permission"] == permission | |||
| if permission == "": | |||
| assert res["data"][0]["permission"] == "me" | |||
| def test_vector_similarity_weight(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| payload = {"vector_similarity_weight": 1} | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 0 | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["vector_similarity_weight"] == 1 | |||
| def test_invalid_dataset_id(self, get_http_api_auth): | |||
| batch_create_datasets(get_http_api_auth, 1) | |||
| res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == "You don't own the dataset" | |||
| @@ -230,21 +225,21 @@ class TestDatasetUpdate: | |||
| {"update_time": 1741671443339}, | |||
| ], | |||
| ) | |||
| def test_modify_read_only_field(self, get_http_api_auth, payload): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = update_dataset(get_http_api_auth, ids[0], payload) | |||
| def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload): | |||
| dataset_id = add_dataset_func | |||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 101 | |||
| assert "is readonly" in res["message"] | |||
| def test_modify_unknown_field(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0}) | |||
| def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0}) | |||
| assert res["code"] == 100 | |||
| def test_concurrent_update(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_concurrent_update(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)] | |||
| futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| @@ -16,22 +16,36 @@ | |||
| import pytest | |||
| from common import batch_create_datasets, bulk_upload_documents, delete_dataset | |||
| from common import bulk_upload_documents, delete_documnets | |||
| @pytest.fixture(scope="class") | |||
| def file_management_tmp_dir(tmp_path_factory): | |||
| return tmp_path_factory.mktemp("file_management") | |||
| @pytest.fixture(scope="function") | |||
| def add_document_func(request, get_http_api_auth, add_dataset, ragflow_tmp_dir): | |||
| dataset_id = add_dataset | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir) | |||
| def cleanup(): | |||
| delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) | |||
| request.addfinalizer(cleanup) | |||
| return dataset_id, document_ids[0] | |||
| @pytest.fixture(scope="class") | |||
| def get_dataset_id_and_document_ids(get_http_api_auth, file_management_tmp_dir, request): | |||
| def add_documents(request, get_http_api_auth, add_dataset, ragflow_tmp_dir): | |||
| dataset_id = add_dataset | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, ragflow_tmp_dir) | |||
| def cleanup(): | |||
| delete_dataset(get_http_api_auth) | |||
| delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) | |||
| request.addfinalizer(cleanup) | |||
| return dataset_id, document_ids | |||
| @pytest.fixture(scope="function") | |||
| def add_documents_func(get_http_api_auth, add_dataset_func, ragflow_tmp_dir): | |||
| dataset_id = add_dataset_func | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, ragflow_tmp_dir) | |||
| dataset_ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = dataset_ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, file_management_tmp_dir) | |||
| return dataset_id, document_ids | |||
| @@ -16,13 +16,7 @@ | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import ( | |||
| INVALID_API_TOKEN, | |||
| batch_create_datasets, | |||
| bulk_upload_documents, | |||
| delete_documnet, | |||
| list_documnet, | |||
| ) | |||
| from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documnets, list_documnets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -38,15 +32,13 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| res = delete_documnet(auth, dataset_id, {"ids": document_ids}) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = delete_documnets(auth, "dataset_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestDocumentDeletion: | |||
| class TestDocumentsDeletion: | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message, remaining", | |||
| [ | |||
| @@ -72,22 +64,21 @@ class TestDocumentDeletion: | |||
| def test_basic_scenarios( | |||
| self, | |||
| get_http_api_auth, | |||
| tmp_path, | |||
| add_documents_func, | |||
| payload, | |||
| expected_code, | |||
| expected_message, | |||
| remaining, | |||
| ): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path) | |||
| dataset_id, document_ids = add_documents_func | |||
| if callable(payload): | |||
| payload = payload(document_ids) | |||
| res = delete_documnet(get_http_api_auth, ids[0], payload) | |||
| res = delete_documnets(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == expected_code | |||
| if res["code"] != 0: | |||
| assert res["message"] == expected_message | |||
| res = list_documnet(get_http_api_auth, ids[0]) | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert len(res["data"]["docs"]) == remaining | |||
| assert res["data"]["total"] == remaining | |||
| @@ -102,10 +93,9 @@ class TestDocumentDeletion: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_dataset_id(self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path) | |||
| res = delete_documnet(get_http_api_auth, dataset_id, {"ids": document_ids[:1]}) | |||
| def test_invalid_dataset_id(self, get_http_api_auth, add_documents_func, dataset_id, expected_code, expected_message): | |||
| _, document_ids = add_documents_func | |||
| res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids[:1]}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -117,69 +107,68 @@ class TestDocumentDeletion: | |||
| lambda r: {"ids": r + ["invalid_id"]}, | |||
| ], | |||
| ) | |||
| def test_delete_partial_invalid_id(self, get_http_api_auth, tmp_path, payload): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path) | |||
| def test_delete_partial_invalid_id(self, get_http_api_auth, add_documents_func, payload): | |||
| dataset_id, document_ids = add_documents_func | |||
| if callable(payload): | |||
| payload = payload(document_ids) | |||
| res = delete_documnet(get_http_api_auth, ids[0], payload) | |||
| res = delete_documnets(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == "Documents not found: ['invalid_id']" | |||
| res = list_documnet(get_http_api_auth, ids[0]) | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert len(res["data"]["docs"]) == 0 | |||
| assert res["data"]["total"] == 0 | |||
| def test_repeated_deletion(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) | |||
| res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids}) | |||
| def test_repeated_deletion(self, get_http_api_auth, add_documents_func): | |||
| dataset_id, document_ids = add_documents_func | |||
| res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids}) | |||
| res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == f"Documents not found: {document_ids}" | |||
| assert "Documents not found" in res["message"] | |||
| def test_duplicate_deletion(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) | |||
| res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids + document_ids}) | |||
| def test_duplicate_deletion(self, get_http_api_auth, add_documents_func): | |||
| dataset_id, document_ids = add_documents_func | |||
| res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids + document_ids}) | |||
| assert res["code"] == 0 | |||
| assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}" | |||
| assert res["data"]["success_count"] == 1 | |||
| assert "Duplicate document ids" in res["data"]["errors"][0] | |||
| assert res["data"]["success_count"] == 3 | |||
| res = list_documnet(get_http_api_auth, ids[0]) | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert len(res["data"]["docs"]) == 0 | |||
| assert res["data"]["total"] == 0 | |||
| def test_concurrent_deletion(self, get_http_api_auth, tmp_path): | |||
| documnets_num = 100 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| delete_documnet, | |||
| get_http_api_auth, | |||
| ids[0], | |||
| {"ids": document_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(documnets_num) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| @pytest.mark.slow | |||
| def test_delete_1k(self, get_http_api_auth, tmp_path): | |||
| documnets_num = 1_000 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path) | |||
| res = list_documnet(get_http_api_auth, ids[0]) | |||
| assert res["data"]["total"] == documnets_num | |||
| res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| res = list_documnet(get_http_api_auth, ids[0]) | |||
| assert res["data"]["total"] == 0 | |||
| def test_concurrent_deletion(get_http_api_auth, add_dataset, tmp_path): | |||
| documnets_num = 100 | |||
| dataset_id = add_dataset | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| delete_documnets, | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| {"ids": document_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(documnets_num) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| @pytest.mark.slow | |||
| def test_delete_1k(get_http_api_auth, add_dataset, tmp_path): | |||
| documnets_num = 1_000 | |||
| dataset_id = add_dataset | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path) | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert res["data"]["total"] == documnets_num | |||
| res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert res["data"]["total"] == 0 | |||
| @@ -18,7 +18,7 @@ import json | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, download_document, upload_documnets | |||
| from common import INVALID_API_TOKEN, bulk_upload_documents, download_document, upload_documnets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| from libs.utils import compare_by_hash | |||
| from requests import codes | |||
| @@ -36,9 +36,8 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_dataset_id_and_document_ids, tmp_path, auth, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| res = download_document(auth, dataset_id, document_ids[0], tmp_path / "ragflow_tes.txt") | |||
| def test_invalid_auth(self, tmp_path, auth, expected_code, expected_message): | |||
| res = download_document(auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt") | |||
| assert res.status_code == codes.ok | |||
| with (tmp_path / "ragflow_tes.txt").open("r") as f: | |||
| response_json = json.load(f) | |||
| @@ -46,7 +45,6 @@ class TestAuthorization: | |||
| assert response_json["message"] == expected_message | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| @pytest.mark.parametrize( | |||
| "generate_test_files", | |||
| [ | |||
| @@ -63,15 +61,15 @@ class TestAuthorization: | |||
| ], | |||
| indirect=True, | |||
| ) | |||
| def test_file_type_validation(get_http_api_auth, generate_test_files, request): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_file_type_validation(get_http_api_auth, add_dataset, generate_test_files, request): | |||
| dataset_id = add_dataset | |||
| fp = generate_test_files[request.node.callspec.params["generate_test_files"]] | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| document_id = res["data"][0]["id"] | |||
| res = download_document( | |||
| get_http_api_auth, | |||
| ids[0], | |||
| dataset_id, | |||
| document_id, | |||
| fp.with_stem("ragflow_test_download"), | |||
| ) | |||
| @@ -93,8 +91,8 @@ class TestDocumentDownload: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, document_id, expected_code, expected_message): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| def test_invalid_document_id(self, get_http_api_auth, add_documents, tmp_path, document_id, expected_code, expected_message): | |||
| dataset_id, _ = add_documents | |||
| res = download_document( | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| @@ -118,8 +116,8 @@ class TestDocumentDownload: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, dataset_id, expected_code, expected_message): | |||
| _, document_ids = get_dataset_id_and_document_ids | |||
| def test_invalid_dataset_id(self, get_http_api_auth, add_documents, tmp_path, dataset_id, expected_code, expected_message): | |||
| _, document_ids = add_documents | |||
| res = download_document( | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| @@ -132,9 +130,9 @@ class TestDocumentDownload: | |||
| assert response_json["code"] == expected_code | |||
| assert response_json["message"] == expected_message | |||
| def test_same_file_repeat(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, file_management_tmp_dir): | |||
| def test_same_file_repeat(self, get_http_api_auth, add_documents, tmp_path, ragflow_tmp_dir): | |||
| num = 5 | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| dataset_id, document_ids = add_documents | |||
| for i in range(num): | |||
| res = download_document( | |||
| get_http_api_auth, | |||
| @@ -144,23 +142,22 @@ class TestDocumentDownload: | |||
| ) | |||
| assert res.status_code == codes.ok | |||
| assert compare_by_hash( | |||
| file_management_tmp_dir / "ragflow_test_upload_0.txt", | |||
| ragflow_tmp_dir / "ragflow_test_upload_0.txt", | |||
| tmp_path / f"ragflow_test_download_{i}.txt", | |||
| ) | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| def test_concurrent_download(get_http_api_auth, tmp_path): | |||
| def test_concurrent_download(get_http_api_auth, add_dataset, tmp_path): | |||
| document_count = 20 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], document_count, tmp_path) | |||
| dataset_id = add_dataset | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_count, tmp_path) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| download_document, | |||
| get_http_api_auth, | |||
| ids[0], | |||
| dataset_id, | |||
| document_ids[i], | |||
| tmp_path / f"ragflow_test_download_{i}.txt", | |||
| ) | |||
| @@ -16,10 +16,7 @@ | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import ( | |||
| INVALID_API_TOKEN, | |||
| list_documnet, | |||
| ) | |||
| from common import INVALID_API_TOKEN, list_documnets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -40,17 +37,16 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(auth, dataset_id) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = list_documnets(auth, "dataset_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| class TestDocumentList: | |||
| def test_default(self, get_http_api_auth, get_dataset_id_and_document_ids): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id) | |||
| class TestDocumentsList: | |||
| def test_default(self, get_http_api_auth, add_documents): | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]["docs"]) == 5 | |||
| assert res["data"]["total"] == 5 | |||
| @@ -66,8 +62,8 @@ class TestDocumentList: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message): | |||
| res = list_documnet(get_http_api_auth, dataset_id) | |||
| def test_invalid_dataset_id(self, get_http_api_auth, dataset_id, expected_code, expected_message): | |||
| res = list_documnets(get_http_api_auth, dataset_id) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -98,14 +94,14 @@ class TestDocumentList: | |||
| def test_page( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| params, | |||
| expected_code, | |||
| expected_page_size, | |||
| expected_message, | |||
| ): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]["docs"]) == expected_page_size | |||
| @@ -140,14 +136,14 @@ class TestDocumentList: | |||
| def test_page_size( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| params, | |||
| expected_code, | |||
| expected_page_size, | |||
| expected_message, | |||
| ): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]["docs"]) == expected_page_size | |||
| @@ -194,14 +190,14 @@ class TestDocumentList: | |||
| def test_orderby( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| params, | |||
| expected_code, | |||
| assertions, | |||
| expected_message, | |||
| ): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if callable(assertions): | |||
| @@ -273,14 +269,14 @@ class TestDocumentList: | |||
| def test_desc( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| params, | |||
| expected_code, | |||
| assertions, | |||
| expected_message, | |||
| ): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if callable(assertions): | |||
| @@ -298,9 +294,9 @@ class TestDocumentList: | |||
| ({"keywords": "unknown"}, 0), | |||
| ], | |||
| ) | |||
| def test_keywords(self, get_http_api_auth, get_dataset_id_and_document_ids, params, expected_num): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| def test_keywords(self, get_http_api_auth, add_documents, params, expected_num): | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]["docs"]) == expected_num | |||
| assert res["data"]["total"] == expected_num | |||
| @@ -322,14 +318,14 @@ class TestDocumentList: | |||
| def test_name( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| params, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| dataset_id, _ = add_documents | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if params["name"] in [None, ""]: | |||
| @@ -351,18 +347,18 @@ class TestDocumentList: | |||
| def test_id( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| document_id, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| dataset_id, document_ids = add_documents | |||
| if callable(document_id): | |||
| params = {"id": document_id(document_ids)} | |||
| else: | |||
| params = {"id": document_id} | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| @@ -391,36 +387,36 @@ class TestDocumentList: | |||
| def test_name_and_id( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| document_id, | |||
| name, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| dataset_id, document_ids = add_documents | |||
| if callable(document_id): | |||
| params = {"id": document_id(document_ids), "name": name} | |||
| else: | |||
| params = {"id": document_id, "name": name} | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| if expected_code == 0: | |||
| assert len(res["data"]["docs"]) == expected_num | |||
| else: | |||
| assert res["message"] == expected_message | |||
| def test_concurrent_list(self, get_http_api_auth, get_dataset_id_and_document_ids): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| def test_concurrent_list(self, get_http_api_auth, add_documents): | |||
| dataset_id, _ = add_documents | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_documnet, get_http_api_auth, dataset_id) for i in range(100)] | |||
| futures = [executor.submit(list_documnets, get_http_api_auth, dataset_id) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| def test_invalid_params(self, get_http_api_auth, get_dataset_id_and_document_ids): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| def test_invalid_params(self, get_http_api_auth, add_documents): | |||
| dataset_id, _ = add_documents | |||
| params = {"a": "b"} | |||
| res = list_documnet(get_http_api_auth, dataset_id, params=params) | |||
| res = list_documnets(get_http_api_auth, dataset_id, params=params) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]["docs"]) == 5 | |||
| @@ -16,20 +16,14 @@ | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import ( | |||
| INVALID_API_TOKEN, | |||
| batch_create_datasets, | |||
| bulk_upload_documents, | |||
| list_documnet, | |||
| parse_documnet, | |||
| ) | |||
| from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| from libs.utils import wait_for | |||
| def validate_document_details(auth, dataset_id, document_ids): | |||
| for document_id in document_ids: | |||
| res = list_documnet(auth, dataset_id, params={"id": document_id}) | |||
| res = list_documnets(auth, dataset_id, params={"id": document_id}) | |||
| doc = res["data"]["docs"][0] | |||
| assert doc["run"] == "DONE" | |||
| assert len(doc["process_begin_at"]) > 0 | |||
| @@ -50,14 +44,12 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| res = parse_documnet(auth, dataset_id, {"document_ids": document_ids}) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = parse_documnets(auth, "dataset_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestDocumentsParse: | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| @@ -89,21 +81,19 @@ class TestDocumentsParse: | |||
| (lambda r: {"document_ids": r}, 0, ""), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message): | |||
| def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message): | |||
| @wait_for(10, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id, _document_ids): | |||
| for _document_id in _document_ids: | |||
| res = list_documnet(_auth, _dataset_id, {"id": _document_id}) | |||
| res = list_documnets(_auth, _dataset_id, {"id": _document_id}) | |||
| if res["data"]["docs"][0]["run"] != "DONE": | |||
| return False | |||
| return True | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) | |||
| dataset_id, document_ids = add_documents_func | |||
| if callable(payload): | |||
| payload = payload(document_ids) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, payload) | |||
| res = parse_documnets(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == expected_code | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message | |||
| @@ -125,14 +115,13 @@ class TestDocumentsParse: | |||
| def test_invalid_dataset_id( | |||
| self, | |||
| get_http_api_auth, | |||
| tmp_path, | |||
| add_documents_func, | |||
| dataset_id, | |||
| expected_code, | |||
| expected_message, | |||
| ): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| _, document_ids = add_documents_func | |||
| res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -144,21 +133,19 @@ class TestDocumentsParse: | |||
| lambda r: {"document_ids": r + ["invalid_id"]}, | |||
| ], | |||
| ) | |||
| def test_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload): | |||
| def test_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload): | |||
| @wait_for(10, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id): | |||
| res = list_documnet(_auth, _dataset_id) | |||
| res = list_documnets(_auth, _dataset_id) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) | |||
| dataset_id, document_ids = add_documents_func | |||
| if callable(payload): | |||
| payload = payload(document_ids) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, payload) | |||
| res = parse_documnets(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == "Documents not found: ['invalid_id']" | |||
| @@ -166,96 +153,92 @@ class TestDocumentsParse: | |||
| validate_document_details(get_http_api_auth, dataset_id, document_ids) | |||
| def test_repeated_parse(self, get_http_api_auth, tmp_path): | |||
| def test_repeated_parse(self, get_http_api_auth, add_documents_func): | |||
| @wait_for(10, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id): | |||
| res = list_documnet(_auth, _dataset_id) | |||
| res = list_documnets(_auth, _dataset_id) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| dataset_id, document_ids = add_documents_func | |||
| res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| condition(get_http_api_auth, dataset_id) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| def test_duplicate_parse(self, get_http_api_auth, tmp_path): | |||
| def test_duplicate_parse(self, get_http_api_auth, add_documents_func): | |||
| @wait_for(10, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id): | |||
| res = list_documnet(_auth, _dataset_id) | |||
| res = list_documnets(_auth, _dataset_id) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) | |||
| dataset_id, document_ids = add_documents_func | |||
| res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) | |||
| assert res["code"] == 0 | |||
| assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}" | |||
| assert res["data"]["success_count"] == 1 | |||
| assert "Duplicate document ids" in res["data"]["errors"][0] | |||
| assert res["data"]["success_count"] == 3 | |||
| condition(get_http_api_auth, dataset_id) | |||
| validate_document_details(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_parse_100_files(self, get_http_api_auth, tmp_path): | |||
| @wait_for(100, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id, _document_num): | |||
| res = list_documnet(_auth, _dataset_id, {"page_size": _document_num}) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| document_num = 100 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| condition(get_http_api_auth, dataset_id, document_num) | |||
| validate_document_details(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_concurrent_parse(self, get_http_api_auth, tmp_path): | |||
| @wait_for(120, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id, _document_num): | |||
| res = list_documnet(_auth, _dataset_id, {"page_size": _document_num}) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| document_num = 100 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| parse_documnet, | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| {"document_ids": document_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(document_num) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| condition(get_http_api_auth, dataset_id, document_num) | |||
| validate_document_details(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_parse_100_files(get_http_api_auth, add_datase_func, tmp_path): | |||
| @wait_for(100, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id, _document_num): | |||
| res = list_documnets(_auth, _dataset_id, {"page_size": _document_num}) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| document_num = 100 | |||
| dataset_id = add_datase_func | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| condition(get_http_api_auth, dataset_id, document_num) | |||
| validate_document_details(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path): | |||
| @wait_for(120, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id, _document_num): | |||
| res = list_documnets(_auth, _dataset_id, {"page_size": _document_num}) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "DONE": | |||
| return False | |||
| return True | |||
| document_num = 100 | |||
| dataset_id = add_datase_func | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| parse_documnets, | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| {"document_ids": document_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(document_num) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| condition(get_http_api_auth, dataset_id, document_num) | |||
| validate_document_details(get_http_api_auth, dataset_id, document_ids) | |||
| @@ -16,21 +16,14 @@ | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| from common import ( | |||
| INVALID_API_TOKEN, | |||
| batch_create_datasets, | |||
| bulk_upload_documents, | |||
| list_documnet, | |||
| parse_documnet, | |||
| stop_parse_documnet, | |||
| ) | |||
| from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| from libs.utils import wait_for | |||
| def validate_document_parse_done(auth, dataset_id, document_ids): | |||
| for document_id in document_ids: | |||
| res = list_documnet(auth, dataset_id, params={"id": document_id}) | |||
| res = list_documnets(auth, dataset_id, params={"id": document_id}) | |||
| doc = res["data"]["docs"][0] | |||
| assert doc["run"] == "DONE" | |||
| assert len(doc["process_begin_at"]) > 0 | |||
| @@ -41,14 +34,13 @@ def validate_document_parse_done(auth, dataset_id, document_ids): | |||
| def validate_document_parse_cancel(auth, dataset_id, document_ids): | |||
| for document_id in document_ids: | |||
| res = list_documnet(auth, dataset_id, params={"id": document_id}) | |||
| res = list_documnets(auth, dataset_id, params={"id": document_id}) | |||
| doc = res["data"]["docs"][0] | |||
| assert doc["run"] == "CANCEL" | |||
| assert len(doc["process_begin_at"]) > 0 | |||
| assert doc["progress"] == 0.0 | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "auth, expected_code, expected_message", | |||
| @@ -61,15 +53,13 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = stop_parse_documnet(auth, ids[0]) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = stop_parse_documnets(auth, "dataset_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.skip | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestDocumentsParseStop: | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| @@ -101,24 +91,22 @@ class TestDocumentsParseStop: | |||
| (lambda r: {"document_ids": r}, 0, ""), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message): | |||
| def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message): | |||
| @wait_for(10, 1, "Document parsing timeout") | |||
| def condition(_auth, _dataset_id, _document_ids): | |||
| for _document_id in _document_ids: | |||
| res = list_documnet(_auth, _dataset_id, {"id": _document_id}) | |||
| res = list_documnets(_auth, _dataset_id, {"id": _document_id}) | |||
| if res["data"]["docs"][0]["run"] != "DONE": | |||
| return False | |||
| return True | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| dataset_id, document_ids = add_documents_func | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| if callable(payload): | |||
| payload = payload(document_ids) | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, payload) | |||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == expected_code | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message | |||
| @@ -129,7 +117,7 @@ class TestDocumentsParseStop: | |||
| validate_document_parse_done(get_http_api_auth, dataset_id, completed_document_ids) | |||
| @pytest.mark.parametrize( | |||
| "dataset_id, expected_code, expected_message", | |||
| "invalid_dataset_id, expected_code, expected_message", | |||
| [ | |||
| ("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"), | |||
| ( | |||
| @@ -142,14 +130,14 @@ class TestDocumentsParseStop: | |||
| def test_invalid_dataset_id( | |||
| self, | |||
| get_http_api_auth, | |||
| tmp_path, | |||
| dataset_id, | |||
| add_documents_func, | |||
| invalid_dataset_id, | |||
| expected_code, | |||
| expected_message, | |||
| ): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| dataset_id, document_ids = add_documents_func | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnets(get_http_api_auth, invalid_dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -162,71 +150,65 @@ class TestDocumentsParseStop: | |||
| lambda r: {"document_ids": r + ["invalid_id"]}, | |||
| ], | |||
| ) | |||
| def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload): | |||
| dataset_id, document_ids = add_documents_func | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| if callable(payload): | |||
| payload = payload(document_ids) | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, payload) | |||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, payload) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == "You don't own the document invalid_id." | |||
| validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) | |||
| def test_repeated_stop_parse(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| def test_repeated_stop_parse(self, get_http_api_auth, add_documents_func): | |||
| dataset_id, document_ids = add_documents_func | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 102 | |||
| assert res["message"] == "Can't stop parsing document with progress at 0 or 1" | |||
| def test_duplicate_stop_parse(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) | |||
| def test_duplicate_stop_parse(self, get_http_api_auth, add_documents_func): | |||
| dataset_id, document_ids = add_documents_func | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) | |||
| assert res["code"] == 0 | |||
| assert res["data"]["success_count"] == 1 | |||
| assert res["data"]["success_count"] == 3 | |||
| assert f"Duplicate document ids: {document_ids[0]}" in res["data"]["errors"] | |||
| @pytest.mark.slow | |||
| def test_stop_parse_100_files(self, get_http_api_auth, tmp_path): | |||
| document_num = 100 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_concurrent_parse(self, get_http_api_auth, tmp_path): | |||
| document_num = 50 | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| dataset_id = ids[0] | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| stop_parse_documnet, | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| {"document_ids": document_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(document_num) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_stop_parse_100_files(get_http_api_auth, add_datase_func, tmp_path): | |||
| document_num = 100 | |||
| dataset_id = add_datase_func | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| assert res["code"] == 0 | |||
| validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) | |||
| @pytest.mark.slow | |||
| def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path): | |||
| document_num = 50 | |||
| dataset_id = add_datase_func | |||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | |||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| stop_parse_documnets, | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| {"document_ids": document_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(document_num) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) | |||
| @@ -16,7 +16,7 @@ | |||
| import pytest | |||
| from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, list_documnet, update_documnet | |||
| from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, list_documnets, update_documnet | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -32,14 +32,13 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| res = update_documnet(auth, dataset_id, document_ids[0], {"name": "auth_test.txt"}) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = update_documnet(auth, "dataset_id", "document_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| class TestUpdatedDocument: | |||
| class TestDocumentsUpdated: | |||
| @pytest.mark.parametrize( | |||
| "name, expected_code, expected_message", | |||
| [ | |||
| @@ -81,12 +80,12 @@ class TestUpdatedDocument: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_name(self, get_http_api_auth, get_dataset_id_and_document_ids, name, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| def test_name(self, get_http_api_auth, add_documents, name, expected_code, expected_message): | |||
| dataset_id, document_ids = add_documents | |||
| res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": name}) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| assert res["data"]["docs"][0]["name"] == name | |||
| else: | |||
| assert res["message"] == expected_message | |||
| @@ -102,8 +101,8 @@ class TestUpdatedDocument: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, document_id, expected_code, expected_message): | |||
| dataset_id, _ = get_dataset_id_and_document_ids | |||
| def test_invalid_document_id(self, get_http_api_auth, add_documents, document_id, expected_code, expected_message): | |||
| dataset_id, _ = add_documents | |||
| res = update_documnet(get_http_api_auth, dataset_id, document_id, {"name": "new_name.txt"}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -119,8 +118,8 @@ class TestUpdatedDocument: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message): | |||
| _, document_ids = get_dataset_id_and_document_ids | |||
| def test_invalid_dataset_id(self, get_http_api_auth, add_documents, dataset_id, expected_code, expected_message): | |||
| _, document_ids = add_documents | |||
| res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": "new_name.txt"}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @@ -129,11 +128,11 @@ class TestUpdatedDocument: | |||
| "meta_fields, expected_code, expected_message", | |||
| [({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")], | |||
| ) | |||
| def test_meta_fields(self, get_http_api_auth, get_dataset_id_and_document_ids, meta_fields, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| def test_meta_fields(self, get_http_api_auth, add_documents, meta_fields, expected_code, expected_message): | |||
| dataset_id, document_ids = add_documents | |||
| res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"meta_fields": meta_fields}) | |||
| if expected_code == 0: | |||
| res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| assert res["data"]["docs"][0]["meta_fields"] == meta_fields | |||
| else: | |||
| assert res["message"] == expected_message | |||
| @@ -162,12 +161,12 @@ class TestUpdatedDocument: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_chunk_method(self, get_http_api_auth, get_dataset_id_and_document_ids, chunk_method, expected_code, expected_message): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| def test_chunk_method(self, get_http_api_auth, add_documents, chunk_method, expected_code, expected_message): | |||
| dataset_id, document_ids = add_documents | |||
| res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"chunk_method": chunk_method}) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| if chunk_method != "": | |||
| assert res["data"]["docs"][0]["chunk_method"] == chunk_method | |||
| else: | |||
| @@ -282,259 +281,259 @@ class TestUpdatedDocument: | |||
| def test_invalid_field( | |||
| self, | |||
| get_http_api_auth, | |||
| get_dataset_id_and_document_ids, | |||
| add_documents, | |||
| payload, | |||
| expected_code, | |||
| expected_message, | |||
| ): | |||
| dataset_id, document_ids = get_dataset_id_and_document_ids | |||
| dataset_id, document_ids = add_documents | |||
| res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], payload) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| @pytest.mark.parametrize( | |||
| "chunk_method, parser_config, expected_code, expected_message", | |||
| [ | |||
| ("naive", {}, 0, ""), | |||
| ( | |||
| "naive", | |||
| { | |||
| "chunk_token_num": 128, | |||
| "layout_recognize": "DeepDOC", | |||
| "html4excel": False, | |||
| "delimiter": "\\n!?;。;!?", | |||
| "task_page_size": 12, | |||
| "raptor": {"use_raptor": False}, | |||
| }, | |||
| 0, | |||
| "", | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": -1}, | |||
| 100, | |||
| "AssertionError('chunk_token_num should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": 0}, | |||
| 100, | |||
| "AssertionError('chunk_token_num should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": 100000000}, | |||
| 100, | |||
| "AssertionError('chunk_token_num should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": 3.14}, | |||
| 102, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ( | |||
| "naive", | |||
| {"layout_recognize": "DeepDOC"}, | |||
| 0, | |||
| "", | |||
| ), | |||
| ( | |||
| "naive", | |||
| {"layout_recognize": "Naive"}, | |||
| 0, | |||
| "", | |||
| ), | |||
| ("naive", {"html4excel": True}, 0, ""), | |||
| ("naive", {"html4excel": False}, 0, ""), | |||
| pytest.param( | |||
| "naive", | |||
| {"html4excel": 1}, | |||
| 100, | |||
| "AssertionError('html4excel should be True or False')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ("naive", {"delimiter": ""}, 0, ""), | |||
| ("naive", {"delimiter": "`##`"}, 0, ""), | |||
| pytest.param( | |||
| "naive", | |||
| {"delimiter": 1}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": -1}, | |||
| 100, | |||
| "AssertionError('task_page_size should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": 0}, | |||
| 100, | |||
| "AssertionError('task_page_size should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": 100000000}, | |||
| 100, | |||
| "AssertionError('task_page_size should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ("naive", {"raptor": {"use_raptor": True}}, 0, ""), | |||
| ("naive", {"raptor": {"use_raptor": False}}, 0, ""), | |||
| pytest.param( | |||
| "naive", | |||
| {"invalid_key": "invalid_value"}, | |||
| 100, | |||
| """AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_keywords": -1}, | |||
| 100, | |||
| "AssertionError('auto_keywords should be in range from 0 to 32')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_keywords": 32}, | |||
| 100, | |||
| "AssertionError('auto_keywords should be in range from 0 to 32')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_keywords": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": -1}, | |||
| 100, | |||
| "AssertionError('auto_questions should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": 10}, | |||
| 100, | |||
| "AssertionError('auto_questions should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": -1}, | |||
| 100, | |||
| "AssertionError('topn_tags should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": 10}, | |||
| 100, | |||
| "AssertionError('topn_tags should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ], | |||
| ) | |||
| def test_parser_config( | |||
| get_http_api_auth, | |||
| tmp_path, | |||
| chunk_method, | |||
| parser_config, | |||
| expected_code, | |||
| expected_message, | |||
| ): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) | |||
| res = update_documnet( | |||
| get_http_api_auth, | |||
| ids[0], | |||
| document_ids[0], | |||
| {"chunk_method": chunk_method, "parser_config": parser_config}, | |||
| class TestUpdateDocumentParserConfig: | |||
| @pytest.mark.parametrize( | |||
| "chunk_method, parser_config, expected_code, expected_message", | |||
| [ | |||
| ("naive", {}, 0, ""), | |||
| ( | |||
| "naive", | |||
| { | |||
| "chunk_token_num": 128, | |||
| "layout_recognize": "DeepDOC", | |||
| "html4excel": False, | |||
| "delimiter": "\\n!?;。;!?", | |||
| "task_page_size": 12, | |||
| "raptor": {"use_raptor": False}, | |||
| }, | |||
| 0, | |||
| "", | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": -1}, | |||
| 100, | |||
| "AssertionError('chunk_token_num should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": 0}, | |||
| 100, | |||
| "AssertionError('chunk_token_num should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": 100000000}, | |||
| 100, | |||
| "AssertionError('chunk_token_num should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": 3.14}, | |||
| 102, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"chunk_token_num": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ( | |||
| "naive", | |||
| {"layout_recognize": "DeepDOC"}, | |||
| 0, | |||
| "", | |||
| ), | |||
| ( | |||
| "naive", | |||
| {"layout_recognize": "Naive"}, | |||
| 0, | |||
| "", | |||
| ), | |||
| ("naive", {"html4excel": True}, 0, ""), | |||
| ("naive", {"html4excel": False}, 0, ""), | |||
| pytest.param( | |||
| "naive", | |||
| {"html4excel": 1}, | |||
| 100, | |||
| "AssertionError('html4excel should be True or False')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ("naive", {"delimiter": ""}, 0, ""), | |||
| ("naive", {"delimiter": "`##`"}, 0, ""), | |||
| pytest.param( | |||
| "naive", | |||
| {"delimiter": 1}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": -1}, | |||
| 100, | |||
| "AssertionError('task_page_size should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": 0}, | |||
| 100, | |||
| "AssertionError('task_page_size should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": 100000000}, | |||
| 100, | |||
| "AssertionError('task_page_size should be in range from 1 to 100000000')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"task_page_size": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ("naive", {"raptor": {"use_raptor": True}}, 0, ""), | |||
| ("naive", {"raptor": {"use_raptor": False}}, 0, ""), | |||
| pytest.param( | |||
| "naive", | |||
| {"invalid_key": "invalid_value"}, | |||
| 100, | |||
| """AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_keywords": -1}, | |||
| 100, | |||
| "AssertionError('auto_keywords should be in range from 0 to 32')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_keywords": 32}, | |||
| 100, | |||
| "AssertionError('auto_keywords should be in range from 0 to 32')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_keywords": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": -1}, | |||
| 100, | |||
| "AssertionError('auto_questions should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": 10}, | |||
| 100, | |||
| "AssertionError('auto_questions should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"auto_questions": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": -1}, | |||
| 100, | |||
| "AssertionError('topn_tags should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": 10}, | |||
| 100, | |||
| "AssertionError('topn_tags should be in range from 0 to 10')", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": 3.14}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| pytest.param( | |||
| "naive", | |||
| {"topn_tags": "1024"}, | |||
| 100, | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/6098"), | |||
| ), | |||
| ], | |||
| ) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]}) | |||
| if parser_config != {}: | |||
| for k, v in parser_config.items(): | |||
| assert res["data"]["docs"][0]["parser_config"][k] == v | |||
| else: | |||
| assert res["data"]["docs"][0]["parser_config"] == { | |||
| "chunk_token_num": 128, | |||
| "delimiter": "\\n!?;。;!?", | |||
| "html4excel": False, | |||
| "layout_recognize": "DeepDOC", | |||
| "raptor": {"use_raptor": False}, | |||
| } | |||
| if expected_code != 0 or expected_message: | |||
| assert res["message"] == expected_message | |||
| def test_parser_config( | |||
| self, | |||
| get_http_api_auth, | |||
| add_documents, | |||
| chunk_method, | |||
| parser_config, | |||
| expected_code, | |||
| expected_message, | |||
| ): | |||
| dataset_id, document_ids = add_documents | |||
| res = update_documnet( | |||
| get_http_api_auth, | |||
| dataset_id, | |||
| document_ids[0], | |||
| {"chunk_method": chunk_method, "parser_config": parser_config}, | |||
| ) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) | |||
| if parser_config != {}: | |||
| for k, v in parser_config.items(): | |||
| assert res["data"]["docs"][0]["parser_config"][k] == v | |||
| else: | |||
| assert res["data"]["docs"][0]["parser_config"] == { | |||
| "chunk_token_num": 128, | |||
| "delimiter": "\\n!?;。;!?", | |||
| "html4excel": False, | |||
| "layout_recognize": "DeepDOC", | |||
| "raptor": {"use_raptor": False}, | |||
| } | |||
| if expected_code != 0 or expected_message: | |||
| assert res["message"] == expected_message | |||
| @@ -19,15 +19,7 @@ from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| import requests | |||
| from common import ( | |||
| DOCUMENT_NAME_LIMIT, | |||
| FILE_API_URL, | |||
| HOST_ADDRESS, | |||
| INVALID_API_TOKEN, | |||
| batch_create_datasets, | |||
| list_dataset, | |||
| upload_documnets, | |||
| ) | |||
| from common import DOCUMENT_NAME_LIMIT, FILE_API_URL, HOST_ADDRESS, INVALID_API_TOKEN, list_datasets, upload_documnets | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| from libs.utils.file_utils import create_txt_file | |||
| from requests_toolbelt import MultipartEncoder | |||
| @@ -46,21 +38,19 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = upload_documnets(auth, ids[0]) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| res = upload_documnets(auth, "dataset_id") | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| class TestUploadDocuments: | |||
| def test_valid_single_upload(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| class TestDocumentsUpload: | |||
| def test_valid_single_upload(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| fp = create_txt_file(tmp_path / "ragflow_test.txt") | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 0 | |||
| assert res["data"][0]["dataset_id"] == ids[0] | |||
| assert res["data"][0]["dataset_id"] == dataset_id | |||
| assert res["data"][0]["name"] == fp.name | |||
| @pytest.mark.parametrize( | |||
| @@ -79,45 +69,45 @@ class TestUploadDocuments: | |||
| ], | |||
| indirect=True, | |||
| ) | |||
| def test_file_type_validation(self, get_http_api_auth, generate_test_files, request): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_file_type_validation(self, get_http_api_auth, add_dataset_func, generate_test_files, request): | |||
| dataset_id = add_dataset_func | |||
| fp = generate_test_files[request.node.callspec.params["generate_test_files"]] | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 0 | |||
| assert res["data"][0]["dataset_id"] == ids[0] | |||
| assert res["data"][0]["dataset_id"] == dataset_id | |||
| assert res["data"][0]["name"] == fp.name | |||
| @pytest.mark.parametrize( | |||
| "file_type", | |||
| ["exe", "unknown"], | |||
| ) | |||
| def test_unsupported_file_type(self, get_http_api_auth, tmp_path, file_type): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_unsupported_file_type(self, get_http_api_auth, add_dataset_func, tmp_path, file_type): | |||
| dataset_id = add_dataset_func | |||
| fp = tmp_path / f"ragflow_test.{file_type}" | |||
| fp.touch() | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 500 | |||
| assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!" | |||
| def test_missing_file(self, get_http_api_auth): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| res = upload_documnets(get_http_api_auth, ids[0]) | |||
| def test_missing_file(self, get_http_api_auth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| res = upload_documnets(get_http_api_auth, dataset_id) | |||
| assert res["code"] == 101 | |||
| assert res["message"] == "No file part!" | |||
| def test_empty_file(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_empty_file(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| fp = tmp_path / "empty.txt" | |||
| fp.touch() | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 0 | |||
| assert res["data"][0]["size"] == 0 | |||
| def test_filename_empty(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_filename_empty(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| fp = create_txt_file(tmp_path / "ragflow_test.txt") | |||
| url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=ids[0]) | |||
| url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) | |||
| fields = (("file", ("", fp.open("rb"))),) | |||
| m = MultipartEncoder(fields=fields) | |||
| res = requests.post( | |||
| @@ -129,11 +119,11 @@ class TestUploadDocuments: | |||
| assert res.json()["code"] == 101 | |||
| assert res.json()["message"] == "No file selected!" | |||
| def test_filename_exceeds_max_length(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_filename_exceeds_max_length(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| # filename_length = 129 | |||
| fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt") | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 101 | |||
| assert res["message"] == "File name should be less than 128 bytes." | |||
| @@ -143,61 +133,61 @@ class TestUploadDocuments: | |||
| assert res["code"] == 100 | |||
| assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")""" | |||
| def test_duplicate_files(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_duplicate_files(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| fp = create_txt_file(tmp_path / "ragflow_test.txt") | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp, fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp, fp]) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 2 | |||
| for i in range(len(res["data"])): | |||
| assert res["data"][i]["dataset_id"] == ids[0] | |||
| assert res["data"][i]["dataset_id"] == dataset_id | |||
| expected_name = fp.name | |||
| if i != 0: | |||
| expected_name = f"{fp.stem}({i}){fp.suffix}" | |||
| assert res["data"][i]["name"] == expected_name | |||
| def test_same_file_repeat(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_same_file_repeat(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| fp = create_txt_file(tmp_path / "ragflow_test.txt") | |||
| for i in range(10): | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 1 | |||
| assert res["data"][0]["dataset_id"] == ids[0] | |||
| assert res["data"][0]["dataset_id"] == dataset_id | |||
| expected_name = fp.name | |||
| if i != 0: | |||
| expected_name = f"{fp.stem}({i}){fp.suffix}" | |||
| assert res["data"][0]["name"] == expected_name | |||
| def test_filename_special_characters(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_filename_special_characters(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| illegal_chars = '<>:"/\\|?*' | |||
| translation_table = str.maketrans({char: "_" for char in illegal_chars}) | |||
| safe_filename = string.punctuation.translate(translation_table) | |||
| fp = tmp_path / f"{safe_filename}.txt" | |||
| fp.write_text("Sample text content") | |||
| res = upload_documnets(get_http_api_auth, ids[0], [fp]) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, [fp]) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 1 | |||
| assert res["data"][0]["dataset_id"] == ids[0] | |||
| assert res["data"][0]["dataset_id"] == dataset_id | |||
| assert res["data"][0]["name"] == fp.name | |||
| def test_multiple_files(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_multiple_files(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| expected_document_count = 20 | |||
| fps = [] | |||
| for i in range(expected_document_count): | |||
| fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt") | |||
| fps.append(fp) | |||
| res = upload_documnets(get_http_api_auth, ids[0], fps) | |||
| res = upload_documnets(get_http_api_auth, dataset_id, fps) | |||
| assert res["code"] == 0 | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["document_count"] == expected_document_count | |||
| def test_concurrent_upload(self, get_http_api_auth, tmp_path): | |||
| ids = batch_create_datasets(get_http_api_auth, 1) | |||
| def test_concurrent_upload(self, get_http_api_auth, add_dataset_func, tmp_path): | |||
| dataset_id = add_dataset_func | |||
| expected_document_count = 20 | |||
| fps = [] | |||
| @@ -206,9 +196,9 @@ class TestUploadDocuments: | |||
| fps.append(fp) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(upload_documnets, get_http_api_auth, ids[0], fps[i : i + 1]) for i in range(expected_document_count)] | |||
| futures = [executor.submit(upload_documnets, get_http_api_auth, dataset_id, fps[i : i + 1]) for i in range(expected_document_count)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| res = list_dataset(get_http_api_auth, {"id": ids[0]}) | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| assert res["data"][0]["document_count"] == expected_document_count | |||