### What problem does this PR solve?
This PR have finished 'create dataset' of both HTTP API and Python SDK.
HTTP API:
```
curl --request POST --url http://<HOST_ADDRESS>/api/v1/dataset --header 'Content-Type: application/json' --header 'Authorization: <ACCESS_KEY>' --data-binary '{
"name": "<DATASET_NAME>"
}'
```
Python SDK:
```
from ragflow.ragflow import RAGFLow
ragflow = RAGFLow('<ACCESS_KEY>', 'http://127.0.0.1:9380')
ragflow.create_dataset("dataset1")
```
TODO:
- ACCESS_KEY is the login_token when user login RAGFlow, currently.
RAGFlow should have the function that user can add/delete access_key.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
---------
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
tags/v0.8.0
| @@ -63,12 +63,17 @@ login_manager.init_app(app) | |||
| def search_pages_path(pages_dir): | |||
| return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] | |||
| app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] | |||
| api_path_list = [path for path in pages_dir.glob('*_api.py') if not path.name.startswith('.')] | |||
| app_path_list.extend(api_path_list) | |||
| return app_path_list | |||
| def register_page(page_path): | |||
| page_name = page_path.stem.rstrip('_app') | |||
| module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, )) | |||
| path = f'{page_path}' | |||
| page_name = page_path.stem.rstrip('_api') if "_api" in path else page_path.stem.rstrip('_app') | |||
| module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,)) | |||
| spec = spec_from_file_location(module_name, page_path) | |||
| page = module_from_spec(spec) | |||
| @@ -76,17 +81,17 @@ def register_page(page_path): | |||
| page.manager = Blueprint(page_name, module_name) | |||
| sys.modules[module_name] = page | |||
| spec.loader.exec_module(page) | |||
| page_name = getattr(page, 'page_name', page_name) | |||
| url_prefix = f'/{API_VERSION}/{page_name}' | |||
| url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}' | |||
| app.register_blueprint(page.manager, url_prefix=url_prefix) | |||
| print(f'API file: {page_path}, URL: {url_prefix}') | |||
| return url_prefix | |||
| pages_dir = [ | |||
| Path(__file__).parent, | |||
| Path(__file__).parent.parent / 'api' / 'apps', | |||
| Path(__file__).parent.parent / 'api' / 'apps', # FIXME: ragflow/api/api/apps, can be remove? | |||
| ] | |||
| client_urls_prefix = [ | |||
| @@ -0,0 +1,142 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import json | |||
| import os | |||
| import re | |||
| from datetime import datetime, timedelta | |||
| from flask import request, Response | |||
| from flask_login import login_required, current_user | |||
| from api.db import FileType, ParserType, FileSource, StatusEnum | |||
| from api.db.db_models import APIToken, API4Conversation, Task, File | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.api_service import APITokenService, API4ConversationService | |||
| from api.db.services.dialog_service import DialogService, chat | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.task_service import queue_tasks, TaskService | |||
| from api.db.services.user_service import UserTenantService, TenantService | |||
| from api.settings import RetCode, retrievaler | |||
| from api.utils import get_uuid, current_timestamp, datetime_format | |||
| # from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request | |||
| from itsdangerous import URLSafeTimedSerializer | |||
| from api.utils.file_utils import filename_type, thumbnail | |||
| from rag.utils.minio_conn import MINIO | |||
| # import library | |||
| from api.utils.api_utils import construct_json_result, construct_result, construct_error_response, validate_request | |||
| from api.contants import NAME_LENGTH_LIMIT | |||
| # ------------------------------ create a dataset --------------------------------------- | |||
| @manager.route('/', methods=['POST']) | |||
| @login_required # use login | |||
| @validate_request("name") # check name key | |||
| def create_dataset(): | |||
| # Check if Authorization header is present | |||
| authorization_token = request.headers.get('Authorization') | |||
| if not authorization_token: | |||
| return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Authorization header is missing.") | |||
| # TODO: Login or API key | |||
| # objs = APIToken.query(token=authorization_token) | |||
| # | |||
| # # Authorization error | |||
| # if not objs: | |||
| # return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Token is invalid.") | |||
| # | |||
| # tenant_id = objs[0].tenant_id | |||
| tenant_id = current_user.id | |||
| request_body = request.json | |||
| # In case that there's no name | |||
| if "name" not in request_body: | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message="Expected 'name' field in request body") | |||
| dataset_name = request_body["name"] | |||
| # empty dataset_name | |||
| if not dataset_name: | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message="Empty dataset name") | |||
| # In case that there's space in the head or the tail | |||
| dataset_name = dataset_name.strip() | |||
| # In case that the length of the name exceeds the limit | |||
| dataset_name_length = len(dataset_name) | |||
| if dataset_name_length > NAME_LENGTH_LIMIT: | |||
| return construct_json_result( | |||
| message=f"Dataset name: {dataset_name} with length {dataset_name_length} exceeds {NAME_LENGTH_LIMIT}!") | |||
| # In case that there are other fields in the data-binary | |||
| if len(request_body.keys()) > 1: | |||
| name_list = [] | |||
| for key_name in request_body.keys(): | |||
| if key_name != 'name': | |||
| name_list.append(key_name) | |||
| return construct_json_result(code=RetCode.DATA_ERROR, | |||
| message=f"fields: {name_list}, are not allowed in request body.") | |||
| # If there is a duplicate name, it will modify it to make it unique | |||
| request_body["name"] = duplicate_name( | |||
| KnowledgebaseService.query, | |||
| name=dataset_name, | |||
| tenant_id=tenant_id, | |||
| status=StatusEnum.VALID.value) | |||
| try: | |||
| request_body["id"] = get_uuid() | |||
| request_body["tenant_id"] = tenant_id | |||
| request_body["created_by"] = tenant_id | |||
| e, t = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| return construct_result(code=RetCode.AUTHENTICATION_ERROR, message="Tenant not found.") | |||
| request_body["embd_id"] = t.embd_id | |||
| if not KnowledgebaseService.save(**request_body): | |||
| # failed to create new dataset | |||
| return construct_result() | |||
| return construct_json_result(data={"dataset_id": request_body["id"]}) | |||
| except Exception as e: | |||
| return construct_error_response(e) | |||
| @manager.route('/<dataset_id>', methods=['DELETE']) | |||
| @login_required | |||
| def remove_dataset(dataset_id): | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to remove dataset: {dataset_id}") | |||
| @manager.route('/<dataset_id>', methods=['PUT']) | |||
| @login_required | |||
| @validate_request("name") | |||
| def update_dataset(dataset_id): | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to update dataset: {dataset_id}") | |||
| @manager.route('/<dataset_id>', methods=['GET']) | |||
| @login_required | |||
| def get_dataset(dataset_id): | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") | |||
| @manager.route('/', methods=['GET']) | |||
| @login_required | |||
| def list_datasets(): | |||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets") | |||
| @@ -0,0 +1,16 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| NAME_LENGTH_LIMIT = 2 ** 10 | |||
| @@ -239,4 +239,5 @@ class RetCode(IntEnum, CustomEnum): | |||
| RUNNING = 106 | |||
| PERMISSION_ERROR = 108 | |||
| AUTHENTICATION_ERROR = 109 | |||
| UNAUTHORIZED = 401 | |||
| SERVER_ERROR = 500 | |||
| @@ -38,7 +38,6 @@ from base64 import b64encode | |||
| from hmac import HMAC | |||
| from urllib.parse import quote, urlencode | |||
| requests.models.complexjson.dumps = functools.partial( | |||
| json.dumps, cls=CustomJSONEncoder) | |||
| @@ -235,3 +234,35 @@ def cors_reponse(retcode=RetCode.SUCCESS, | |||
| response.headers["Access-Control-Allow-Headers"] = "*" | |||
| response.headers["Access-Control-Expose-Headers"] = "Authorization" | |||
| return response | |||
| def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): | |||
| import re | |||
| result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} | |||
| response = {} | |||
| for key, value in result_dict.items(): | |||
| if value is None and key != "code": | |||
| continue | |||
| else: | |||
| response[key] = value | |||
| return jsonify(response) | |||
| def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): | |||
| if data == None: | |||
| return jsonify({"code": code, "message": message}) | |||
| else: | |||
| return jsonify({"code": code, "message": message, "data": data}) | |||
| def construct_error_response(e): | |||
| stat_logger.exception(e) | |||
| try: | |||
| if e.code == 401: | |||
| return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) | |||
| except BaseException: | |||
| pass | |||
| if len(e.args) > 1: | |||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||
| if repr(e).find("index_not_found_exception") >=0: | |||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") | |||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) | |||
| @@ -1 +1,41 @@ | |||
| # ragflow | |||
| # python-ragflow | |||
| # update python client | |||
| - Update "version" field of [project] chapter | |||
| - build new python SDK | |||
| - upload to pypi.org | |||
| - install new python SDK | |||
| # build python SDK | |||
| ```shell | |||
| rm -f dist/* && python setup.py sdist bdist_wheel | |||
| ``` | |||
| # install python SDK | |||
| ```shell | |||
| pip uninstall -y ragflow && pip install dist/*.whl | |||
| ``` | |||
| This will install ragflow-sdk and its dependencies. | |||
| # upload to pypi.org | |||
| ```shell | |||
| twine upload dist/*.whl | |||
| ``` | |||
| Enter your pypi API token according to the prompt. | |||
| Note that pypi allow a version of a package [be uploaded only once](https://pypi.org/help/#file-name-reuse). You need to change the `version` inside the `pyproject.toml` before build and upload. | |||
| # using | |||
| ```python | |||
| ``` | |||
| # For developer | |||
| ```shell | |||
| pip install -e . | |||
| ``` | |||
| @@ -0,0 +1,21 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| class DataSet: | |||
| def __init__(self, user_key, dataset_url, uuid, name): | |||
| self.user_key = user_key | |||
| self.dataset_url = dataset_url | |||
| self.uuid = uuid | |||
| self.name = name | |||
| @@ -12,33 +12,43 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from abc import ABC | |||
| import requests | |||
| import json | |||
| class RAGFLow(ABC): | |||
| def __init__(self, user_key, base_url): | |||
| class RAGFLow: | |||
| def __init__(self, user_key, base_url, version = 'v1'): | |||
| ''' | |||
| api_url: http://<host_address>/api/v1 | |||
| dataset_url: http://<host_address>/api/v1/dataset | |||
| ''' | |||
| self.user_key = user_key | |||
| self.base_url = base_url | |||
| self.api_url = f"{base_url}/api/{version}" | |||
| self.dataset_url = f"{self.api_url}/dataset" | |||
| self.authorization_header = {"Authorization": "{}".format(self.user_key)} | |||
| def create_dataset(self, name): | |||
| return name | |||
| def create_dataset(self, dataset_name): | |||
| """ | |||
| name: dataset name | |||
| """ | |||
| res = requests.post(url=self.dataset_url, json={"name": dataset_name}, headers=self.authorization_header) | |||
| result_dict = json.loads(res.text) | |||
| return result_dict | |||
| def delete_dataset(self, name): | |||
| return name | |||
| def delete_dataset(self, dataset_name = None, dataset_id = None): | |||
| return dataset_name | |||
| def list_dataset(self): | |||
| endpoint = f"{self.base_url}/api/v1/dataset" | |||
| response = requests.get(endpoint) | |||
| response = requests.get(self.dataset_url) | |||
| print(response) | |||
| if response.status_code == 200: | |||
| return response.json()['datasets'] | |||
| else: | |||
| return None | |||
| def get_dataset(self, dataset_id): | |||
| endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}" | |||
| endpoint = f"{self.dataset_url}/{dataset_id}" | |||
| response = requests.get(endpoint) | |||
| if response.status_code == 200: | |||
| return response.json() | |||
| @@ -46,7 +56,7 @@ class RAGFLow(ABC): | |||
| return None | |||
| def update_dataset(self, dataset_id, params): | |||
| endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}" | |||
| endpoint = f"{self.dataset_url}/{dataset_id}" | |||
| response = requests.put(endpoint, json=params) | |||
| if response.status_code == 200: | |||
| return True | |||
| @@ -0,0 +1,4 @@ | |||
| API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE' | |||
| HOST_ADDRESS = 'http://127.0.0.1:9380' | |||
| @@ -3,49 +3,46 @@ import ragflow | |||
| from ragflow.ragflow import RAGFLow | |||
| import pytest | |||
| from unittest.mock import MagicMock | |||
| from common import API_KEY, HOST_ADDRESS | |||
| class TestCase(TestSdk): | |||
| @pytest.fixture | |||
| def ragflow_instance(self): | |||
| # Here we create a mock instance of RAGFlow for testing | |||
| return ragflow.ragflow.RAGFLow('123', 'url') | |||
| class TestBasic(TestSdk): | |||
| def test_version(self): | |||
| print(ragflow.__version__) | |||
| def test_create_dataset(self): | |||
| assert ragflow.ragflow.RAGFLow('123', 'url').create_dataset('abc') == 'abc' | |||
| def test_delete_dataset(self): | |||
| assert ragflow.ragflow.RAGFLow('123', 'url').delete_dataset('abc') == 'abc' | |||
| def test_list_dataset_success(self, ragflow_instance, monkeypatch): | |||
| # Mocking the response of requests.get method | |||
| mock_response = MagicMock() | |||
| mock_response.status_code = 200 | |||
| mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]} | |||
| # Patching requests.get to return the mock_response | |||
| monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||
| # Call the method under test | |||
| result = ragflow_instance.list_dataset() | |||
| # Assertion | |||
| assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}] | |||
| def test_list_dataset_failure(self, ragflow_instance, monkeypatch): | |||
| # Mocking the response of requests.get method | |||
| mock_response = MagicMock() | |||
| mock_response.status_code = 404 # Simulating a failed request | |||
| # Patching requests.get to return the mock_response | |||
| monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||
| # Call the method under test | |||
| result = ragflow_instance.list_dataset() | |||
| # Assertion | |||
| assert result is None | |||
| # def test_create_dataset(self): | |||
| # res = RAGFLow(API_KEY, HOST_ADDRESS).create_dataset('abc') | |||
| # print(res) | |||
| # | |||
| # def test_delete_dataset(self): | |||
| # assert RAGFLow('123', 'url').delete_dataset('abc') == 'abc' | |||
| # | |||
| # def test_list_dataset_success(self, ragflow_instance, monkeypatch): | |||
| # # Mocking the response of requests.get method | |||
| # mock_response = MagicMock() | |||
| # mock_response.status_code = 200 | |||
| # mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]} | |||
| # | |||
| # # Patching requests.get to return the mock_response | |||
| # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||
| # | |||
| # # Call the method under test | |||
| # result = ragflow_instance.list_dataset() | |||
| # | |||
| # # Assertion | |||
| # assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}] | |||
| # | |||
| # def test_list_dataset_failure(self, ragflow_instance, monkeypatch): | |||
| # # Mocking the response of requests.get method | |||
| # mock_response = MagicMock() | |||
| # mock_response.status_code = 404 # Simulating a failed request | |||
| # | |||
| # # Patching requests.get to return the mock_response | |||
| # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||
| # | |||
| # # Call the method under test | |||
| # result = ragflow_instance.list_dataset() | |||
| # | |||
| # # Assertion | |||
| # assert result is None | |||
| @@ -0,0 +1,26 @@ | |||
| from test_sdkbase import TestSdk | |||
| import ragflow | |||
| from ragflow.ragflow import RAGFLow | |||
| import pytest | |||
| from unittest.mock import MagicMock | |||
| from common import API_KEY, HOST_ADDRESS | |||
| class TestDataset(TestSdk): | |||
| def test_create_dataset(self): | |||
| ''' | |||
| 1. create a kb | |||
| 2. list the kb | |||
| 3. get the detail info according to the kb id | |||
| 4. update the kb | |||
| 5. delete the kb | |||
| ''' | |||
| ragflow = RAGFLow(API_KEY, HOST_ADDRESS) | |||
| # create a kb | |||
| res = ragflow.create_dataset("kb1") | |||
| assert res['code'] == 0 and res['message'] == 'success' | |||
| dataset_id = res['data']['dataset_id'] | |||
| print(dataset_id) | |||
| # TODO: list the kb | |||