### What problem does this PR solve? - Implement new SDK API test cases for chat assistant CRUD operations - Enhance HTTP API concurrent tests to use as_completed for better reliability ### Type of change - [x] Add test cases - [x] Refactoringtags/v0.19.1
| @@ -13,7 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, batch_create_chat_assistants, delete_chat_assistants, list_chat_assistants | |||
| @@ -107,16 +107,18 @@ class TestChatAssistantsDelete: | |||
| @pytest.mark.p3 | |||
| def test_concurrent_deletion(self, api_key): | |||
| ids = batch_create_chat_assistants(api_key, 100) | |||
| count = 100 | |||
| ids = batch_create_chat_assistants(api_key, count) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(delete_chat_assistants, api_key, {"ids": ids[i : i + 1]}) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| futures = [executor.submit(delete_chat_assistants, api_key, {"ids": ids[i : i + 1]}) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_delete_10k(self, api_key): | |||
| ids = batch_create_chat_assistants(api_key, 10_000) | |||
| ids = batch_create_chat_assistants(api_key, 1_000) | |||
| res = delete_chat_assistants(api_key, {"ids": ids}) | |||
| assert res["code"] == 0 | |||
| @@ -13,7 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, delete_datasets, list_chat_assistants | |||
| @@ -288,10 +288,12 @@ class TestChatAssistantsList: | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, api_key): | |||
| count = 100 | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_chat_assistants, api_key) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| futures = [executor.submit(list_chat_assistants, api_key) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_invalid_params(self, api_key): | |||
| @@ -40,3 +40,8 @@ def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Do | |||
| # CHUNK MANAGEMENT WITHIN DATASET | |||
| def batch_add_chunks(document: Document, num: int): | |||
| return [document.add_chunk(content=f"chunk test {i}") for i in range(num)] | |||
| # CHAT ASSISTANT MANAGEMENT | |||
| def batch_create_chat_assistants(client: RAGFlow, num: int): | |||
| return [client.create_chat(name=f"test_chat_assistant_{i}") for i in range(num)] | |||
| @@ -90,6 +90,14 @@ def clear_datasets(request: FixtureRequest, client: RAGFlow): | |||
| request.addfinalizer(cleanup) | |||
| @pytest.fixture(scope="function") | |||
| def clear_chat_assistants(request: FixtureRequest, client: RAGFlow): | |||
| def cleanup(): | |||
| client.delete_chats(ids=None) | |||
| request.addfinalizer(cleanup) | |||
| @pytest.fixture(scope="class") | |||
| def add_dataset(request: FixtureRequest, client: RAGFlow): | |||
| def cleanup(): | |||
| @@ -137,3 +145,22 @@ def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document]) | |||
| request.addfinalizer(cleanup) | |||
| return dataset, document, chunks | |||
| @pytest.fixture(scope="class") | |||
| def add_chat_assistants(request, client, add_document): | |||
| def cleanup(): | |||
| client.delete_chats(ids=None) | |||
| request.addfinalizer(cleanup) | |||
| dataset, document = add_document | |||
| dataset.async_parse_documents([document.id]) | |||
| condition(dataset) | |||
| chat_assistants = [] | |||
| for i in range(5): | |||
| chat_assistant = client.create_chat(name=f"test_chat_assistant_{i}", dataset_ids=[dataset.id]) | |||
| chat_assistants.append(chat_assistant) | |||
| return dataset, document, chat_assistants | |||
| @@ -0,0 +1,47 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import pytest | |||
| from pytest import FixtureRequest | |||
| from ragflow_sdk import Chat, DataSet, Document, RAGFlow | |||
| from utils import wait_for | |||
| @wait_for(30, 1, "Document parsing timeout") | |||
| def condition(_dataset: DataSet): | |||
| documents = _dataset.list_documents(page_size=1000) | |||
| for document in documents: | |||
| if document.run != "DONE": | |||
| return False | |||
| return True | |||
| @pytest.fixture(scope="function") | |||
| def add_chat_assistants_func(request: FixtureRequest, client: RAGFlow, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chat]]: | |||
| def cleanup(): | |||
| client.delete_chats(ids=None) | |||
| request.addfinalizer(cleanup) | |||
| dataset, document = add_document | |||
| dataset.async_parse_documents([document.id]) | |||
| condition(dataset) | |||
| chat_assistants = [] | |||
| for i in range(5): | |||
| chat_assistant = client.create_chat(name=f"test_chat_assistant_{i}", dataset_ids=[dataset.id]) | |||
| chat_assistants.append(chat_assistant) | |||
| return dataset, document, chat_assistants | |||
| @@ -0,0 +1,224 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from operator import attrgetter | |||
| import pytest | |||
| from configs import CHAT_ASSISTANT_NAME_LIMIT | |||
| from ragflow_sdk import Chat | |||
| from utils import encode_avatar | |||
| from utils.file_utils import create_image_file | |||
| @pytest.mark.usefixtures("clear_chat_assistants") | |||
| class TestChatAssistantCreate: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.usefixtures("add_chunks") | |||
| @pytest.mark.parametrize( | |||
| "name, expected_message", | |||
| [ | |||
| ("valid_name", ""), | |||
| pytest.param("a" * (CHAT_ASSISTANT_NAME_LIMIT + 1), "", marks=pytest.mark.skip(reason="issues/")), | |||
| pytest.param(1, "", marks=pytest.mark.skip(reason="issues/")), | |||
| ("", "`name` is required."), | |||
| ("duplicated_name", "Duplicated chat name in creating chat."), | |||
| ("case insensitive", "Duplicated chat name in creating chat."), | |||
| ], | |||
| ) | |||
| def test_name(self, client, name, expected_message): | |||
| if name == "duplicated_name": | |||
| client.create_chat(name=name) | |||
| elif name == "case insensitive": | |||
| client.create_chat(name=name.upper()) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_chat(name=name) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant = client.create_chat(name=name) | |||
| assert chat_assistant.name == name | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "dataset_ids, expected_message", | |||
| [ | |||
| ([], ""), | |||
| (lambda r: [r], ""), | |||
| (["invalid_dataset_id"], "You don't own the dataset invalid_dataset_id"), | |||
| ("invalid_dataset_id", "You don't own the dataset i"), | |||
| ], | |||
| ) | |||
| def test_dataset_ids(self, client, add_chunks, dataset_ids, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| if callable(dataset_ids): | |||
| dataset_ids = dataset_ids(dataset.id) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_chat(name="ragflow test", dataset_ids=dataset_ids) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant = client.create_chat(name="ragflow test", dataset_ids=dataset_ids) | |||
| assert chat_assistant.name == "ragflow test" | |||
| @pytest.mark.p3 | |||
| def test_avatar(self, client, tmp_path): | |||
| fn = create_image_file(tmp_path / "ragflow_test.png") | |||
| chat_assistant = client.create_chat(name="avatar_test", avatar=encode_avatar(fn), dataset_ids=[]) | |||
| assert chat_assistant.name == "avatar_test" | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "llm, expected_message", | |||
| [ | |||
| ({}, ""), | |||
| ({"model_name": "glm-4"}, ""), | |||
| ({"model_name": "unknown"}, "`model_name` unknown doesn't exist"), | |||
| ({"temperature": 0}, ""), | |||
| ({"temperature": 1}, ""), | |||
| pytest.param({"temperature": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"temperature": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"temperature": "a"}, "", marks=pytest.mark.skip), | |||
| ({"top_p": 0}, ""), | |||
| ({"top_p": 1}, ""), | |||
| pytest.param({"top_p": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_p": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_p": "a"}, "", marks=pytest.mark.skip), | |||
| ({"presence_penalty": 0}, ""), | |||
| ({"presence_penalty": 1}, ""), | |||
| pytest.param({"presence_penalty": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"presence_penalty": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"presence_penalty": "a"}, "", marks=pytest.mark.skip), | |||
| ({"frequency_penalty": 0}, ""), | |||
| ({"frequency_penalty": 1}, ""), | |||
| pytest.param({"frequency_penalty": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"frequency_penalty": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"frequency_penalty": "a"}, "", marks=pytest.mark.skip), | |||
| ({"max_token": 0}, ""), | |||
| ({"max_token": 1024}, ""), | |||
| pytest.param({"max_token": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"max_token": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"max_token": "a"}, "", marks=pytest.mark.skip), | |||
| pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_llm(self, client, add_chunks, llm, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| llm_o = Chat.LLM(client, llm) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant = client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o) | |||
| if llm: | |||
| for k, v in llm.items(): | |||
| assert attrgetter(k)(chat_assistant.llm) == v | |||
| else: | |||
| assert attrgetter("model_name")(chat_assistant.llm) == "glm-4-flash@ZHIPU-AI" | |||
| assert attrgetter("temperature")(chat_assistant.llm) == 0.1 | |||
| assert attrgetter("top_p")(chat_assistant.llm) == 0.3 | |||
| assert attrgetter("presence_penalty")(chat_assistant.llm) == 0.4 | |||
| assert attrgetter("frequency_penalty")(chat_assistant.llm) == 0.7 | |||
| assert attrgetter("max_tokens")(chat_assistant.llm) == 512 | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "prompt, expected_message", | |||
| [ | |||
| ({"similarity_threshold": 0}, ""), | |||
| ({"similarity_threshold": 1}, ""), | |||
| pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip), | |||
| ({"keywords_similarity_weight": 0}, ""), | |||
| ({"keywords_similarity_weight": 1}, ""), | |||
| pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip), | |||
| ({"variables": []}, ""), | |||
| ({"top_n": 0}, ""), | |||
| ({"top_n": 1}, ""), | |||
| pytest.param({"top_n": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_n": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip), | |||
| ({"empty_response": "Hello World"}, ""), | |||
| ({"empty_response": ""}, ""), | |||
| ({"empty_response": "!@#$%^&*()"}, ""), | |||
| ({"empty_response": "中文测试"}, ""), | |||
| pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip), | |||
| pytest.param({"empty_response": True}, "", marks=pytest.mark.skip), | |||
| pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip), | |||
| ({"opener": "Hello World"}, ""), | |||
| ({"opener": ""}, ""), | |||
| ({"opener": "!@#$%^&*()"}, ""), | |||
| ({"opener": "中文测试"}, ""), | |||
| pytest.param({"opener": 123}, "", marks=pytest.mark.skip), | |||
| pytest.param({"opener": True}, "", marks=pytest.mark.skip), | |||
| pytest.param({"opener": " "}, "", marks=pytest.mark.skip), | |||
| ({"show_quote": True}, ""), | |||
| ({"show_quote": False}, ""), | |||
| ({"prompt": "Hello World {knowledge}"}, ""), | |||
| ({"prompt": "{knowledge}"}, ""), | |||
| ({"prompt": "!@#$%^&*() {knowledge}"}, ""), | |||
| ({"prompt": "中文测试 {knowledge}"}, ""), | |||
| ({"prompt": "Hello World"}, ""), | |||
| ({"prompt": "Hello World", "variables": []}, ""), | |||
| pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), | |||
| pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), | |||
| pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_prompt(self, client, add_chunks, prompt, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| prompt_o = Chat.Prompt(client, prompt) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant = client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o) | |||
| if prompt: | |||
| for k, v in prompt.items(): | |||
| if k == "keywords_similarity_weight": | |||
| assert attrgetter(k)(chat_assistant.prompt) == 1 - v | |||
| else: | |||
| assert attrgetter(k)(chat_assistant.prompt) == v | |||
| else: | |||
| assert attrgetter("similarity_threshold")(chat_assistant.prompt) == 0.2 | |||
| assert attrgetter("keywords_similarity_weight")(chat_assistant.prompt) == 0.7 | |||
| assert attrgetter("top_n")(chat_assistant.prompt) == 6 | |||
| assert attrgetter("variables")(chat_assistant.prompt) == [{"key": "knowledge", "optional": False}] | |||
| assert attrgetter("rerank_model")(chat_assistant.prompt) == "" | |||
| assert attrgetter("empty_response")(chat_assistant.prompt) == "Sorry! No relevant content was found in the knowledge base!" | |||
| assert attrgetter("opener")(chat_assistant.prompt) == "Hi! I'm your assistant, what can I do for you?" | |||
| assert attrgetter("show_quote")(chat_assistant.prompt) is True | |||
| assert ( | |||
| attrgetter("prompt")(chat_assistant.prompt) | |||
| == 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.' | |||
| ) | |||
| class TestChatAssistantCreate2: | |||
| @pytest.mark.p2 | |||
| def test_unparsed_document(self, client, add_document): | |||
| dataset, _ = add_document | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_chat(name="prompt_test", dataset_ids=[dataset.id]) | |||
| assert "doesn't own parsed file" in str(excinfo.value) | |||
| @@ -0,0 +1,105 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import batch_create_chat_assistants | |||
| class TestChatAssistantsDelete: | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message, remaining", | |||
| [ | |||
| pytest.param(None, "", 0, marks=pytest.mark.p3), | |||
| pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), | |||
| pytest.param({"ids": ["invalid_id"]}, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3), | |||
| pytest.param({"ids": ["\n!?。;!?\"'"]}, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"ids": r}, "", 0, marks=pytest.mark.p1), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, client, add_chat_assistants_func, payload, expected_message, remaining): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| if callable(payload): | |||
| payload = payload([chat_assistant.id for chat_assistant in chat_assistants]) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.delete_chats(**payload) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| if payload is None: | |||
| client.delete_chats(payload) | |||
| else: | |||
| client.delete_chats(**payload) | |||
| assistants = client.list_chats() | |||
| assert len(assistants) == remaining | |||
| @pytest.mark.parametrize( | |||
| "payload", | |||
| [ | |||
| pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1), | |||
| pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), | |||
| ], | |||
| ) | |||
| def test_delete_partial_invalid_id(self, client, add_chat_assistants_func, payload): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| payload = payload([chat_assistant.id for chat_assistant in chat_assistants]) | |||
| client.delete_chats(**payload) | |||
| assistants = client.list_chats() | |||
| assert len(assistants) == 0 | |||
| @pytest.mark.p3 | |||
| def test_repeated_deletion(self, client, add_chat_assistants_func): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| chat_ids = [chat.id for chat in chat_assistants] | |||
| client.delete_chats(ids=chat_ids) | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.delete_chats(ids=chat_ids) | |||
| assert "not found" in str(excinfo.value) | |||
| @pytest.mark.p3 | |||
| def test_duplicate_deletion(self, client, add_chat_assistants_func): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| chat_ids = [chat.id for chat in chat_assistants] | |||
| client.delete_chats(ids=chat_ids + chat_ids) | |||
| assistants = client.list_chats() | |||
| assert len(assistants) == 0 | |||
| @pytest.mark.p3 | |||
| def test_concurrent_deletion(self, client): | |||
| count = 100 | |||
| chat_ids = [client.create_chat(name=f"test_{i}").id for i in range(count)] | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(client.delete_chats, ids=[chat_ids[i]]) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count | |||
| assert all(future.exception() is None for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_delete_1k(self, client): | |||
| chat_assistants = batch_create_chat_assistants(client, 1_000) | |||
| client.delete_chats(ids=[chat_assistants.id for chat_assistants in chat_assistants]) | |||
| assistants = client.list_chats() | |||
| assert len(assistants) == 0 | |||
| @@ -0,0 +1,224 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| @pytest.mark.usefixtures("add_chat_assistants") | |||
| class TestChatAssistantsList: | |||
| @pytest.mark.p1 | |||
| def test_default(self, client): | |||
| assistants = client.list_chats() | |||
| assert len(assistants) == 5 | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_page_size, expected_message", | |||
| [ | |||
| ({"page": 0, "page_size": 2}, 2, ""), | |||
| ({"page": 2, "page_size": 2}, 2, ""), | |||
| ({"page": 3, "page_size": 2}, 1, ""), | |||
| ({"page": "3", "page_size": 2}, 0, "not instance of"), | |||
| pytest.param( | |||
| {"page": -1, "page_size": 2}, | |||
| 0, | |||
| "1064", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| pytest.param( | |||
| {"page": "a", "page_size": 2}, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page(self, client, params, expected_page_size, expected_message): | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| assistants = client.list_chats(**params) | |||
| assert len(assistants) == expected_page_size | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_page_size, expected_message", | |||
| [ | |||
| ({"page_size": 0}, 0, ""), | |||
| ({"page_size": 1}, 1, ""), | |||
| ({"page_size": 6}, 5, ""), | |||
| ({"page_size": "1"}, 0, "not instance of"), | |||
| pytest.param( | |||
| {"page_size": -1}, | |||
| 0, | |||
| "1064", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| pytest.param( | |||
| {"page_size": "a"}, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page_size(self, client, params, expected_page_size, expected_message): | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| assistants = client.list_chats(**params) | |||
| assert len(assistants) == expected_page_size | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_message", | |||
| [ | |||
| ({"orderby": "create_time"}, ""), | |||
| ({"orderby": "update_time"}, ""), | |||
| pytest.param({"orderby": "name", "desc": "False"}, "", marks=pytest.mark.skip(reason="issues/5851")), | |||
| pytest.param({"orderby": "unknown"}, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")), | |||
| ], | |||
| ) | |||
| def test_orderby(self, client, params, expected_message): | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| client.list_chats(**params) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_message", | |||
| [ | |||
| ({"desc": None}, "not instance of"), | |||
| ({"desc": "true"}, "not instance of"), | |||
| ({"desc": "True"}, "not instance of"), | |||
| ({"desc": True}, ""), | |||
| ({"desc": "false"}, "not instance of"), | |||
| ({"desc": "False"}, "not instance of"), | |||
| ({"desc": False}, ""), | |||
| ({"desc": "False", "orderby": "update_time"}, "not instance of"), | |||
| pytest.param( | |||
| {"desc": "unknown"}, | |||
| "desc should be true or false", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| ], | |||
| ) | |||
| def test_desc(self, client, params, expected_message): | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| client.list_chats(**params) | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_num, expected_message", | |||
| [ | |||
| ({"name": None}, 5, ""), | |||
| ({"name": ""}, 5, ""), | |||
| ({"name": "test_chat_assistant_1"}, 1, ""), | |||
| ({"name": "unknown"}, 0, "The chat doesn't exist"), | |||
| ], | |||
| ) | |||
| def test_name(self, client, params, expected_num, expected_message): | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| assistants = client.list_chats(**params) | |||
| if params["name"] in [None, ""]: | |||
| assert len(assistants) == expected_num | |||
| else: | |||
| assert assistants[0].name == params["name"] | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "chat_assistant_id, expected_num, expected_message", | |||
| [ | |||
| (None, 5, ""), | |||
| ("", 5, ""), | |||
| (lambda r: r[0], 1, ""), | |||
| ("unknown", 0, "The chat doesn't exist"), | |||
| ], | |||
| ) | |||
| def test_id(self, client, add_chat_assistants, chat_assistant_id, expected_num, expected_message): | |||
| _, _, chat_assistants = add_chat_assistants | |||
| if callable(chat_assistant_id): | |||
| params = {"id": chat_assistant_id([chat.id for chat in chat_assistants])} | |||
| else: | |||
| params = {"id": chat_assistant_id} | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| assistants = client.list_chats(**params) | |||
| if params["id"] in [None, ""]: | |||
| assert len(assistants) == expected_num | |||
| else: | |||
| assert assistants[0].id == params["id"] | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "chat_assistant_id, name, expected_num, expected_message", | |||
| [ | |||
| (lambda r: r[0], "test_chat_assistant_0", 1, ""), | |||
| (lambda r: r[0], "test_chat_assistant_1", 0, "The chat doesn't exist"), | |||
| (lambda r: r[0], "unknown", 0, "The chat doesn't exist"), | |||
| ("id", "chat_assistant_0", 0, "The chat doesn't exist"), | |||
| ], | |||
| ) | |||
| def test_name_and_id(self, client, add_chat_assistants, chat_assistant_id, name, expected_num, expected_message): | |||
| _, _, chat_assistants = add_chat_assistants | |||
| if callable(chat_assistant_id): | |||
| params = {"id": chat_assistant_id([chat.id for chat in chat_assistants]), "name": name} | |||
| else: | |||
| params = {"id": chat_assistant_id, "name": name} | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.list_chats(**params) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| assistants = client.list_chats(**params) | |||
| assert len(assistants) == expected_num | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, client): | |||
| count = 100 | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(client.list_chats) for _ in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| @pytest.mark.p2 | |||
| def test_list_chats_after_deleting_associated_dataset(self, client, add_chat_assistants): | |||
| dataset, _, _ = add_chat_assistants | |||
| client.delete_datasets(ids=[dataset.id]) | |||
| assistants = client.list_chats() | |||
| assert len(assistants) == 5 | |||
| @@ -0,0 +1,208 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from operator import attrgetter | |||
| import pytest | |||
| from configs import CHAT_ASSISTANT_NAME_LIMIT | |||
| from ragflow_sdk import Chat | |||
| from utils import encode_avatar | |||
| from utils.file_utils import create_image_file | |||
| class TestChatAssistantUpdate: | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| pytest.param({"name": "valid_name"}, "", marks=pytest.mark.p1), | |||
| pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, "", marks=pytest.mark.skip(reason="issues/")), | |||
| pytest.param({"name": 1}, "", marks=pytest.mark.skip(reason="issues/")), | |||
| pytest.param({"name": ""}, "`name` cannot be empty.", marks=pytest.mark.p3), | |||
| pytest.param({"name": "test_chat_assistant_1"}, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), | |||
| pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, "Duplicated chat name in updating chat.", marks=pytest.mark.p3), | |||
| ], | |||
| ) | |||
| def test_name(self, client, add_chat_assistants_func, payload, expected_message): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| chat_assistant = chat_assistants[0] | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chat_assistant.update(payload) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant.update(payload) | |||
| updated_chat = client.list_chats(id=chat_assistant.id)[0] | |||
| assert updated_chat.name == payload["name"], str(updated_chat) | |||
| @pytest.mark.p3 | |||
| def test_avatar(self, client, add_chat_assistants_func, tmp_path): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| chat_assistant = chat_assistants[0] | |||
| fn = create_image_file(tmp_path / "ragflow_test.png") | |||
| payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": chat_assistant.dataset_ids} | |||
| chat_assistant.update(payload) | |||
| updated_chat = client.list_chats(id=chat_assistant.id)[0] | |||
| assert updated_chat.name == payload["name"], str(updated_chat) | |||
| assert updated_chat.avatar is not None, str(updated_chat) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "llm, expected_message", | |||
| [ | |||
| ({}, "ValueError"), | |||
| ({"model_name": "glm-4"}, ""), | |||
| ({"model_name": "unknown"}, "`model_name` unknown doesn't exist"), | |||
| ({"temperature": 0}, ""), | |||
| ({"temperature": 1}, ""), | |||
| pytest.param({"temperature": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"temperature": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"temperature": "a"}, "", marks=pytest.mark.skip), | |||
| ({"top_p": 0}, ""), | |||
| ({"top_p": 1}, ""), | |||
| pytest.param({"top_p": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_p": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_p": "a"}, "", marks=pytest.mark.skip), | |||
| ({"presence_penalty": 0}, ""), | |||
| ({"presence_penalty": 1}, ""), | |||
| pytest.param({"presence_penalty": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"presence_penalty": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"presence_penalty": "a"}, "", marks=pytest.mark.skip), | |||
| ({"frequency_penalty": 0}, ""), | |||
| ({"frequency_penalty": 1}, ""), | |||
| pytest.param({"frequency_penalty": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"frequency_penalty": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"frequency_penalty": "a"}, "", marks=pytest.mark.skip), | |||
| ({"max_token": 0}, ""), | |||
| ({"max_token": 1024}, ""), | |||
| pytest.param({"max_token": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"max_token": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"max_token": "a"}, "", marks=pytest.mark.skip), | |||
| pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_llm(self, client, add_chat_assistants_func, llm, expected_message): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| chat_assistant = chat_assistants[0] | |||
| payload = {"name": "llm_test", "dataset_ids": chat_assistant.dataset_ids, "llm": llm} | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chat_assistant.update(payload) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant.update(payload) | |||
| updated_chat = client.list_chats(id=chat_assistant.id)[0] | |||
| if llm: | |||
| for k, v in llm.items(): | |||
| assert attrgetter(k)(updated_chat.llm) == v, str(updated_chat) | |||
| else: | |||
| excepted_value = Chat.LLM( | |||
| client, | |||
| { | |||
| "model_name": "glm-4-flash@ZHIPU-AI", | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "presence_penalty": 0.4, | |||
| "frequency_penalty": 0.7, | |||
| "max_tokens": 512, | |||
| }, | |||
| ) | |||
| assert str(updated_chat.llm) == str(excepted_value), str(updated_chat) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "prompt, expected_message", | |||
| [ | |||
| ({}, "ValueError"), | |||
| ({"similarity_threshold": 0}, ""), | |||
| ({"similarity_threshold": 1}, ""), | |||
| pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip), | |||
| ({"keywords_similarity_weight": 0}, ""), | |||
| ({"keywords_similarity_weight": 1}, ""), | |||
| pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip), | |||
| ({"variables": []}, ""), | |||
| ({"top_n": 0}, ""), | |||
| ({"top_n": 1}, ""), | |||
| pytest.param({"top_n": -1}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_n": 10}, "", marks=pytest.mark.skip), | |||
| pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip), | |||
| ({"empty_response": "Hello World"}, ""), | |||
| ({"empty_response": ""}, ""), | |||
| ({"empty_response": "!@#$%^&*()"}, ""), | |||
| ({"empty_response": "中文测试"}, ""), | |||
| pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip), | |||
| pytest.param({"empty_response": True}, "", marks=pytest.mark.skip), | |||
| pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip), | |||
| ({"opener": "Hello World"}, ""), | |||
| ({"opener": ""}, ""), | |||
| ({"opener": "!@#$%^&*()"}, ""), | |||
| ({"opener": "中文测试"}, ""), | |||
| pytest.param({"opener": 123}, "", marks=pytest.mark.skip), | |||
| pytest.param({"opener": True}, "", marks=pytest.mark.skip), | |||
| pytest.param({"opener": " "}, "", marks=pytest.mark.skip), | |||
| ({"show_quote": True}, ""), | |||
| ({"show_quote": False}, ""), | |||
| ({"prompt": "Hello World {knowledge}"}, ""), | |||
| ({"prompt": "{knowledge}"}, ""), | |||
| ({"prompt": "!@#$%^&*() {knowledge}"}, ""), | |||
| ({"prompt": "中文测试 {knowledge}"}, ""), | |||
| ({"prompt": "Hello World"}, ""), | |||
| ({"prompt": "Hello World", "variables": []}, ""), | |||
| pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), | |||
| pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip), | |||
| pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_prompt(self, client, add_chat_assistants_func, prompt, expected_message): | |||
| _, _, chat_assistants = add_chat_assistants_func | |||
| chat_assistant = chat_assistants[0] | |||
| payload = {"name": "prompt_test", "dataset_ids": chat_assistant.dataset_ids, "prompt": prompt} | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chat_assistant.update(payload) | |||
| assert expected_message in str(excinfo.value) | |||
| else: | |||
| chat_assistant.update(payload) | |||
| updated_chat = client.list_chats(id=chat_assistant.id)[0] | |||
| if prompt: | |||
| for k, v in prompt.items(): | |||
| if k == "keywords_similarity_weight": | |||
| assert attrgetter(k)(updated_chat.prompt) == 1 - v, str(updated_chat) | |||
| else: | |||
| assert attrgetter(k)(updated_chat.prompt) == v, str(updated_chat) | |||
| else: | |||
| excepted_value = Chat.LLM( | |||
| client, | |||
| { | |||
| "similarity_threshold": 0.2, | |||
| "keywords_similarity_weight": 0.7, | |||
| "top_n": 6, | |||
| "variables": [{"key": "knowledge", "optional": False}], | |||
| "rerank_model": "", | |||
| "empty_response": "Sorry! No relevant content was found in the knowledge base!", | |||
| "opener": "Hi! I'm your assistant, what can I do for you?", | |||
| "show_quote": True, | |||
| "prompt": 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.', | |||
| }, | |||
| ) | |||
| assert str(updated_chat.prompt) == str(excepted_value), str(updated_chat) | |||