浏览代码

add alot of api (#23)

* clean rust version project

* clean rust version project

* build python version rag-flow

* add alot of api
tags/v0.1.0
KevinHuSh 1年前
父节点
当前提交
3198faf2d2

+ 1
- 1
rag/llm/embedding_model.py 查看文件





class HuEmbedding(Base): class HuEmbedding(Base):
def __init__(self):
def __init__(self, key="", model_name=""):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!



+ 6
- 3
rag/nlp/huchunk.py 查看文件

flds = self.Fields() flds = self.Fields()
if self.is_binary_file(fnm): if self.is_binary_file(fnm):
return flds return flds
with open(fnm, "r") as f:
txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
txt = ""
if isinstance(fnm, str):
with open(fnm, "r") as f:
txt = f.read()
else: txt = fnm.decode("utf-8")
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = [] flds.table_chunks = []
return flds return flds



+ 1
- 1
rag/nlp/search.py 查看文件

import numpy as np import numpy as np




def index_name(uid): return f"docgpt_{uid}"
def index_name(uid): return f"ragflow_{uid}"




class Dealer: class Dealer:

+ 35
- 19
rag/svr/parse_user_docs.py 查看文件

# limitations under the License. # limitations under the License.
# #
import json import json
import logging
import os import os
import hashlib import hashlib
import copy import copy


from rag.llm import EmbeddingModel, CvModel from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt
from rag.utils import rmSpace, findMaxTm

from rag.nlp import huchunk, huqie, search from rag.nlp import huchunk, huqie, search
from io import BytesIO from io import BytesIO
import pandas as pd import pandas as pd
from web_server.db import LLMType from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService from web_server.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger
from web_server.utils import get_format_time from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory from web_server.utils.file_utils import get_project_base_directory


if len(docs) == 0: if len(docs) == 0:
return pd.DataFrame() return pd.DataFrame()
docs = pd.DataFrame(docs) docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19]
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs return docs


cron_logger.error("set_progress:({}), {}".format(docid, str(e))) cron_logger.error("set_progress:({}), {}".format(docid, str(e)))




def build(row):
def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE: if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return [] return []

res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0: if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
set_progress(row["id"], random.randint(0, 20) / set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True) 100., "Finished preparing! Start to slice file!", True)
try: try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
set_progress( set_progress(
row["id"], -1, f"Internal server error: %s" % row["id"], -1, f"Internal server error: %s" %
str(e).replace( str(e).replace(
"'", "")) "'", ""))

cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))

return [] return []


if not obj.text_chunks and not obj.table_chunks: if not obj.text_chunks and not obj.table_chunks:
"Finished slicing files. Start to embedding the content.") "Finished slicing files. Start to embedding the content.")


doc = { doc = {
"doc_id": row["did"],
"doc_id": row["id"],
"kb_id": [str(row["kb_id"])], "kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1], "docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]), "title_tks": huqie.qie(row["name"]),
docs.append(d) docs.append(d)
continue continue


if isinstance(img, Image):
img.save(output_buffer, format='JPEG')
else:
if isinstance(img, bytes):
output_buffer = BytesIO(img) output_buffer = BytesIO(img)
else:
img.save(output_buffer, format='JPEG')


MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])




def model_instance(tenant_id, llm_type): def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return
model_config = model_config[0]
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING: if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT: if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])




def main(comm, mod): def main(comm, mod):
from rag.llm import HuEmbedding from rag.llm import HuEmbedding
model = HuEmbedding() model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm)
tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm) rows = collect(comm, mod, tm)
if len(rows) == 0: if len(rows) == 0:
return return
st_tm = timer() st_tm = timer()
cks = build(r, cv_mdl) cks = build(r, cv_mdl)
if not cks: if not cks:
tmf.write(str(r["updated_at"]) + "\n")
tmf.write(str(r["update_time"]) + "\n")
continue continue
# TODO: exception handler # TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
cron_logger.error(str(es_r)) cron_logger.error(str(es_r))
else: else:
set_progress(r["id"], 1., "Done!") set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))

tmf.write(str(r["update_time"]) + "\n") tmf.write(str(r["update_time"]) + "\n")
tmf.close() tmf.close()




if __name__ == "__main__": if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)

from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank()) main(comm.Get_size(), comm.Get_rank())

+ 19
- 0
rag/utils/__init__.py 查看文件

print("WARNING: can't find " + fnm) print("WARNING: can't find " + fnm)
return m return m


def findMaxTm(fnm):
m = 0
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
break
l = l.strip("\n")
if l == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
print("WARNING: can't find " + fnm)
return m

def num_tokens_from_string(string: str) -> int: def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base') encoding = tiktoken.get_encoding('cl100k_base')

+ 1
- 0
rag/utils/es_conn.py 查看文件

except Exception as e: except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " + es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】:" + str(query.to_dict())) str(e) + "【Q】:" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue



+ 47
- 2
web_server/apps/document_app.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import base64
import pathlib import pathlib
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]): if not DocumentService.delete_by_id(req["doc_id"]):
return get_data_error_result( return get_data_error_result(
retmsg="Database error (Document removal)!") retmsg="Database error (Document removal)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
MINIO.rm(kb.id, doc.location)
MINIO.rm(doc.kb_id, doc.location)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
doc_id = request.args["doc_id"]
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
blob = MINIO.get(doc.kb_id, doc.location)
return get_json_result(data={"base64": base64.b64decode(blob)})
except Exception as e:
return server_error_response(e)
@manager.route('/change_parser', methods=['POST'])
@login_required
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
if not e:
return get_data_error_result(retmsg="Document not found!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

+ 14
- 2
web_server/apps/kb_app.py 查看文件

@manager.route('/create', methods=['post']) @manager.route('/create', methods=['post'])
@login_required @login_required
@validate_request("name", "description", "permission", "embd_id", "parser_id")
@validate_request("name", "description", "permission", "parser_id")
def create(): def create():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
@manager.route('/update', methods=['post']) @manager.route('/update', methods=['post'])
@login_required @login_required
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
@validate_request("kb_id", "name", "description", "permission", "parser_id")
def update(): def update():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
return server_error_response(e) return server_error_response(e)
@manager.route('/detail', methods=['GET'])
@login_required
def detail():
kb_id = request.args["kb_id"]
try:
kb = KnowledgebaseService.get_detail(kb_id)
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET']) @manager.route('/list', methods=['GET'])
@login_required @login_required
def list(): def list():

+ 95
- 0
web_server/apps/llm_app.py 查看文件

#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase, TenantLLM
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=fac.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST'])
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"api_key": req["api_key"]
}
# TODO: Test api_key
for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n]
TenantLLM.insert(**llm).on_conflict("replace").execute()
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET'])
@login_required
def my_llms():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs]
for o in objs: del o["api_key"]
return get_json_result(data=objs)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs if o.api_key]
fct = {}
for o in objs:
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = False
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
m["available"] = True
res = {}
for m in llms:
if m["fid"] not in res: res[m["fid"]] = []
res[m["fid"]].append(m)
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)

+ 33
- 5
web_server/apps/user_app.py 查看文件

from flask import request, session, redirect, url_for from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user from flask_login import login_required, current_user, login_user, logout_user
from web_server.db.db_models import TenantLLM
from web_server.db.services.llm_service import TenantLLMService
from web_server.utils.api_utils import server_error_response, validate_request from web_server.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img from web_server.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole
from web_server.db import UserTenantRole, LLMType
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService from web_server.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger from web_server.settings import stat_logger
avatar = download_img(userinfo["avatar_url"]) avatar = download_img(userinfo["avatar_url"])
except Exception as e: except Exception as e:
stat_logger.exception(e) stat_logger.exception(e)
user_id = get_uuid()
try: try:
users = user_register({
users = user_register(user_id, {
"access_token": session["access_token"], "access_token": session["access_token"],
"email": userinfo["email"], "email": userinfo["email"],
"avatar": avatar, "avatar": avatar,
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return server_error_response(e) return server_error_response(e)
elif not request.json: elif not request.json:
return get_json_result(data=current_user.to_dict()) return get_json_result(data=current_user.to_dict())
def user_register(user):
def rollback_user_registration(user_id):
try:
TenantService.delete_by_id(user_id)
except Exception as e:
pass
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
pass
try:
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
except Exception as e:
pass
def user_register(user_id, user):
user_id = get_uuid() user_id = get_uuid()
user["id"] = user_id user["id"] = user_id
tenant = { tenant = {
"invited_by": user_id, "invited_by": user_id,
"role": UserTenantRole.OWNER "role": UserTenantRole.OWNER
} }
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
if not UserService.save(**user):return if not UserService.save(**user):return
TenantService.save(**tenant) TenantService.save(**tenant)
UserTenantService.save(**usr_tenant) UserTenantService.save(**usr_tenant)
TenantLLMService.save(**tenant_llm)
return UserService.query(email=user["email"]) return UserService.query(email=user["email"])
"last_login_time": get_format_time(), "last_login_time": get_format_time(),
"is_superuser": False, "is_superuser": False,
} }
user_id = get_uuid()
try: try:
users = user_register(user_dict)
users = user_register(user_id, user_dict)
if not users: raise Exception('Register user failure.') if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!') if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0] user = users[0]
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@login_required @login_required
def tenant_info(): def tenant_info():
try: try:
tenants = TenantService.get_by_user_id(current_user.id)
tenants = TenantService.get_by_user_id(current_user.id)[0]
return get_json_result(data=tenants) return get_json_result(data=tenants)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

+ 7
- 4
web_server/db/db_models.py 查看文件

class LLM(DataBaseModel): class LLM(DataBaseModel):
# defautlt LLMs for every users # defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id") fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
class TenantLLM(DataBaseModel): class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY") api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base") api_base = CharField(max_length=255, null=True, help_text="API Base")
class Meta: class Meta:
db_table = "tenant_llm" db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory')
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
class Knowledgebase(DataBaseModel): class Knowledgebase(DataBaseModel):
permission = CharField(max_length=16, null=False, help_text="me|team") permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False) created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")

+ 27
- 13
web_server/db/services/document_service.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from peewee import Expression
from web_server.db import TenantPermission, FileType from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase
from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService from web_server.db.services.kb_service import KnowledgebaseService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum from web_server.db.db_utils import StatusEnum
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where(
cls.model.status == StatusEnum.VALID.value,
cls.model.type != FileType.VIRTUAL,
cls.model.progress == 0,
cls.model.update_time >= tm,
cls.model.create_time %
comm == mod).order_by(
cls.model.update_time.asc()).paginate(
1,
items_per_page)
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts()) return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num

+ 33
- 6
web_server/db/services/kb_service.py 查看文件

from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time from web_server.utils import get_uuid, get_format_time
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc):
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where( kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status==StatusEnum.VALID.value)
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
) )
if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page) kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts()) return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

+ 18
- 0
web_server/db/services/llm_service.py 查看文件

class TenantLLMService(CommonService): class TenantLLMService(CommonService):
model = TenantLLM model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]

+ 1
- 1
web_server/db/services/user_service.py 查看文件

@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_user_id(cls, user_id): def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\ return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())

+ 1
- 1
web_server/utils/file_utils.py 查看文件

if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):

正在加载...
取消
保存