### What problem does this PR solve?
This PR have finished 'create dataset' of both HTTP API and Python SDK.
HTTP API:
```
curl --request POST --url http://<HOST_ADDRESS>/api/v1/dataset --header 'Content-Type: application/json' --header 'Authorization: <ACCESS_KEY>' --data-binary '{
"name": "<DATASET_NAME>"
}'
```
Python SDK:
```
from ragflow.ragflow import RAGFLow
ragflow = RAGFLow('<ACCESS_KEY>', 'http://127.0.0.1:9380')
ragflow.create_dataset("dataset1")
```
TODO:
- ACCESS_KEY is the login_token when user login RAGFlow, currently.
RAGFlow should have the function that user can add/delete access_key.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
---------
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
tags/v0.8.0
| def search_pages_path(pages_dir): | def search_pages_path(pages_dir): | ||||
| return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] | |||||
| app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] | |||||
| api_path_list = [path for path in pages_dir.glob('*_api.py') if not path.name.startswith('.')] | |||||
| app_path_list.extend(api_path_list) | |||||
| return app_path_list | |||||
| def register_page(page_path): | def register_page(page_path): | ||||
| page_name = page_path.stem.rstrip('_app') | |||||
| module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, )) | |||||
| path = f'{page_path}' | |||||
| page_name = page_path.stem.rstrip('_api') if "_api" in path else page_path.stem.rstrip('_app') | |||||
| module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,)) | |||||
| spec = spec_from_file_location(module_name, page_path) | spec = spec_from_file_location(module_name, page_path) | ||||
| page = module_from_spec(spec) | page = module_from_spec(spec) | ||||
| page.manager = Blueprint(page_name, module_name) | page.manager = Blueprint(page_name, module_name) | ||||
| sys.modules[module_name] = page | sys.modules[module_name] = page | ||||
| spec.loader.exec_module(page) | spec.loader.exec_module(page) | ||||
| page_name = getattr(page, 'page_name', page_name) | page_name = getattr(page, 'page_name', page_name) | ||||
| url_prefix = f'/{API_VERSION}/{page_name}' | |||||
| url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}' | |||||
| app.register_blueprint(page.manager, url_prefix=url_prefix) | app.register_blueprint(page.manager, url_prefix=url_prefix) | ||||
| print(f'API file: {page_path}, URL: {url_prefix}') | |||||
| return url_prefix | return url_prefix | ||||
| pages_dir = [ | pages_dir = [ | ||||
| Path(__file__).parent, | Path(__file__).parent, | ||||
| Path(__file__).parent.parent / 'api' / 'apps', | |||||
| Path(__file__).parent.parent / 'api' / 'apps', # FIXME: ragflow/api/api/apps, can be remove? | |||||
| ] | ] | ||||
| client_urls_prefix = [ | client_urls_prefix = [ |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import json | |||||
| import os | |||||
| import re | |||||
| from datetime import datetime, timedelta | |||||
| from flask import request, Response | |||||
| from flask_login import login_required, current_user | |||||
| from api.db import FileType, ParserType, FileSource, StatusEnum | |||||
| from api.db.db_models import APIToken, API4Conversation, Task, File | |||||
| from api.db.services import duplicate_name | |||||
| from api.db.services.api_service import APITokenService, API4ConversationService | |||||
| from api.db.services.dialog_service import DialogService, chat | |||||
| from api.db.services.document_service import DocumentService | |||||
| from api.db.services.file2document_service import File2DocumentService | |||||
| from api.db.services.file_service import FileService | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||||
| from api.db.services.task_service import queue_tasks, TaskService | |||||
| from api.db.services.user_service import UserTenantService, TenantService | |||||
| from api.settings import RetCode, retrievaler | |||||
| from api.utils import get_uuid, current_timestamp, datetime_format | |||||
| # from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request | |||||
| from itsdangerous import URLSafeTimedSerializer | |||||
| from api.utils.file_utils import filename_type, thumbnail | |||||
| from rag.utils.minio_conn import MINIO | |||||
| # import library | |||||
| from api.utils.api_utils import construct_json_result, construct_result, construct_error_response, validate_request | |||||
| from api.contants import NAME_LENGTH_LIMIT | |||||
| # ------------------------------ create a dataset --------------------------------------- | |||||
| @manager.route('/', methods=['POST']) | |||||
| @login_required # use login | |||||
| @validate_request("name") # check name key | |||||
| def create_dataset(): | |||||
| # Check if Authorization header is present | |||||
| authorization_token = request.headers.get('Authorization') | |||||
| if not authorization_token: | |||||
| return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Authorization header is missing.") | |||||
| # TODO: Login or API key | |||||
| # objs = APIToken.query(token=authorization_token) | |||||
| # | |||||
| # # Authorization error | |||||
| # if not objs: | |||||
| # return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Token is invalid.") | |||||
| # | |||||
| # tenant_id = objs[0].tenant_id | |||||
| tenant_id = current_user.id | |||||
| request_body = request.json | |||||
| # In case that there's no name | |||||
| if "name" not in request_body: | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message="Expected 'name' field in request body") | |||||
| dataset_name = request_body["name"] | |||||
| # empty dataset_name | |||||
| if not dataset_name: | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message="Empty dataset name") | |||||
| # In case that there's space in the head or the tail | |||||
| dataset_name = dataset_name.strip() | |||||
| # In case that the length of the name exceeds the limit | |||||
| dataset_name_length = len(dataset_name) | |||||
| if dataset_name_length > NAME_LENGTH_LIMIT: | |||||
| return construct_json_result( | |||||
| message=f"Dataset name: {dataset_name} with length {dataset_name_length} exceeds {NAME_LENGTH_LIMIT}!") | |||||
| # In case that there are other fields in the data-binary | |||||
| if len(request_body.keys()) > 1: | |||||
| name_list = [] | |||||
| for key_name in request_body.keys(): | |||||
| if key_name != 'name': | |||||
| name_list.append(key_name) | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, | |||||
| message=f"fields: {name_list}, are not allowed in request body.") | |||||
| # If there is a duplicate name, it will modify it to make it unique | |||||
| request_body["name"] = duplicate_name( | |||||
| KnowledgebaseService.query, | |||||
| name=dataset_name, | |||||
| tenant_id=tenant_id, | |||||
| status=StatusEnum.VALID.value) | |||||
| try: | |||||
| request_body["id"] = get_uuid() | |||||
| request_body["tenant_id"] = tenant_id | |||||
| request_body["created_by"] = tenant_id | |||||
| e, t = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| return construct_result(code=RetCode.AUTHENTICATION_ERROR, message="Tenant not found.") | |||||
| request_body["embd_id"] = t.embd_id | |||||
| if not KnowledgebaseService.save(**request_body): | |||||
| # failed to create new dataset | |||||
| return construct_result() | |||||
| return construct_json_result(data={"dataset_id": request_body["id"]}) | |||||
| except Exception as e: | |||||
| return construct_error_response(e) | |||||
| @manager.route('/<dataset_id>', methods=['DELETE']) | |||||
| @login_required | |||||
| def remove_dataset(dataset_id): | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to remove dataset: {dataset_id}") | |||||
| @manager.route('/<dataset_id>', methods=['PUT']) | |||||
| @login_required | |||||
| @validate_request("name") | |||||
| def update_dataset(dataset_id): | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to update dataset: {dataset_id}") | |||||
| @manager.route('/<dataset_id>', methods=['GET']) | |||||
| @login_required | |||||
| def get_dataset(dataset_id): | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") | |||||
| @manager.route('/', methods=['GET']) | |||||
| @login_required | |||||
| def list_datasets(): | |||||
| return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets") | |||||
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| NAME_LENGTH_LIMIT = 2 ** 10 |
| RUNNING = 106 | RUNNING = 106 | ||||
| PERMISSION_ERROR = 108 | PERMISSION_ERROR = 108 | ||||
| AUTHENTICATION_ERROR = 109 | AUTHENTICATION_ERROR = 109 | ||||
| UNAUTHORIZED = 401 | |||||
| SERVER_ERROR = 500 | SERVER_ERROR = 500 |
| from hmac import HMAC | from hmac import HMAC | ||||
| from urllib.parse import quote, urlencode | from urllib.parse import quote, urlencode | ||||
| requests.models.complexjson.dumps = functools.partial( | requests.models.complexjson.dumps = functools.partial( | ||||
| json.dumps, cls=CustomJSONEncoder) | json.dumps, cls=CustomJSONEncoder) | ||||
| response.headers["Access-Control-Allow-Headers"] = "*" | response.headers["Access-Control-Allow-Headers"] = "*" | ||||
| response.headers["Access-Control-Expose-Headers"] = "Authorization" | response.headers["Access-Control-Expose-Headers"] = "Authorization" | ||||
| return response | return response | ||||
| def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): | |||||
| import re | |||||
| result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} | |||||
| response = {} | |||||
| for key, value in result_dict.items(): | |||||
| if value is None and key != "code": | |||||
| continue | |||||
| else: | |||||
| response[key] = value | |||||
| return jsonify(response) | |||||
| def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): | |||||
| if data == None: | |||||
| return jsonify({"code": code, "message": message}) | |||||
| else: | |||||
| return jsonify({"code": code, "message": message, "data": data}) | |||||
| def construct_error_response(e): | |||||
| stat_logger.exception(e) | |||||
| try: | |||||
| if e.code == 401: | |||||
| return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) | |||||
| except BaseException: | |||||
| pass | |||||
| if len(e.args) > 1: | |||||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) | |||||
| if repr(e).find("index_not_found_exception") >=0: | |||||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") | |||||
| return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) |
| # ragflow | |||||
| # python-ragflow | |||||
| # update python client | |||||
| - Update "version" field of [project] chapter | |||||
| - build new python SDK | |||||
| - upload to pypi.org | |||||
| - install new python SDK | |||||
| # build python SDK | |||||
| ```shell | |||||
| rm -f dist/* && python setup.py sdist bdist_wheel | |||||
| ``` | |||||
| # install python SDK | |||||
| ```shell | |||||
| pip uninstall -y ragflow && pip install dist/*.whl | |||||
| ``` | |||||
| This will install ragflow-sdk and its dependencies. | |||||
| # upload to pypi.org | |||||
| ```shell | |||||
| twine upload dist/*.whl | |||||
| ``` | |||||
| Enter your pypi API token according to the prompt. | |||||
| Note that pypi allow a version of a package [be uploaded only once](https://pypi.org/help/#file-name-reuse). You need to change the `version` inside the `pyproject.toml` before build and upload. | |||||
| # using | |||||
| ```python | |||||
| ``` | |||||
| # For developer | |||||
| ```shell | |||||
| pip install -e . | |||||
| ``` |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| class DataSet: | |||||
| def __init__(self, user_key, dataset_url, uuid, name): | |||||
| self.user_key = user_key | |||||
| self.dataset_url = dataset_url | |||||
| self.uuid = uuid | |||||
| self.name = name |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | |||||
| import os | import os | ||||
| from abc import ABC | |||||
| import requests | import requests | ||||
| import json | |||||
| class RAGFLow(ABC): | |||||
| def __init__(self, user_key, base_url): | |||||
| class RAGFLow: | |||||
| def __init__(self, user_key, base_url, version = 'v1'): | |||||
| ''' | |||||
| api_url: http://<host_address>/api/v1 | |||||
| dataset_url: http://<host_address>/api/v1/dataset | |||||
| ''' | |||||
| self.user_key = user_key | self.user_key = user_key | ||||
| self.base_url = base_url | |||||
| self.api_url = f"{base_url}/api/{version}" | |||||
| self.dataset_url = f"{self.api_url}/dataset" | |||||
| self.authorization_header = {"Authorization": "{}".format(self.user_key)} | |||||
| def create_dataset(self, name): | |||||
| return name | |||||
| def create_dataset(self, dataset_name): | |||||
| """ | |||||
| name: dataset name | |||||
| """ | |||||
| res = requests.post(url=self.dataset_url, json={"name": dataset_name}, headers=self.authorization_header) | |||||
| result_dict = json.loads(res.text) | |||||
| return result_dict | |||||
| def delete_dataset(self, name): | |||||
| return name | |||||
| def delete_dataset(self, dataset_name = None, dataset_id = None): | |||||
| return dataset_name | |||||
| def list_dataset(self): | def list_dataset(self): | ||||
| endpoint = f"{self.base_url}/api/v1/dataset" | |||||
| response = requests.get(endpoint) | |||||
| response = requests.get(self.dataset_url) | |||||
| print(response) | |||||
| if response.status_code == 200: | if response.status_code == 200: | ||||
| return response.json()['datasets'] | return response.json()['datasets'] | ||||
| else: | else: | ||||
| return None | return None | ||||
| def get_dataset(self, dataset_id): | def get_dataset(self, dataset_id): | ||||
| endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}" | |||||
| endpoint = f"{self.dataset_url}/{dataset_id}" | |||||
| response = requests.get(endpoint) | response = requests.get(endpoint) | ||||
| if response.status_code == 200: | if response.status_code == 200: | ||||
| return response.json() | return response.json() | ||||
| return None | return None | ||||
| def update_dataset(self, dataset_id, params): | def update_dataset(self, dataset_id, params): | ||||
| endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}" | |||||
| endpoint = f"{self.dataset_url}/{dataset_id}" | |||||
| response = requests.put(endpoint, json=params) | response = requests.put(endpoint, json=params) | ||||
| if response.status_code == 200: | if response.status_code == 200: | ||||
| return True | return True |
| API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE' | |||||
| HOST_ADDRESS = 'http://127.0.0.1:9380' |
| from ragflow.ragflow import RAGFLow | from ragflow.ragflow import RAGFLow | ||||
| import pytest | import pytest | ||||
| from unittest.mock import MagicMock | from unittest.mock import MagicMock | ||||
| from common import API_KEY, HOST_ADDRESS | |||||
| class TestCase(TestSdk): | |||||
| @pytest.fixture | |||||
| def ragflow_instance(self): | |||||
| # Here we create a mock instance of RAGFlow for testing | |||||
| return ragflow.ragflow.RAGFLow('123', 'url') | |||||
| class TestBasic(TestSdk): | |||||
| def test_version(self): | def test_version(self): | ||||
| print(ragflow.__version__) | print(ragflow.__version__) | ||||
| def test_create_dataset(self): | |||||
| assert ragflow.ragflow.RAGFLow('123', 'url').create_dataset('abc') == 'abc' | |||||
| def test_delete_dataset(self): | |||||
| assert ragflow.ragflow.RAGFLow('123', 'url').delete_dataset('abc') == 'abc' | |||||
| def test_list_dataset_success(self, ragflow_instance, monkeypatch): | |||||
| # Mocking the response of requests.get method | |||||
| mock_response = MagicMock() | |||||
| mock_response.status_code = 200 | |||||
| mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]} | |||||
| # Patching requests.get to return the mock_response | |||||
| monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||||
| # Call the method under test | |||||
| result = ragflow_instance.list_dataset() | |||||
| # Assertion | |||||
| assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}] | |||||
| def test_list_dataset_failure(self, ragflow_instance, monkeypatch): | |||||
| # Mocking the response of requests.get method | |||||
| mock_response = MagicMock() | |||||
| mock_response.status_code = 404 # Simulating a failed request | |||||
| # Patching requests.get to return the mock_response | |||||
| monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||||
| # Call the method under test | |||||
| result = ragflow_instance.list_dataset() | |||||
| # Assertion | |||||
| assert result is None | |||||
| # def test_create_dataset(self): | |||||
| # res = RAGFLow(API_KEY, HOST_ADDRESS).create_dataset('abc') | |||||
| # print(res) | |||||
| # | |||||
| # def test_delete_dataset(self): | |||||
| # assert RAGFLow('123', 'url').delete_dataset('abc') == 'abc' | |||||
| # | |||||
| # def test_list_dataset_success(self, ragflow_instance, monkeypatch): | |||||
| # # Mocking the response of requests.get method | |||||
| # mock_response = MagicMock() | |||||
| # mock_response.status_code = 200 | |||||
| # mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]} | |||||
| # | |||||
| # # Patching requests.get to return the mock_response | |||||
| # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||||
| # | |||||
| # # Call the method under test | |||||
| # result = ragflow_instance.list_dataset() | |||||
| # | |||||
| # # Assertion | |||||
| # assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}] | |||||
| # | |||||
| # def test_list_dataset_failure(self, ragflow_instance, monkeypatch): | |||||
| # # Mocking the response of requests.get method | |||||
| # mock_response = MagicMock() | |||||
| # mock_response.status_code = 404 # Simulating a failed request | |||||
| # | |||||
| # # Patching requests.get to return the mock_response | |||||
| # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) | |||||
| # | |||||
| # # Call the method under test | |||||
| # result = ragflow_instance.list_dataset() | |||||
| # | |||||
| # # Assertion | |||||
| # assert result is None |
| from test_sdkbase import TestSdk | |||||
| import ragflow | |||||
| from ragflow.ragflow import RAGFLow | |||||
| import pytest | |||||
| from unittest.mock import MagicMock | |||||
| from common import API_KEY, HOST_ADDRESS | |||||
| class TestDataset(TestSdk): | |||||
| def test_create_dataset(self): | |||||
| ''' | |||||
| 1. create a kb | |||||
| 2. list the kb | |||||
| 3. get the detail info according to the kb id | |||||
| 4. update the kb | |||||
| 5. delete the kb | |||||
| ''' | |||||
| ragflow = RAGFLow(API_KEY, HOST_ADDRESS) | |||||
| # create a kb | |||||
| res = ragflow.create_dataset("kb1") | |||||
| assert res['code'] == 0 and res['message'] == 'success' | |||||
| dataset_id = res['data']['dataset_id'] | |||||
| print(dataset_id) | |||||
| # TODO: list the kb |