瀏覽代碼

Test: Add SDK API tests for chat assistant management and improve con… (#8131)

### What problem does this PR solve?

- Implement new SDK API test cases for chat assistant CRUD operations
- Enhance HTTP API concurrent tests to use as_completed for better
reliability

### Type of change

- [x] Add test cases
- [x] Refactoring
tags/v0.19.1
Liu An 4 月之前
父節點
當前提交
4649accd54
沒有連結到貢獻者的電子郵件帳戶。

+ 8
- 6
test/testcases/test_http_api/test_chat_assistant_management/test_delete_chat_assistants.py 查看文件

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import INVALID_API_TOKEN, batch_create_chat_assistants, delete_chat_assistants, list_chat_assistants
@@ -107,16 +107,18 @@ class TestChatAssistantsDelete:

@pytest.mark.p3
def test_concurrent_deletion(self, api_key):
ids = batch_create_chat_assistants(api_key, 100)
count = 100
ids = batch_create_chat_assistants(api_key, count)

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_chat_assistants, api_key, {"ids": ids[i : i + 1]}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
futures = [executor.submit(delete_chat_assistants, api_key, {"ids": ids[i : i + 1]}) for i in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)

@pytest.mark.p3
def test_delete_10k(self, api_key):
ids = batch_create_chat_assistants(api_key, 10_000)
ids = batch_create_chat_assistants(api_key, 1_000)
res = delete_chat_assistants(api_key, {"ids": ids})
assert res["code"] == 0


+ 6
- 4
test/testcases/test_http_api/test_chat_assistant_management/test_list_chat_assistants.py 查看文件

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import INVALID_API_TOKEN, delete_datasets, list_chat_assistants
@@ -288,10 +288,12 @@ class TestChatAssistantsList:

@pytest.mark.p3
def test_concurrent_list(self, api_key):
count = 100
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_chat_assistants, api_key) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
futures = [executor.submit(list_chat_assistants, api_key) for i in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)

@pytest.mark.p3
def test_invalid_params(self, api_key):

+ 5
- 0
test/testcases/test_sdk_api/common.py 查看文件

@@ -40,3 +40,8 @@ def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Do
# CHUNK MANAGEMENT WITHIN DATASET
def batch_add_chunks(document: Document, num: int):
return [document.add_chunk(content=f"chunk test {i}") for i in range(num)]


# CHAT ASSISTANT MANAGEMENT
def batch_create_chat_assistants(client: RAGFlow, num: int):
return [client.create_chat(name=f"test_chat_assistant_{i}") for i in range(num)]

+ 27
- 0
test/testcases/test_sdk_api/conftest.py 查看文件

@@ -90,6 +90,14 @@ def clear_datasets(request: FixtureRequest, client: RAGFlow):
request.addfinalizer(cleanup)


@pytest.fixture(scope="function")
def clear_chat_assistants(request: FixtureRequest, client: RAGFlow):
def cleanup():
client.delete_chats(ids=None)

request.addfinalizer(cleanup)


@pytest.fixture(scope="class")
def add_dataset(request: FixtureRequest, client: RAGFlow):
def cleanup():
@@ -137,3 +145,22 @@ def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document])

request.addfinalizer(cleanup)
return dataset, document, chunks


@pytest.fixture(scope="class")
def add_chat_assistants(request, client, add_document):
def cleanup():
client.delete_chats(ids=None)

request.addfinalizer(cleanup)

dataset, document = add_document
dataset.async_parse_documents([document.id])
condition(dataset)

chat_assistants = []
for i in range(5):
chat_assistant = client.create_chat(name=f"test_chat_assistant_{i}", dataset_ids=[dataset.id])
chat_assistants.append(chat_assistant)

return dataset, document, chat_assistants

+ 47
- 0
test/testcases/test_sdk_api/test_chat_assistant_management/conftest.py 查看文件

@@ -0,0 +1,47 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from pytest import FixtureRequest
from ragflow_sdk import Chat, DataSet, Document, RAGFlow
from utils import wait_for


@wait_for(30, 1, "Document parsing timeout")
def condition(_dataset: DataSet):
documents = _dataset.list_documents(page_size=1000)
for document in documents:
if document.run != "DONE":
return False
return True


@pytest.fixture(scope="function")
def add_chat_assistants_func(request: FixtureRequest, client: RAGFlow, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chat]]:
def cleanup():
client.delete_chats(ids=None)

request.addfinalizer(cleanup)

dataset, document = add_document
dataset.async_parse_documents([document.id])
condition(dataset)

chat_assistants = []
for i in range(5):
chat_assistant = client.create_chat(name=f"test_chat_assistant_{i}", dataset_ids=[dataset.id])
chat_assistants.append(chat_assistant)

return dataset, document, chat_assistants

+ 224
- 0
test/testcases/test_sdk_api/test_chat_assistant_management/test_create_chat_assistant.py 查看文件

@@ -0,0 +1,224 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from operator import attrgetter

import pytest
from configs import CHAT_ASSISTANT_NAME_LIMIT
from ragflow_sdk import Chat
from utils import encode_avatar
from utils.file_utils import create_image_file


@pytest.mark.usefixtures("clear_chat_assistants")
class TestChatAssistantCreate:
@pytest.mark.p1
@pytest.mark.usefixtures("add_chunks")
@pytest.mark.parametrize(
"name, expected_message",
[
("valid_name", ""),
pytest.param("a" * (CHAT_ASSISTANT_NAME_LIMIT + 1), "", marks=pytest.mark.skip(reason="issues/")),
pytest.param(1, "", marks=pytest.mark.skip(reason="issues/")),
("", "`name` is required."),
("duplicated_name", "Duplicated chat name in creating chat."),
("case insensitive", "Duplicated chat name in creating chat."),
],
)
def test_name(self, client, name, expected_message):
if name == "duplicated_name":
client.create_chat(name=name)
elif name == "case insensitive":
client.create_chat(name=name.upper())

if expected_message:
with pytest.raises(Exception) as excinfo:
client.create_chat(name=name)
assert expected_message in str(excinfo.value)
else:
chat_assistant = client.create_chat(name=name)
assert chat_assistant.name == name

@pytest.mark.p1
@pytest.mark.parametrize(
"dataset_ids, expected_message",
[
([], ""),
(lambda r: [r], ""),
(["invalid_dataset_id"], "You don't own the dataset invalid_dataset_id"),
("invalid_dataset_id", "You don't own the dataset i"),
],
)
def test_dataset_ids(self, client, add_chunks, dataset_ids, expected_message):
dataset, _, _ = add_chunks
if callable(dataset_ids):
dataset_ids = dataset_ids(dataset.id)

if expected_message:
with pytest.raises(Exception) as excinfo:
client.create_chat(name="ragflow test", dataset_ids=dataset_ids)
assert expected_message in str(excinfo.value)
else:
chat_assistant = client.create_chat(name="ragflow test", dataset_ids=dataset_ids)
assert chat_assistant.name == "ragflow test"

@pytest.mark.p3
def test_avatar(self, client, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
chat_assistant = client.create_chat(name="avatar_test", avatar=encode_avatar(fn), dataset_ids=[])
assert chat_assistant.name == "avatar_test"

@pytest.mark.p2
@pytest.mark.parametrize(
"llm, expected_message",
[
({}, ""),
({"model_name": "glm-4"}, ""),
({"model_name": "unknown"}, "`model_name` unknown doesn't exist"),
({"temperature": 0}, ""),
({"temperature": 1}, ""),
pytest.param({"temperature": -1}, "", marks=pytest.mark.skip),
pytest.param({"temperature": 10}, "", marks=pytest.mark.skip),
pytest.param({"temperature": "a"}, "", marks=pytest.mark.skip),
({"top_p": 0}, ""),
({"top_p": 1}, ""),
pytest.param({"top_p": -1}, "", marks=pytest.mark.skip),
pytest.param({"top_p": 10}, "", marks=pytest.mark.skip),
pytest.param({"top_p": "a"}, "", marks=pytest.mark.skip),
({"presence_penalty": 0}, ""),
({"presence_penalty": 1}, ""),
pytest.param({"presence_penalty": -1}, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": 10}, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": "a"}, "", marks=pytest.mark.skip),
({"frequency_penalty": 0}, ""),
({"frequency_penalty": 1}, ""),
pytest.param({"frequency_penalty": -1}, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": 10}, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": "a"}, "", marks=pytest.mark.skip),
({"max_token": 0}, ""),
({"max_token": 1024}, ""),
pytest.param({"max_token": -1}, "", marks=pytest.mark.skip),
pytest.param({"max_token": 10}, "", marks=pytest.mark.skip),
pytest.param({"max_token": "a"}, "", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip),
],
)
def test_llm(self, client, add_chunks, llm, expected_message):
dataset, _, _ = add_chunks
llm_o = Chat.LLM(client, llm)

if expected_message:
with pytest.raises(Exception) as excinfo:
client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o)
assert expected_message in str(excinfo.value)
else:
chat_assistant = client.create_chat(name="llm_test", dataset_ids=[dataset.id], llm=llm_o)
if llm:
for k, v in llm.items():
assert attrgetter(k)(chat_assistant.llm) == v
else:
assert attrgetter("model_name")(chat_assistant.llm) == "glm-4-flash@ZHIPU-AI"
assert attrgetter("temperature")(chat_assistant.llm) == 0.1
assert attrgetter("top_p")(chat_assistant.llm) == 0.3
assert attrgetter("presence_penalty")(chat_assistant.llm) == 0.4
assert attrgetter("frequency_penalty")(chat_assistant.llm) == 0.7
assert attrgetter("max_tokens")(chat_assistant.llm) == 512

@pytest.mark.p2
@pytest.mark.parametrize(
"prompt, expected_message",
[
({"similarity_threshold": 0}, ""),
({"similarity_threshold": 1}, ""),
pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip),
({"keywords_similarity_weight": 0}, ""),
({"keywords_similarity_weight": 1}, ""),
pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip),
({"variables": []}, ""),
({"top_n": 0}, ""),
({"top_n": 1}, ""),
pytest.param({"top_n": -1}, "", marks=pytest.mark.skip),
pytest.param({"top_n": 10}, "", marks=pytest.mark.skip),
pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip),
({"empty_response": "Hello World"}, ""),
({"empty_response": ""}, ""),
({"empty_response": "!@#$%^&*()"}, ""),
({"empty_response": "中文测试"}, ""),
pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip),
pytest.param({"empty_response": True}, "", marks=pytest.mark.skip),
pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip),
({"opener": "Hello World"}, ""),
({"opener": ""}, ""),
({"opener": "!@#$%^&*()"}, ""),
({"opener": "中文测试"}, ""),
pytest.param({"opener": 123}, "", marks=pytest.mark.skip),
pytest.param({"opener": True}, "", marks=pytest.mark.skip),
pytest.param({"opener": " "}, "", marks=pytest.mark.skip),
({"show_quote": True}, ""),
({"show_quote": False}, ""),
({"prompt": "Hello World {knowledge}"}, ""),
({"prompt": "{knowledge}"}, ""),
({"prompt": "!@#$%^&*() {knowledge}"}, ""),
({"prompt": "中文测试 {knowledge}"}, ""),
({"prompt": "Hello World"}, ""),
({"prompt": "Hello World", "variables": []}, ""),
pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip),
],
)
def test_prompt(self, client, add_chunks, prompt, expected_message):
dataset, _, _ = add_chunks
prompt_o = Chat.Prompt(client, prompt)

if expected_message:
with pytest.raises(Exception) as excinfo:
client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o)
assert expected_message in str(excinfo.value)
else:
chat_assistant = client.create_chat(name="prompt_test", dataset_ids=[dataset.id], prompt=prompt_o)
if prompt:
for k, v in prompt.items():
if k == "keywords_similarity_weight":
assert attrgetter(k)(chat_assistant.prompt) == 1 - v
else:
assert attrgetter(k)(chat_assistant.prompt) == v
else:
assert attrgetter("similarity_threshold")(chat_assistant.prompt) == 0.2
assert attrgetter("keywords_similarity_weight")(chat_assistant.prompt) == 0.7
assert attrgetter("top_n")(chat_assistant.prompt) == 6
assert attrgetter("variables")(chat_assistant.prompt) == [{"key": "knowledge", "optional": False}]
assert attrgetter("rerank_model")(chat_assistant.prompt) == ""
assert attrgetter("empty_response")(chat_assistant.prompt) == "Sorry! No relevant content was found in the knowledge base!"
assert attrgetter("opener")(chat_assistant.prompt) == "Hi! I'm your assistant, what can I do for you?"
assert attrgetter("show_quote")(chat_assistant.prompt) is True
assert (
attrgetter("prompt")(chat_assistant.prompt)
== 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.'
)


class TestChatAssistantCreate2:
@pytest.mark.p2
def test_unparsed_document(self, client, add_document):
dataset, _ = add_document
with pytest.raises(Exception) as excinfo:
client.create_chat(name="prompt_test", dataset_ids=[dataset.id])
assert "doesn't own parsed file" in str(excinfo.value)

+ 105
- 0
test/testcases/test_sdk_api/test_chat_assistant_management/test_delete_chat_assistants.py 查看文件

@@ -0,0 +1,105 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import batch_create_chat_assistants


class TestChatAssistantsDelete:
@pytest.mark.parametrize(
"payload, expected_message, remaining",
[
pytest.param(None, "", 0, marks=pytest.mark.p3),
pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3),
pytest.param({"ids": ["invalid_id"]}, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3),
pytest.param({"ids": ["\n!?。;!?\"'"]}, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r}, "", 0, marks=pytest.mark.p1),
],
)
def test_basic_scenarios(self, client, add_chat_assistants_func, payload, expected_message, remaining):
_, _, chat_assistants = add_chat_assistants_func
if callable(payload):
payload = payload([chat_assistant.id for chat_assistant in chat_assistants])

if expected_message:
with pytest.raises(Exception) as excinfo:
client.delete_chats(**payload)
assert expected_message in str(excinfo.value)
else:
if payload is None:
client.delete_chats(payload)
else:
client.delete_chats(**payload)

assistants = client.list_chats()
assert len(assistants) == remaining

@pytest.mark.parametrize(
"payload",
[
pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1),
pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
],
)
def test_delete_partial_invalid_id(self, client, add_chat_assistants_func, payload):
_, _, chat_assistants = add_chat_assistants_func
payload = payload([chat_assistant.id for chat_assistant in chat_assistants])
client.delete_chats(**payload)

assistants = client.list_chats()
assert len(assistants) == 0

@pytest.mark.p3
def test_repeated_deletion(self, client, add_chat_assistants_func):
_, _, chat_assistants = add_chat_assistants_func
chat_ids = [chat.id for chat in chat_assistants]
client.delete_chats(ids=chat_ids)

with pytest.raises(Exception) as excinfo:
client.delete_chats(ids=chat_ids)
assert "not found" in str(excinfo.value)

@pytest.mark.p3
def test_duplicate_deletion(self, client, add_chat_assistants_func):
_, _, chat_assistants = add_chat_assistants_func
chat_ids = [chat.id for chat in chat_assistants]
client.delete_chats(ids=chat_ids + chat_ids)

assistants = client.list_chats()
assert len(assistants) == 0

@pytest.mark.p3
def test_concurrent_deletion(self, client):
count = 100
chat_ids = [client.create_chat(name=f"test_{i}").id for i in range(count)]

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(client.delete_chats, ids=[chat_ids[i]]) for i in range(count)]
responses = list(as_completed(futures))

assert len(responses) == count
assert all(future.exception() is None for future in futures)

@pytest.mark.p3
def test_delete_1k(self, client):
chat_assistants = batch_create_chat_assistants(client, 1_000)
client.delete_chats(ids=[chat_assistants.id for chat_assistants in chat_assistants])

assistants = client.list_chats()
assert len(assistants) == 0

+ 224
- 0
test/testcases/test_sdk_api/test_chat_assistant_management/test_list_chat_assistants.py 查看文件

@@ -0,0 +1,224 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest


@pytest.mark.usefixtures("add_chat_assistants")
class TestChatAssistantsList:
@pytest.mark.p1
def test_default(self, client):
assistants = client.list_chats()
assert len(assistants) == 5

@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size, expected_message",
[
({"page": 0, "page_size": 2}, 2, ""),
({"page": 2, "page_size": 2}, 2, ""),
({"page": 3, "page_size": 2}, 1, ""),
({"page": "3", "page_size": 2}, 0, "not instance of"),
pytest.param(
{"page": -1, "page_size": 2},
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page": "a", "page_size": 2},
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page(self, client, params, expected_page_size, expected_message):
if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
assistants = client.list_chats(**params)
assert len(assistants) == expected_page_size

@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size, expected_message",
[
({"page_size": 0}, 0, ""),
({"page_size": 1}, 1, ""),
({"page_size": 6}, 5, ""),
({"page_size": "1"}, 0, "not instance of"),
pytest.param(
{"page_size": -1},
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page_size": "a"},
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page_size(self, client, params, expected_page_size, expected_message):
if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
assistants = client.list_chats(**params)
assert len(assistants) == expected_page_size

@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_message",
[
({"orderby": "create_time"}, ""),
({"orderby": "update_time"}, ""),
pytest.param({"orderby": "name", "desc": "False"}, "", marks=pytest.mark.skip(reason="issues/5851")),
pytest.param({"orderby": "unknown"}, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")),
],
)
def test_orderby(self, client, params, expected_message):
if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
client.list_chats(**params)

@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_message",
[
({"desc": None}, "not instance of"),
({"desc": "true"}, "not instance of"),
({"desc": "True"}, "not instance of"),
({"desc": True}, ""),
({"desc": "false"}, "not instance of"),
({"desc": "False"}, "not instance of"),
({"desc": False}, ""),
({"desc": "False", "orderby": "update_time"}, "not instance of"),
pytest.param(
{"desc": "unknown"},
"desc should be true or false",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_desc(self, client, params, expected_message):
if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
client.list_chats(**params)

@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_num, expected_message",
[
({"name": None}, 5, ""),
({"name": ""}, 5, ""),
({"name": "test_chat_assistant_1"}, 1, ""),
({"name": "unknown"}, 0, "The chat doesn't exist"),
],
)
def test_name(self, client, params, expected_num, expected_message):
if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
assistants = client.list_chats(**params)
if params["name"] in [None, ""]:
assert len(assistants) == expected_num
else:
assert assistants[0].name == params["name"]

@pytest.mark.p1
@pytest.mark.parametrize(
"chat_assistant_id, expected_num, expected_message",
[
(None, 5, ""),
("", 5, ""),
(lambda r: r[0], 1, ""),
("unknown", 0, "The chat doesn't exist"),
],
)
def test_id(self, client, add_chat_assistants, chat_assistant_id, expected_num, expected_message):
_, _, chat_assistants = add_chat_assistants
if callable(chat_assistant_id):
params = {"id": chat_assistant_id([chat.id for chat in chat_assistants])}
else:
params = {"id": chat_assistant_id}

if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
assistants = client.list_chats(**params)
if params["id"] in [None, ""]:
assert len(assistants) == expected_num
else:
assert assistants[0].id == params["id"]

@pytest.mark.p3
@pytest.mark.parametrize(
"chat_assistant_id, name, expected_num, expected_message",
[
(lambda r: r[0], "test_chat_assistant_0", 1, ""),
(lambda r: r[0], "test_chat_assistant_1", 0, "The chat doesn't exist"),
(lambda r: r[0], "unknown", 0, "The chat doesn't exist"),
("id", "chat_assistant_0", 0, "The chat doesn't exist"),
],
)
def test_name_and_id(self, client, add_chat_assistants, chat_assistant_id, name, expected_num, expected_message):
_, _, chat_assistants = add_chat_assistants
if callable(chat_assistant_id):
params = {"id": chat_assistant_id([chat.id for chat in chat_assistants]), "name": name}
else:
params = {"id": chat_assistant_id, "name": name}

if expected_message:
with pytest.raises(Exception) as excinfo:
client.list_chats(**params)
assert expected_message in str(excinfo.value)
else:
assistants = client.list_chats(**params)
assert len(assistants) == expected_num

@pytest.mark.p3
def test_concurrent_list(self, client):
count = 100
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(client.list_chats) for _ in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses

@pytest.mark.p2
def test_list_chats_after_deleting_associated_dataset(self, client, add_chat_assistants):
dataset, _, _ = add_chat_assistants
client.delete_datasets(ids=[dataset.id])

assistants = client.list_chats()
assert len(assistants) == 5

+ 208
- 0
test/testcases/test_sdk_api/test_chat_assistant_management/test_update_chat_assistant.py 查看文件

@@ -0,0 +1,208 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from operator import attrgetter

import pytest
from configs import CHAT_ASSISTANT_NAME_LIMIT
from ragflow_sdk import Chat
from utils import encode_avatar
from utils.file_utils import create_image_file


class TestChatAssistantUpdate:
@pytest.mark.parametrize(
"payload, expected_message",
[
pytest.param({"name": "valid_name"}, "", marks=pytest.mark.p1),
pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": 1}, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": ""}, "`name` cannot be empty.", marks=pytest.mark.p3),
pytest.param({"name": "test_chat_assistant_1"}, "Duplicated chat name in updating chat.", marks=pytest.mark.p3),
pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, "Duplicated chat name in updating chat.", marks=pytest.mark.p3),
],
)
def test_name(self, client, add_chat_assistants_func, payload, expected_message):
_, _, chat_assistants = add_chat_assistants_func
chat_assistant = chat_assistants[0]

if expected_message:
with pytest.raises(Exception) as excinfo:
chat_assistant.update(payload)
assert expected_message in str(excinfo.value)
else:
chat_assistant.update(payload)
updated_chat = client.list_chats(id=chat_assistant.id)[0]
assert updated_chat.name == payload["name"], str(updated_chat)

@pytest.mark.p3
def test_avatar(self, client, add_chat_assistants_func, tmp_path):
_, _, chat_assistants = add_chat_assistants_func
chat_assistant = chat_assistants[0]

fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": chat_assistant.dataset_ids}

chat_assistant.update(payload)
updated_chat = client.list_chats(id=chat_assistant.id)[0]
assert updated_chat.name == payload["name"], str(updated_chat)
assert updated_chat.avatar is not None, str(updated_chat)

@pytest.mark.p3
@pytest.mark.parametrize(
"llm, expected_message",
[
({}, "ValueError"),
({"model_name": "glm-4"}, ""),
({"model_name": "unknown"}, "`model_name` unknown doesn't exist"),
({"temperature": 0}, ""),
({"temperature": 1}, ""),
pytest.param({"temperature": -1}, "", marks=pytest.mark.skip),
pytest.param({"temperature": 10}, "", marks=pytest.mark.skip),
pytest.param({"temperature": "a"}, "", marks=pytest.mark.skip),
({"top_p": 0}, ""),
({"top_p": 1}, ""),
pytest.param({"top_p": -1}, "", marks=pytest.mark.skip),
pytest.param({"top_p": 10}, "", marks=pytest.mark.skip),
pytest.param({"top_p": "a"}, "", marks=pytest.mark.skip),
({"presence_penalty": 0}, ""),
({"presence_penalty": 1}, ""),
pytest.param({"presence_penalty": -1}, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": 10}, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": "a"}, "", marks=pytest.mark.skip),
({"frequency_penalty": 0}, ""),
({"frequency_penalty": 1}, ""),
pytest.param({"frequency_penalty": -1}, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": 10}, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": "a"}, "", marks=pytest.mark.skip),
({"max_token": 0}, ""),
({"max_token": 1024}, ""),
pytest.param({"max_token": -1}, "", marks=pytest.mark.skip),
pytest.param({"max_token": 10}, "", marks=pytest.mark.skip),
pytest.param({"max_token": "a"}, "", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip),
],
)
def test_llm(self, client, add_chat_assistants_func, llm, expected_message):
_, _, chat_assistants = add_chat_assistants_func
chat_assistant = chat_assistants[0]
payload = {"name": "llm_test", "dataset_ids": chat_assistant.dataset_ids, "llm": llm}

if expected_message:
with pytest.raises(Exception) as excinfo:
chat_assistant.update(payload)
assert expected_message in str(excinfo.value)
else:
chat_assistant.update(payload)
updated_chat = client.list_chats(id=chat_assistant.id)[0]
if llm:
for k, v in llm.items():
assert attrgetter(k)(updated_chat.llm) == v, str(updated_chat)
else:
excepted_value = Chat.LLM(
client,
{
"model_name": "glm-4-flash@ZHIPU-AI",
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512,
},
)
assert str(updated_chat.llm) == str(excepted_value), str(updated_chat)

@pytest.mark.p3
@pytest.mark.parametrize(
"prompt, expected_message",
[
({}, "ValueError"),
({"similarity_threshold": 0}, ""),
({"similarity_threshold": 1}, ""),
pytest.param({"similarity_threshold": -1}, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": 10}, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": "a"}, "", marks=pytest.mark.skip),
({"keywords_similarity_weight": 0}, ""),
({"keywords_similarity_weight": 1}, ""),
pytest.param({"keywords_similarity_weight": -1}, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": 10}, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": "a"}, "", marks=pytest.mark.skip),
({"variables": []}, ""),
({"top_n": 0}, ""),
({"top_n": 1}, ""),
pytest.param({"top_n": -1}, "", marks=pytest.mark.skip),
pytest.param({"top_n": 10}, "", marks=pytest.mark.skip),
pytest.param({"top_n": "a"}, "", marks=pytest.mark.skip),
({"empty_response": "Hello World"}, ""),
({"empty_response": ""}, ""),
({"empty_response": "!@#$%^&*()"}, ""),
({"empty_response": "中文测试"}, ""),
pytest.param({"empty_response": 123}, "", marks=pytest.mark.skip),
pytest.param({"empty_response": True}, "", marks=pytest.mark.skip),
pytest.param({"empty_response": " "}, "", marks=pytest.mark.skip),
({"opener": "Hello World"}, ""),
({"opener": ""}, ""),
({"opener": "!@#$%^&*()"}, ""),
({"opener": "中文测试"}, ""),
pytest.param({"opener": 123}, "", marks=pytest.mark.skip),
pytest.param({"opener": True}, "", marks=pytest.mark.skip),
pytest.param({"opener": " "}, "", marks=pytest.mark.skip),
({"show_quote": True}, ""),
({"show_quote": False}, ""),
({"prompt": "Hello World {knowledge}"}, ""),
({"prompt": "{knowledge}"}, ""),
({"prompt": "!@#$%^&*() {knowledge}"}, ""),
({"prompt": "中文测试 {knowledge}"}, ""),
({"prompt": "Hello World"}, ""),
({"prompt": "Hello World", "variables": []}, ""),
pytest.param({"prompt": 123}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"prompt": True}, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, "", marks=pytest.mark.skip),
],
)
def test_prompt(self, client, add_chat_assistants_func, prompt, expected_message):
_, _, chat_assistants = add_chat_assistants_func
chat_assistant = chat_assistants[0]
payload = {"name": "prompt_test", "dataset_ids": chat_assistant.dataset_ids, "prompt": prompt}

if expected_message:
with pytest.raises(Exception) as excinfo:
chat_assistant.update(payload)
assert expected_message in str(excinfo.value)
else:
chat_assistant.update(payload)
updated_chat = client.list_chats(id=chat_assistant.id)[0]
if prompt:
for k, v in prompt.items():
if k == "keywords_similarity_weight":
assert attrgetter(k)(updated_chat.prompt) == 1 - v, str(updated_chat)
else:
assert attrgetter(k)(updated_chat.prompt) == v, str(updated_chat)
else:
excepted_value = Chat.LLM(
client,
{
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.7,
"top_n": 6,
"variables": [{"key": "knowledge", "optional": False}],
"rerank_model": "",
"empty_response": "Sorry! No relevant content was found in the knowledge base!",
"opener": "Hi! I'm your assistant, what can I do for you?",
"show_quote": True,
"prompt": 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.',
},
)
assert str(updated_chat.prompt) == str(excepted_value), str(updated_chat)

Loading…
取消
儲存