Selaa lähdekoodia

complete implementation of dataset SDK (#2147)

### What problem does this PR solve?

Complete implementation of dataset SDK.
#1102

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
tags/v0.11.0
LiuHua 1 vuosi sitten
vanhempi
commit
f87e7242cd
No account linked to committer's email address

+ 118
- 44
api/apps/sdk/dataset.py Näytä tiedosto

# #
from flask import request from flask import request
from api.db import StatusEnum
from api.db.db_models import APIToken
from api.db import StatusEnum, FileSource
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result
from api.utils.api_utils import get_json_result, token_required, get_data_error_result
@manager.route('/save', methods=['POST']) @manager.route('/save', methods=['POST'])
def save():
@token_required
def save(tenant_id):
req = request.json req = request.json
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
e, t = TenantService.get_by_id(tenant_id) e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
if "id" not in req: if "id" not in req:
if "tenant_id" in req or "embd_id" in req:
return get_data_error_result(
retmsg="Tenant_id or embedding_model must not be provided")
if "name" not in req:
return get_data_error_result(
retmsg="Name is not empty!")
req['id'] = get_uuid() req['id'] = get_uuid()
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
if req["name"] == "": if req["name"] == "":
return get_data_error_result( return get_data_error_result(
retmsg="Name is not empty")
if KnowledgebaseService.query(name=req["name"]):
retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result( return get_data_error_result(
retmsg="Duplicated knowledgebase name")
retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = tenant_id req["tenant_id"] = tenant_id
req['created_by'] = tenant_id req['created_by'] = tenant_id
req['embd_id'] = t.embd_id req['embd_id'] = t.embd_id
if not KnowledgebaseService.save(**req): if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Data saving error")
req.pop('created_by')
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
for old_key,new_key in keys_to_rename.items():
if old_key in req:
req[new_key]=req.pop(old_key)
return get_data_error_result(retmsg="Create dataset error.(Database error)")
return get_json_result(data=req) return get_json_result(data=req)
else: else:
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change tenant_id or embedding_model")
if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_data_error_result(
retmsg="Can't change tenant_id.")
e, kb = KnowledgebaseService.get_by_id(req["id"])
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if "embd_id" in req:
if req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change embedding_model.")
if not KnowledgebaseService.query( if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]): created_by=tenant_id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR) retcode=RetCode.OPERATING_ERROR)
if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count or chunk_count ")
e, kb = KnowledgebaseService.get_by_id(req["id"])
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parser method is not changable. ")
if "chunk_num" in req:
if req["chunk_num"] != kb.chunk_num:
return get_data_error_result(
retmsg="Can't change chunk_count.")
if "doc_num" in req:
if req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count.")
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
if "parser_id" in req:
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parse method is not changable.")
if "name" in req:
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name in updating dataset.")
del req["id"] del req["id"]
req['created_by'] = tenant_id
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Data update error ")
return get_data_error_result(retmsg="Update dataset error.(Database error)")
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/delete', methods=['DELETE'])
@token_required
def delete(tenant_id):
req = request.args
kbs = KnowledgebaseService.query(
created_by=tenant_id, id=req["id"])
if not kbs:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(req["id"]):
return get_data_error_result(
retmsg="Delete dataset error.(Database error)")
return get_json_result(data=True)
@manager.route('/list', methods=['GET'])
@token_required
def list_datasets(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
desc = bool(request.args.get("desc", True))
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)
@manager.route('/detail', methods=['GET'])
@token_required
def detail(tenant_id):
req = request.args
if "id" in req:
id = req["id"]
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
if not kb:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
if "name" in req:
name = req["name"]
if kb[0].name != name:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
e, k = KnowledgebaseService.get_by_id(id)
return get_json_result(data=k.to_dict())
else:
if "name" in req:
name = req["name"]
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
if not e:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
return get_json_result(data=k.to_dict())
else:
return get_data_error_result(
retmsg="At least one of `id` or `name` must be provided.")

+ 34
- 13
api/utils/api_utils.py Näytä tiedosto

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import functools
import json import json
import random import random
import time import time
from base64 import b64encode
from functools import wraps from functools import wraps
from hmac import HMAC
from io import BytesIO from io import BytesIO
from urllib.parse import quote, urlencode
from uuid import uuid1

import requests
from flask import ( from flask import (
Response, jsonify, send_file, make_response, Response, jsonify, send_file, make_response,
request as flask_request, request as flask_request,
) )
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES


from api.utils import json_dumps
from api.settings import RetCode
from api.db.db_models import APIToken
from api.settings import ( from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
) )
import requests
import functools
from api.settings import RetCode
from api.utils import CustomJSONEncoder from api.utils import CustomJSONEncoder
from uuid import uuid1
from base64 import b64encode
from hmac import HMAC
from urllib.parse import quote, urlencode
from api.utils import json_dumps


requests.models.complexjson.dumps = functools.partial( requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder) json.dumps, cls=CustomJSONEncoder)


def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None): data=None, job_id=None, meta=None):
import re
result_dict = { result_dict = {
"retcode": retcode, "retcode": retcode,
"retmsg": retmsg, "retmsg": retmsg,
return get_json_result( return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0: if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
retmsg="No chunk found, please upload file and parse it.")


return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))


return get_json_result( return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)

return decorated_function return decorated_function

return wrapper return wrapper








def construct_response(retcode=RetCode.SUCCESS, def construct_response(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None):
retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {} response_dict = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
response.headers["Access-Control-Expose-Headers"] = "Authorization" response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response return response



def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
import re import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >=0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")


return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))


def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)

return decorated_function

+ 8
- 4
sdk/python/ragflow/modules/base.py Näytä tiedosto

pr[name] = value pr[name] = value
return pr return pr
def post(self, path, param): def post(self, path, param):
res = self.rag.post(path,param)
res = self.rag.post(path, param)
return res return res
def get(self, path, params=''):
res = self.rag.get(path,params)
def get(self, path, params):
res = self.rag.get(path, params)
return res return res
def rm(self, path, params):
res = self.rag.delete(path, params)
return res
def __str__(self):
return str(self.to_json())

+ 23
- 5
sdk/python/ragflow/modules/dataset.py Näytä tiedosto

self.permission = "me" self.permission = "me"
self.document_count = 0 self.document_count = 0
self.chunk_count = 0 self.chunk_count = 0
self.parser_method = "naive"
self.parse_method = "naive"
self.parser_config = None self.parser_config = None
for k in list(res_dict.keys()):
if k == "embd_id":
res_dict["embedding_model"] = res_dict[k]
if k == "parser_id":
res_dict['parse_method'] = res_dict[k]
if k == "doc_num":
res_dict["document_count"] = res_dict[k]
if k == "chunk_num":
res_dict["chunk_count"] = res_dict[k]
if k not in self.__dict__:
res_dict.pop(k)
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def save(self):
def save(self) -> bool:
res = self.post('/dataset/save', res = self.post('/dataset/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id, {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
"description": self.description, "language": self.language, "embd_id": self.embedding_model, "description": self.description, "language": self.language, "embd_id": self.embedding_model,
"permission": self.permission, "permission": self.permission,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
"parser_config": self.parser_config.to_json() "parser_config": self.parser_config.to_json()
}) })
res = res.json() res = res.json()
if not res.get("retmsg"): return True
raise Exception(res["retmsg"])
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])
def delete(self) -> bool:
res = self.rm('/dataset/delete',
{"id": self.id})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])

+ 40
- 14
sdk/python/ragflow/ragflow.py Näytä tiedosto

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.


from typing import List

import requests import requests


from .modules.dataset import DataSet from .modules.dataset import DataSet
""" """
self.user_key = user_key self.user_key = user_key
self.api_url = f"{base_url}/api/{version}" self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)}
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}


def post(self, path, param): def post(self, path, param):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header) res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
return res return res


def get(self, path, params=''):
res = requests.get(self.api_url + path, params=params, headers=self.authorization_header)
def get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res

def delete(self, path, params):
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
return res return res


def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me",
document_count:int=0,chunk_count:int=0,parser_method:str="naive",
parser_config:DataSet.ParserConfig=None):
def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English",
permission: str = "me",
document_count: int = 0, chunk_count: int = 0, parse_method: str = "naive",
parser_config: DataSet.ParserConfig = None) -> DataSet:
if parser_config is None: if parser_config is None:
parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12})
parser_config=parser_config.to_json()
res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission,
"doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method,
"parser_config":parser_config
}
)
parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True,
"delimiter": "\n!?。;!?", "task_page_size": 12})
parser_config = parser_config.to_json()
res = self.post("/dataset/save",
{"name": name, "avatar": avatar, "description": description, "language": language,
"permission": permission,
"doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method,
"parser_config": parser_config
}
)
res = res.json() res = res.json()
if not res.get("retmsg"):
if res.get("retmsg") == "success":
return DataSet(self, res["data"]) return DataSet(self, res["data"])
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])


def list_datasets(self, page: int = 1, page_size: int = 150, orderby: str = "create_time", desc: bool = True) -> \
List[DataSet]:
res = self.get("/dataset/list", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
result_list = []
if res.get("retmsg") == "success":
for data in res['data']:
result_list.append(DataSet(self, data))
return result_list
raise Exception(res["retmsg"])


def get_dataset(self, id: str = None, name: str = None) -> DataSet:
res = self.get("/dataset/detail", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return DataSet(self, res['data'])
raise Exception(res["retmsg"])

+ 36
- 5
sdk/python/test/t_dataset.py Näytä tiedosto

class TestDataset(TestSdk): class TestDataset(TestSdk):
def test_create_dataset_with_success(self): def test_create_dataset_with_success(self):
""" """
Test creating dataset with success
Test creating a dataset with success
""" """
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("God") ds = rag.create_dataset("God")
def test_update_dataset_with_success(self): def test_update_dataset_with_success(self):
""" """
Test updating dataset with success.
Test updating a dataset with success.
""" """
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("ABC") ds = rag.create_dataset("ABC")
if isinstance(ds, DataSet): if isinstance(ds, DataSet):
assert ds.name == "ABC", "Name does not match."
assert ds.name == "ABC", "Name does not match."
ds.name = 'DEF' ds.name = 'DEF'
res = ds.save() res = ds.save()
assert res is True, f"Failed to update dataset, error: {res}"
assert res is True, f"Failed to update dataset, error: {res}"
else:
assert False, f"Failed to create dataset, error: {ds}"
def test_delete_dataset_with_success(self):
"""
Test deleting a dataset with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("MA")
if isinstance(ds, DataSet):
assert ds.name == "MA", "Name does not match."
res = ds.delete()
assert res is True, f"Failed to delete dataset, error: {res}"
else: else:
assert False, f"Failed to create dataset, error: {ds}"
assert False, f"Failed to create dataset, error: {ds}"
def test_list_datasets_with_success(self):
"""
Test listing datasets with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
list_datasets = rag.list_datasets()
assert len(list_datasets) > 0, "Do not exist any dataset"
for ds in list_datasets:
assert isinstance(ds, DataSet), "Existence type is not dataset."
def test_get_detail_dataset_with_success(self):
"""
Test getting a dataset's detail with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.get_dataset(name="God")
assert isinstance(ds, DataSet), f"Failed to get dataset, error: {ds}."
assert ds.name == "God", "Name does not match"

Loading…
Peruuta
Tallenna