### What problem does this PR solve? This PR have completed both HTTP API and Python SDK for 'list_dataset". In addition, there are tests for it. ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.8.0
| @@ -46,7 +46,7 @@ from api.contants import NAME_LENGTH_LIMIT | |||
| # ------------------------------ create a dataset --------------------------------------- | |||
| @manager.route('/', methods=['POST']) | |||
| @login_required # use login | |||
| @login_required # use login | |||
| @validate_request("name") # check name key | |||
| def create_dataset(): | |||
| # Check if Authorization header is present | |||
| @@ -111,10 +111,27 @@ def create_dataset(): | |||
| if not KnowledgebaseService.save(**request_body): | |||
| # failed to create new dataset | |||
| return construct_result() | |||
| return construct_json_result(data={"dataset_id": request_body["id"]}) | |||
| return construct_json_result(data={"dataset_name": request_body["name"]}) | |||
| except Exception as e: | |||
| return construct_error_response(e) | |||
| # -----------------------------list datasets------------------------------------------------------- | |||
| @manager.route('/', methods=['GET']) | |||
| @login_required | |||
| def list_datasets(): | |||
| offset = request.args.get("offset", 0) | |||
| count = request.args.get("count", -1) | |||
| orderby = request.args.get("orderby", "create_time") | |||
| desc = request.args.get("desc", True) | |||
| try: | |||
| tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) | |||
| kbs = KnowledgebaseService.get_by_tenant_ids( | |||
| [m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc) | |||
| return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets") | |||
| except Exception as e: | |||
| return construct_error_response(e) | |||
| # ---------------------------------delete a dataset ---------------------------- | |||
| @manager.route('/<dataset_id>', methods=['DELETE']) | |||
| @login_required | |||
| @@ -135,8 +152,5 @@ def get_dataset(dataset_id): | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") | |||
| @manager.route('/', methods=['GET']) | |||
| @login_required | |||
| def list_datasets(): | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets") | |||
| @@ -40,6 +40,29 @@ class KnowledgebaseService(CommonService): | |||
| return list(kbs.dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | |||
| offset, count, orderby, desc): | |||
| kbs = cls.model.select().where( | |||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | |||
| TenantPermission.TEAM.value)) | ( | |||
| cls.model.tenant_id == user_id)) | |||
| & (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| if desc: | |||
| kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||
| else: | |||
| kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||
| kbs = list(kbs.dicts()) | |||
| kbs_length = len(kbs) | |||
| if offset < 0 or offset > kbs_length: | |||
| raise IndexError("Offset is out of the valid range.") | |||
| return kbs[offset:offset+count] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_detail(cls, kb_id): | |||
| @@ -1,3 +1,5 @@ | |||
| import importlib.metadata | |||
| __version__ = importlib.metadata.version("ragflow") | |||
| from .ragflow import RAGFlow | |||
| @@ -18,4 +18,4 @@ class DataSet: | |||
| self.user_key = user_key | |||
| self.dataset_url = dataset_url | |||
| self.uuid = uuid | |||
| self.name = name | |||
| self.name = name | |||
| @@ -17,7 +17,10 @@ import os | |||
| import requests | |||
| import json | |||
| class RAGFLow: | |||
| from httpx import HTTPError | |||
| class RAGFlow: | |||
| def __init__(self, user_key, base_url, version = 'v1'): | |||
| ''' | |||
| api_url: http://<host_address>/api/v1 | |||
| @@ -36,16 +39,39 @@ class RAGFLow: | |||
| result_dict = json.loads(res.text) | |||
| return result_dict | |||
| def delete_dataset(self, dataset_name = None, dataset_id = None): | |||
| def delete_dataset(self, dataset_name=None, dataset_id=None): | |||
| return dataset_name | |||
| def list_dataset(self): | |||
| response = requests.get(self.dataset_url) | |||
| print(response) | |||
| if response.status_code == 200: | |||
| return response.json()['datasets'] | |||
| else: | |||
| return None | |||
| def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True): | |||
| params = { | |||
| "offset": offset, | |||
| "count": count, | |||
| "orderby": orderby, | |||
| "desc": desc | |||
| } | |||
| try: | |||
| response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header) | |||
| response.raise_for_status() # if it is not 200 | |||
| original_data = response.json() | |||
| # TODO: format the data | |||
| # print(original_data) | |||
| # # Process the original data into the desired format | |||
| # formatted_data = { | |||
| # "datasets": [ | |||
| # { | |||
| # "id": dataset["id"], | |||
| # "created": dataset["create_time"], # Adjust the key based on the actual response | |||
| # "fileCount": dataset["doc_num"], # Adjust the key based on the actual response | |||
| # "name": dataset["name"] | |||
| # } | |||
| # for dataset in original_data | |||
| # ] | |||
| # } | |||
| return response.status_code, original_data | |||
| except HTTPError as http_err: | |||
| print(f"HTTP error occurred: {http_err}") | |||
| except Exception as err: | |||
| print(f"An error occurred: {err}") | |||
| def get_dataset(self, dataset_id): | |||
| endpoint = f"{self.dataset_url}/{dataset_id}" | |||
| @@ -61,4 +87,4 @@ class RAGFLow: | |||
| if response.status_code == 200: | |||
| return True | |||
| else: | |||
| return False | |||
| return False | |||
| @@ -1,4 +1,4 @@ | |||
| API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE' | |||
| API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E' | |||
| HOST_ADDRESS = 'http://127.0.0.1:9380' | |||
| @@ -1,10 +1,10 @@ | |||
| from test_sdkbase import TestSdk | |||
| import ragflow | |||
| from ragflow.ragflow import RAGFLow | |||
| from ragflow import RAGFlow | |||
| import pytest | |||
| from unittest.mock import MagicMock | |||
| from common import API_KEY, HOST_ADDRESS | |||
| class TestDataset(TestSdk): | |||
| def test_create_dataset(self): | |||
| @@ -15,12 +15,92 @@ class TestDataset(TestSdk): | |||
| 4. update the kb | |||
| 5. delete the kb | |||
| ''' | |||
| ragflow = RAGFLow(API_KEY, HOST_ADDRESS) | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| # create a kb | |||
| res = ragflow.create_dataset("kb1") | |||
| assert res['code'] == 0 and res['message'] == 'success' | |||
| dataset_id = res['data']['dataset_id'] | |||
| print(dataset_id) | |||
| dataset_name = res['data']['dataset_name'] | |||
| def test_list_dataset_success(self): | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| # Call the list_datasets method | |||
| response = ragflow.list_dataset() | |||
| code, datasets = response | |||
| assert code == 200 | |||
| def test_list_dataset_with_checking_size_and_name(self): | |||
| datasets_to_create = ["dataset1", "dataset2", "dataset3"] | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| created_response = [ragflow.create_dataset(name) for name in datasets_to_create] | |||
| real_name_to_create = set() | |||
| for response in created_response: | |||
| assert 'data' in response, "Response is missing 'data' key" | |||
| dataset_name = response['data']['dataset_name'] | |||
| real_name_to_create.add(dataset_name) | |||
| status_code, listed_data = ragflow.list_dataset(0, 3) | |||
| listed_data = listed_data['data'] | |||
| listed_names = {d['name'] for d in listed_data} | |||
| assert listed_names == real_name_to_create | |||
| assert status_code == 200 | |||
| assert len(listed_data) == len(datasets_to_create) | |||
| def test_list_dataset_with_getting_empty_result(self): | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| datasets_to_create = [] | |||
| created_response = [ragflow.create_dataset(name) for name in datasets_to_create] | |||
| real_name_to_create = set() | |||
| for response in created_response: | |||
| assert 'data' in response, "Response is missing 'data' key" | |||
| dataset_name = response['data']['dataset_name'] | |||
| real_name_to_create.add(dataset_name) | |||
| status_code, listed_data = ragflow.list_dataset(0, 0) | |||
| listed_data = listed_data['data'] | |||
| listed_names = {d['name'] for d in listed_data} | |||
| assert listed_names == real_name_to_create | |||
| assert status_code == 200 | |||
| assert len(listed_data) == 0 | |||
| def test_list_dataset_with_creating_100_knowledge_bases(self): | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| datasets_to_create = ["dataset1"] * 100 | |||
| created_response = [ragflow.create_dataset(name) for name in datasets_to_create] | |||
| real_name_to_create = set() | |||
| for response in created_response: | |||
| assert 'data' in response, "Response is missing 'data' key" | |||
| dataset_name = response['data']['dataset_name'] | |||
| real_name_to_create.add(dataset_name) | |||
| status_code, listed_data = ragflow.list_dataset(0, 100) | |||
| listed_data = listed_data['data'] | |||
| listed_names = {d['name'] for d in listed_data} | |||
| assert listed_names == real_name_to_create | |||
| assert status_code == 200 | |||
| assert len(listed_data) == 100 | |||
| def test_list_dataset_with_showing_one_dataset(self): | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| response = ragflow.list_dataset(0, 1) | |||
| code, response = response | |||
| datasets = response['data'] | |||
| assert len(datasets) == 1 | |||
| def test_list_dataset_failure(self): | |||
| ragflow = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| response = ragflow.list_dataset(-1, -1) | |||
| _, res = response | |||
| assert "IndexError" in res['message'] | |||
| # TODO: list the kb | |||