### What problem does this PR solve? This PR have completed both HTTP API and Python SDK for 'list_dataset". In addition, there are tests for it. ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.8.0
| # ------------------------------ create a dataset --------------------------------------- | # ------------------------------ create a dataset --------------------------------------- | ||||
| @manager.route('/', methods=['POST']) | @manager.route('/', methods=['POST']) | ||||
| @login_required # use login | |||||
| @login_required # use login | |||||
| @validate_request("name") # check name key | @validate_request("name") # check name key | ||||
| def create_dataset(): | def create_dataset(): | ||||
| # Check if Authorization header is present | # Check if Authorization header is present | ||||
| if not KnowledgebaseService.save(**request_body): | if not KnowledgebaseService.save(**request_body): | ||||
| # failed to create new dataset | # failed to create new dataset | ||||
| return construct_result() | return construct_result() | ||||
| return construct_json_result(data={"dataset_id": request_body["id"]}) | |||||
| return construct_json_result(data={"dataset_name": request_body["name"]}) | |||||
| except Exception as e: | except Exception as e: | ||||
| return construct_error_response(e) | return construct_error_response(e) | ||||
| # -----------------------------list datasets------------------------------------------------------- | |||||
| @manager.route('/', methods=['GET']) | |||||
| @login_required | |||||
| def list_datasets(): | |||||
| offset = request.args.get("offset", 0) | |||||
| count = request.args.get("count", -1) | |||||
| orderby = request.args.get("orderby", "create_time") | |||||
| desc = request.args.get("desc", True) | |||||
| try: | |||||
| tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) | |||||
| kbs = KnowledgebaseService.get_by_tenant_ids( | |||||
| [m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc) | |||||
| return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets") | |||||
| except Exception as e: | |||||
| return construct_error_response(e) | |||||
| # ---------------------------------delete a dataset ---------------------------- | |||||
| @manager.route('/<dataset_id>', methods=['DELETE']) | @manager.route('/<dataset_id>', methods=['DELETE']) | ||||
| @login_required | @login_required | ||||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") | return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") | ||||
| @manager.route('/', methods=['GET']) | |||||
| @login_required | |||||
| def list_datasets(): | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets") | |||||
| return list(kbs.dicts()) | return list(kbs.dicts()) | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | |||||
| offset, count, orderby, desc): | |||||
| kbs = cls.model.select().where( | |||||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | |||||
| TenantPermission.TEAM.value)) | ( | |||||
| cls.model.tenant_id == user_id)) | |||||
| & (cls.model.status == StatusEnum.VALID.value) | |||||
| ) | |||||
| if desc: | |||||
| kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||||
| else: | |||||
| kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||||
| kbs = list(kbs.dicts()) | |||||
| kbs_length = len(kbs) | |||||
| if offset < 0 or offset > kbs_length: | |||||
| raise IndexError("Offset is out of the valid range.") | |||||
| return kbs[offset:offset+count] | |||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_detail(cls, kb_id): | def get_detail(cls, kb_id): |
| import importlib.metadata | import importlib.metadata | ||||
| __version__ = importlib.metadata.version("ragflow") | __version__ = importlib.metadata.version("ragflow") | ||||
| from .ragflow import RAGFlow |
| self.user_key = user_key | self.user_key = user_key | ||||
| self.dataset_url = dataset_url | self.dataset_url = dataset_url | ||||
| self.uuid = uuid | self.uuid = uuid | ||||
| self.name = name | |||||
| self.name = name |
| import requests | import requests | ||||
| import json | import json | ||||
| class RAGFLow: | |||||
| from httpx import HTTPError | |||||
| class RAGFlow: | |||||
| def __init__(self, user_key, base_url, version = 'v1'): | def __init__(self, user_key, base_url, version = 'v1'): | ||||
| ''' | ''' | ||||
| api_url: http://<host_address>/api/v1 | api_url: http://<host_address>/api/v1 | ||||
| result_dict = json.loads(res.text) | result_dict = json.loads(res.text) | ||||
| return result_dict | return result_dict | ||||
| def delete_dataset(self, dataset_name = None, dataset_id = None): | |||||
| def delete_dataset(self, dataset_name=None, dataset_id=None): | |||||
| return dataset_name | return dataset_name | ||||
| def list_dataset(self): | |||||
| response = requests.get(self.dataset_url) | |||||
| print(response) | |||||
| if response.status_code == 200: | |||||
| return response.json()['datasets'] | |||||
| else: | |||||
| return None | |||||
| def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True): | |||||
| params = { | |||||
| "offset": offset, | |||||
| "count": count, | |||||
| "orderby": orderby, | |||||
| "desc": desc | |||||
| } | |||||
| try: | |||||
| response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header) | |||||
| response.raise_for_status() # if it is not 200 | |||||
| original_data = response.json() | |||||
| # TODO: format the data | |||||
| # print(original_data) | |||||
| # # Process the original data into the desired format | |||||
| # formatted_data = { | |||||
| # "datasets": [ | |||||
| # { | |||||
| # "id": dataset["id"], | |||||
| # "created": dataset["create_time"], # Adjust the key based on the actual response | |||||
| # "fileCount": dataset["doc_num"], # Adjust the key based on the actual response | |||||
| # "name": dataset["name"] | |||||
| # } | |||||
| # for dataset in original_data | |||||
| # ] | |||||
| # } | |||||
| return response.status_code, original_data | |||||
| except HTTPError as http_err: | |||||
| print(f"HTTP error occurred: {http_err}") | |||||
| except Exception as err: | |||||
| print(f"An error occurred: {err}") | |||||
| def get_dataset(self, dataset_id): | def get_dataset(self, dataset_id): | ||||
| endpoint = f"{self.dataset_url}/{dataset_id}" | endpoint = f"{self.dataset_url}/{dataset_id}" | ||||
| if response.status_code == 200: | if response.status_code == 200: | ||||
| return True | return True | ||||
| else: | else: | ||||
| return False | |||||
| return False |
| API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE' | |||||
| API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E' | |||||
| HOST_ADDRESS = 'http://127.0.0.1:9380' | HOST_ADDRESS = 'http://127.0.0.1:9380' |
| from test_sdkbase import TestSdk | from test_sdkbase import TestSdk | ||||
| import ragflow | |||||
| from ragflow.ragflow import RAGFLow | |||||
| from ragflow import RAGFlow | |||||
| import pytest | import pytest | ||||
| from unittest.mock import MagicMock | |||||
| from common import API_KEY, HOST_ADDRESS | from common import API_KEY, HOST_ADDRESS | ||||
| class TestDataset(TestSdk): | class TestDataset(TestSdk): | ||||
| def test_create_dataset(self): | def test_create_dataset(self): | ||||
| 4. update the kb | 4. update the kb | ||||
| 5. delete the kb | 5. delete the kb | ||||
| ''' | ''' | ||||
| ragflow = RAGFLow(API_KEY, HOST_ADDRESS) | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| # create a kb | # create a kb | ||||
| res = ragflow.create_dataset("kb1") | res = ragflow.create_dataset("kb1") | ||||
| assert res['code'] == 0 and res['message'] == 'success' | assert res['code'] == 0 and res['message'] == 'success' | ||||
| dataset_id = res['data']['dataset_id'] | |||||
| print(dataset_id) | |||||
| dataset_name = res['data']['dataset_name'] | |||||
| def test_list_dataset_success(self): | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| # Call the list_datasets method | |||||
| response = ragflow.list_dataset() | |||||
| code, datasets = response | |||||
| assert code == 200 | |||||
| def test_list_dataset_with_checking_size_and_name(self): | |||||
| datasets_to_create = ["dataset1", "dataset2", "dataset3"] | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| created_response = [ragflow.create_dataset(name) for name in datasets_to_create] | |||||
| real_name_to_create = set() | |||||
| for response in created_response: | |||||
| assert 'data' in response, "Response is missing 'data' key" | |||||
| dataset_name = response['data']['dataset_name'] | |||||
| real_name_to_create.add(dataset_name) | |||||
| status_code, listed_data = ragflow.list_dataset(0, 3) | |||||
| listed_data = listed_data['data'] | |||||
| listed_names = {d['name'] for d in listed_data} | |||||
| assert listed_names == real_name_to_create | |||||
| assert status_code == 200 | |||||
| assert len(listed_data) == len(datasets_to_create) | |||||
| def test_list_dataset_with_getting_empty_result(self): | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| datasets_to_create = [] | |||||
| created_response = [ragflow.create_dataset(name) for name in datasets_to_create] | |||||
| real_name_to_create = set() | |||||
| for response in created_response: | |||||
| assert 'data' in response, "Response is missing 'data' key" | |||||
| dataset_name = response['data']['dataset_name'] | |||||
| real_name_to_create.add(dataset_name) | |||||
| status_code, listed_data = ragflow.list_dataset(0, 0) | |||||
| listed_data = listed_data['data'] | |||||
| listed_names = {d['name'] for d in listed_data} | |||||
| assert listed_names == real_name_to_create | |||||
| assert status_code == 200 | |||||
| assert len(listed_data) == 0 | |||||
| def test_list_dataset_with_creating_100_knowledge_bases(self): | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| datasets_to_create = ["dataset1"] * 100 | |||||
| created_response = [ragflow.create_dataset(name) for name in datasets_to_create] | |||||
| real_name_to_create = set() | |||||
| for response in created_response: | |||||
| assert 'data' in response, "Response is missing 'data' key" | |||||
| dataset_name = response['data']['dataset_name'] | |||||
| real_name_to_create.add(dataset_name) | |||||
| status_code, listed_data = ragflow.list_dataset(0, 100) | |||||
| listed_data = listed_data['data'] | |||||
| listed_names = {d['name'] for d in listed_data} | |||||
| assert listed_names == real_name_to_create | |||||
| assert status_code == 200 | |||||
| assert len(listed_data) == 100 | |||||
| def test_list_dataset_with_showing_one_dataset(self): | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| response = ragflow.list_dataset(0, 1) | |||||
| code, response = response | |||||
| datasets = response['data'] | |||||
| assert len(datasets) == 1 | |||||
| def test_list_dataset_failure(self): | |||||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| response = ragflow.list_dataset(-1, -1) | |||||
| _, res = response | |||||
| assert "IndexError" in res['message'] | |||||
| # TODO: list the kb |