### What problem does this PR solve? Test cases about dataset ### Type of change - [x] Other (please describe): test cases --------- Signed-off-by: jinhai <haijin.chn@gmail.com>tags/v0.14.1
| @@ -29,6 +29,7 @@ from api.db.db_models import File | |||
| from api.utils.api_utils import get_json_result | |||
| from api import settings | |||
| from rag.nlp import search | |||
| from api.constants import DATASET_NAME_LIMIT | |||
| @manager.route('/create', methods=['post']) | |||
| @@ -36,10 +37,19 @@ from rag.nlp import search | |||
| @validate_request("name") | |||
| def create(): | |||
| req = request.json | |||
| req["name"] = req["name"].strip() | |||
| req["name"] = duplicate_name( | |||
| dataset_name = req["name"] | |||
| if not isinstance(dataset_name, str): | |||
| return get_data_error_result(message="Dataset name must be string.") | |||
| if dataset_name == "": | |||
| return get_data_error_result(message="Dataset name can't be empty.") | |||
| if len(dataset_name) >= DATASET_NAME_LIMIT: | |||
| return get_data_error_result( | |||
| message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}") | |||
| dataset_name = dataset_name.strip() | |||
| dataset_name = duplicate_name( | |||
| KnowledgebaseService.query, | |||
| name=req["name"], | |||
| name=dataset_name, | |||
| tenant_id=current_user.id, | |||
| status=StatusEnum.VALID.value) | |||
| try: | |||
| @@ -73,7 +83,8 @@ def update(): | |||
| if not KnowledgebaseService.query( | |||
| created_by=current_user.id, id=req["kb_id"]): | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) | |||
| if not e: | |||
| @@ -81,7 +92,8 @@ def update(): | |||
| message="Can't find this knowledgebase!") | |||
| if req["name"].lower() != kb.name.lower() \ | |||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: | |||
| and len( | |||
| KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: | |||
| return get_data_error_result( | |||
| message="Duplicated knowledgebase name.") | |||
| @@ -152,10 +164,11 @@ def rm(): | |||
| ) | |||
| try: | |||
| kbs = KnowledgebaseService.query( | |||
| created_by=current_user.id, id=req["kb_id"]) | |||
| created_by=current_user.id, id=req["kb_id"]) | |||
| if not kbs: | |||
| return get_json_result( | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) | |||
| data=False, message='Only owner of knowledgebase authorized for this operation.', | |||
| code=settings.RetCode.OPERATING_ERROR) | |||
| for doc in DocumentService.query(kb_id=req["kb_id"]): | |||
| if not DocumentService.remove_document(doc, kbs[0].tenant_id): | |||
| @@ -23,3 +23,5 @@ API_VERSION = "v1" | |||
| RAG_FLOW_SERVICE_NAME = "ragflow" | |||
| REQUEST_WAIT_SEC = 2 | |||
| REQUEST_MAX_WAIT_SEC = 300 | |||
| DATASET_NAME_LIMIT = 128 | |||
| @@ -310,7 +310,9 @@ class InfinityConnection(DocStoreConnection): | |||
| table_name = f"{indexName}_{knowledgebaseId}" | |||
| table_instance = db_instance.get_table(table_name) | |||
| kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl() | |||
| df_list.append(kb_res) | |||
| if len(kb_res) != 0 and kb_res.shape[0] > 0: | |||
| df_list.append(kb_res) | |||
| self.connPool.release_conn(inf_conn) | |||
| res = concat_dataframes(df_list, ["id"]) | |||
| res_fields = self.getFields(res, res.columns) | |||
| @@ -3,6 +3,8 @@ import requests | |||
| HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380') | |||
| DATASET_NAME_LIMIT = 128 | |||
| def create_dataset(auth, dataset_name): | |||
| authorization = {"Authorization": auth} | |||
| url = f"{HOST_ADDRESS}/v1/kb/create" | |||
| @@ -24,3 +26,9 @@ def rm_dataset(auth, dataset_id): | |||
| json = {"kb_id": dataset_id} | |||
| res = requests.post(url=url, headers=authorization, json=json) | |||
| return res.json() | |||
| def update_dataset(auth, json_req): | |||
| authorization = {"Authorization": auth} | |||
| url = f"{HOST_ADDRESS}/v1/kb/update" | |||
| res = requests.post(url=url, headers=authorization, json=json_req) | |||
| return res.json() | |||
| @@ -1,6 +1,8 @@ | |||
| from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset | |||
| import requests | |||
| from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT | |||
| import re | |||
| import pytest | |||
| import random | |||
| import string | |||
| def test_dataset(get_auth): | |||
| # create dataset | |||
| @@ -56,8 +58,76 @@ def test_dataset_1k_dataset(get_auth): | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| print(f"{len(dataset_list)} datasets are deleted") | |||
| # delete dataset | |||
| # create invalid name dataset | |||
| def test_duplicated_name_dataset(get_auth): | |||
| # create dataset | |||
| for i in range(20): | |||
| res = create_dataset(get_auth, "test_create_dataset") | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| # list dataset | |||
| res = list_dataset(get_auth, 1) | |||
| data = res.get("data") | |||
| dataset_list = [] | |||
| pattern = r'^test_create_dataset.*' | |||
| for item in data: | |||
| dataset_name = item.get("name") | |||
| dataset_id = item.get("id") | |||
| dataset_list.append(dataset_id) | |||
| match = re.match(pattern, dataset_name) | |||
| assert match != None | |||
| for dataset_id in dataset_list: | |||
| res = rm_dataset(get_auth, dataset_id) | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| print(f"{len(dataset_list)} datasets are deleted") | |||
| def test_invalid_name_dataset(get_auth): | |||
| # create dataset | |||
| # with pytest.raises(Exception) as e: | |||
| res = create_dataset(get_auth, 0) | |||
| assert res['code'] == 102 | |||
| res = create_dataset(get_auth, "") | |||
| assert res['code'] == 102 | |||
| long_string = "" | |||
| while len(long_string) <= DATASET_NAME_LIMIT: | |||
| long_string += random.choice(string.ascii_letters + string.digits) | |||
| res = create_dataset(get_auth, long_string) | |||
| assert res['code'] == 102 | |||
| print(res) | |||
| def test_update_different_params_dataset(get_auth): | |||
| # create dataset | |||
| res = create_dataset(get_auth, "test_create_dataset") | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| # list dataset | |||
| page_number = 1 | |||
| dataset_list = [] | |||
| while True: | |||
| res = list_dataset(get_auth, page_number) | |||
| data = res.get("data") | |||
| for item in data: | |||
| dataset_id = item.get("id") | |||
| dataset_list.append(dataset_id) | |||
| if len(dataset_list) < page_number * 150: | |||
| break | |||
| page_number += 1 | |||
| print(f"found {len(dataset_list)} datasets") | |||
| dataset_id = dataset_list[0] | |||
| json_req = {"kb_id": dataset_id, "name": "test_update_dataset", "description": "test", "permission": "me", "parser_id": "presentation"} | |||
| res = update_dataset(get_auth, json_req) | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| # delete dataset | |||
| for dataset_id in dataset_list: | |||
| res = rm_dataset(get_auth, dataset_id) | |||
| assert res.get("code") == 0, f"{res.get('message')}" | |||
| print(f"{len(dataset_list)} datasets are deleted") | |||
| # update dataset with different parameters | |||
| # create duplicated name dataset | |||
| # | |||
| @@ -15,7 +15,7 @@ get_distro_info() { | |||
| echo "$distro_id $distro_version (Kernel version: $kernel_version)" | |||
| } | |||
| # get Git repo name | |||
| # get Git repository name | |||
| git_repo_name='' | |||
| if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then | |||
| git_repo_name=$(basename "$(git rev-parse --show-toplevel)") | |||
| @@ -48,8 +48,8 @@ else | |||
| python_version="Python not installed" | |||
| fi | |||
| # Print all infomation | |||
| echo "Current Repo: $git_repo_name" | |||
| # Print all information | |||
| echo "Current Repository: $git_repo_name" | |||
| # get Commit ID | |||
| git_version=$(git log -1 --pretty=format:'%h') | |||