* clean rust version project * clean rust version project * build python version rag-flow * add alot of apitags/v0.1.0
| @@ -35,7 +35,7 @@ class Base(ABC): | |||
| class HuEmbedding(Base): | |||
| def __init__(self): | |||
| def __init__(self, key="", model_name=""): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -411,9 +411,12 @@ class TextChunker(HuChunker): | |||
| flds = self.Fields() | |||
| if self.is_binary_file(fnm): | |||
| return flds | |||
| with open(fnm, "r") as f: | |||
| txt = f.read() | |||
| flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] | |||
| txt = "" | |||
| if isinstance(fnm, str): | |||
| with open(fnm, "r") as f: | |||
| txt = f.read() | |||
| else: txt = fnm.decode("utf-8") | |||
| flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] | |||
| flds.table_chunks = [] | |||
| return flds | |||
| @@ -8,7 +8,7 @@ from rag.nlp import huqie, query | |||
| import numpy as np | |||
| def index_name(uid): return f"docgpt_{uid}" | |||
| def index_name(uid): return f"ragflow_{uid}" | |||
| class Dealer: | |||
| @@ -14,6 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| import logging | |||
| import os | |||
| import hashlib | |||
| import copy | |||
| @@ -24,9 +25,10 @@ from timeit import default_timer as timer | |||
| from rag.llm import EmbeddingModel, CvModel | |||
| from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | |||
| from rag.utils import ELASTICSEARCH, num_tokens_from_string | |||
| from rag.utils import ELASTICSEARCH | |||
| from rag.utils import MINIO | |||
| from rag.utils import rmSpace, findMaxDt | |||
| from rag.utils import rmSpace, findMaxTm | |||
| from rag.nlp import huchunk, huqie, search | |||
| from io import BytesIO | |||
| import pandas as pd | |||
| @@ -47,6 +49,7 @@ from rag.nlp.huchunk import ( | |||
| from web_server.db import LLMType | |||
| from web_server.db.services.document_service import DocumentService | |||
| from web_server.db.services.llm_service import TenantLLMService | |||
| from web_server.settings import database_logger | |||
| from web_server.utils import get_format_time | |||
| from web_server.utils.file_utils import get_project_base_directory | |||
| @@ -83,7 +86,7 @@ def collect(comm, mod, tm): | |||
| if len(docs) == 0: | |||
| return pd.DataFrame() | |||
| docs = pd.DataFrame(docs) | |||
| mtm = str(docs["update_time"].max())[:19] | |||
| mtm = docs["update_time"].max() | |||
| cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) | |||
| return docs | |||
| @@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False): | |||
| cron_logger.error("set_progress:({}), {}".format(docid, str(e))) | |||
| def build(row): | |||
| def build(row, cvmdl): | |||
| if row["size"] > DOC_MAXIMUM_SIZE: | |||
| set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % | |||
| (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | |||
| return [] | |||
| res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) | |||
| if ELASTICSEARCH.getTotal(res) > 0: | |||
| ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), | |||
| @@ -120,7 +124,8 @@ def build(row): | |||
| set_progress(row["id"], random.randint(0, 20) / | |||
| 100., "Finished preparing! Start to slice file!", True) | |||
| try: | |||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) | |||
| cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | |||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) | |||
| except Exception as e: | |||
| if re.search("(No such file|not found)", str(e)): | |||
| set_progress( | |||
| @@ -131,6 +136,9 @@ def build(row): | |||
| row["id"], -1, f"Internal server error: %s" % | |||
| str(e).replace( | |||
| "'", "")) | |||
| cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e))) | |||
| return [] | |||
| if not obj.text_chunks and not obj.table_chunks: | |||
| @@ -144,7 +152,7 @@ def build(row): | |||
| "Finished slicing files. Start to embedding the content.") | |||
| doc = { | |||
| "doc_id": row["did"], | |||
| "doc_id": row["id"], | |||
| "kb_id": [str(row["kb_id"])], | |||
| "docnm_kwd": os.path.split(row["location"])[-1], | |||
| "title_tks": huqie.qie(row["name"]), | |||
| @@ -164,10 +172,10 @@ def build(row): | |||
| docs.append(d) | |||
| continue | |||
| if isinstance(img, Image): | |||
| img.save(output_buffer, format='JPEG') | |||
| else: | |||
| if isinstance(img, bytes): | |||
| output_buffer = BytesIO(img) | |||
| else: | |||
| img.save(output_buffer, format='JPEG') | |||
| MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) | |||
| d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) | |||
| @@ -215,15 +223,16 @@ def embedding(docs, mdl): | |||
| def model_instance(tenant_id, llm_type): | |||
| model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) | |||
| if not model_config:return | |||
| model_config = model_config[0] | |||
| model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) | |||
| if not model_config: | |||
| model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} | |||
| else: model_config = model_config[0].to_dict() | |||
| if llm_type == LLMType.EMBEDDING: | |||
| if model_config.llm_factory not in EmbeddingModel: return | |||
| return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) | |||
| if model_config["llm_factory"] not in EmbeddingModel: return | |||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||
| if llm_type == LLMType.IMAGE2TEXT: | |||
| if model_config.llm_factory not in CvModel: return | |||
| return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) | |||
| if model_config["llm_factory"] not in CvModel: return | |||
| return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"]) | |||
| def main(comm, mod): | |||
| @@ -231,7 +240,7 @@ def main(comm, mod): | |||
| from rag.llm import HuEmbedding | |||
| model = HuEmbedding() | |||
| tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") | |||
| tm = findMaxDt(tm_fnm) | |||
| tm = findMaxTm(tm_fnm) | |||
| rows = collect(comm, mod, tm) | |||
| if len(rows) == 0: | |||
| return | |||
| @@ -247,7 +256,7 @@ def main(comm, mod): | |||
| st_tm = timer() | |||
| cks = build(r, cv_mdl) | |||
| if not cks: | |||
| tmf.write(str(r["updated_at"]) + "\n") | |||
| tmf.write(str(r["update_time"]) + "\n") | |||
| continue | |||
| # TODO: exception handler | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| @@ -268,12 +277,19 @@ def main(comm, mod): | |||
| cron_logger.error(str(es_r)) | |||
| else: | |||
| set_progress(r["id"], 1., "Done!") | |||
| DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) | |||
| DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm) | |||
| cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | |||
| tmf.write(str(r["update_time"]) + "\n") | |||
| tmf.close() | |||
| if __name__ == "__main__": | |||
| peewee_logger = logging.getLogger('peewee') | |||
| peewee_logger.propagate = False | |||
| peewee_logger.addHandler(database_logger.handlers[0]) | |||
| peewee_logger.setLevel(database_logger.level) | |||
| from mpi4py import MPI | |||
| comm = MPI.COMM_WORLD | |||
| main(comm.Get_size(), comm.Get_rank()) | |||
| @@ -40,6 +40,25 @@ def findMaxDt(fnm): | |||
| print("WARNING: can't find " + fnm) | |||
| return m | |||
| def findMaxTm(fnm): | |||
| m = 0 | |||
| try: | |||
| with open(fnm, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: | |||
| break | |||
| l = l.strip("\n") | |||
| if l == 'nan': | |||
| continue | |||
| if int(l) > m: | |||
| m = int(l) | |||
| except Exception as e: | |||
| print("WARNING: can't find " + fnm) | |||
| return m | |||
| def num_tokens_from_string(string: str) -> int: | |||
| """Returns the number of tokens in a text string.""" | |||
| encoding = tiktoken.get_encoding('cl100k_base') | |||
| @@ -294,6 +294,7 @@ class HuEs: | |||
| except Exception as e: | |||
| es_logger.error("ES updateByQuery deleteByQuery: " + | |||
| str(e) + "【Q】:" + str(query.to_dict())) | |||
| if str(e).find("NotFoundError") > 0: return True | |||
| if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | |||
| continue | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import base64 | |||
| import pathlib | |||
| from elasticsearch_dsl import Q | |||
| @@ -195,11 +196,15 @@ def rm(): | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): | |||
| return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) | |||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) | |||
| if not DocumentService.delete_by_id(req["doc_id"]): | |||
| return get_data_error_result( | |||
| retmsg="Database error (Document removal)!") | |||
| e, kb = KnowledgebaseService.get_by_id(doc.kb_id) | |||
| MINIO.rm(kb.id, doc.location) | |||
| MINIO.rm(doc.kb_id, doc.location) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -233,3 +238,43 @@ def rename(): | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/get', methods=['GET']) | |||
| @login_required | |||
| def get(): | |||
| doc_id = request.args["doc_id"] | |||
| try: | |||
| e, doc = DocumentService.get_by_id(doc_id) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| blob = MINIO.get(doc.kb_id, doc.location) | |||
| return get_json_result(data={"base64": base64.b64decode(blob)}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/change_parser', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id", "parser_id") | |||
| def change_parser(): | |||
| req = request.json | |||
| try: | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if doc.parser_id.lower() == req["parser_id"].lower(): | |||
| return get_json_result(data=True) | |||
| e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result | |||
| @manager.route('/create', methods=['post']) | |||
| @login_required | |||
| @validate_request("name", "description", "permission", "embd_id", "parser_id") | |||
| @validate_request("name", "description", "permission", "parser_id") | |||
| def create(): | |||
| req = request.json | |||
| req["name"] = req["name"].strip() | |||
| @@ -46,7 +46,7 @@ def create(): | |||
| @manager.route('/update', methods=['post']) | |||
| @login_required | |||
| @validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") | |||
| @validate_request("kb_id", "name", "description", "permission", "parser_id") | |||
| def update(): | |||
| req = request.json | |||
| req["name"] = req["name"].strip() | |||
| @@ -72,6 +72,18 @@ def update(): | |||
| return server_error_response(e) | |||
| @manager.route('/detail', methods=['GET']) | |||
| @login_required | |||
| def detail(): | |||
| kb_id = request.args["kb_id"] | |||
| try: | |||
| kb = KnowledgebaseService.get_detail(kb_id) | |||
| if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") | |||
| return get_json_result(data=kb) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/list', methods=['GET']) | |||
| @login_required | |||
| def list(): | |||
| @@ -0,0 +1,95 @@ | |||
| # | |||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from web_server.db.services import duplicate_name | |||
| from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||
| from web_server.db.services.user_service import TenantService, UserTenantService | |||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db import StatusEnum, UserTenantRole | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.db.db_models import Knowledgebase, TenantLLM | |||
| from web_server.settings import stat_logger, RetCode | |||
| from web_server.utils.api_utils import get_json_result | |||
| @manager.route('/factories', methods=['GET']) | |||
| @login_required | |||
| def factories(): | |||
| try: | |||
| fac = LLMFactoriesService.get_all() | |||
| return get_json_result(data=fac.to_json()) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/set_api_key', methods=['POST']) | |||
| @login_required | |||
| @validate_request("llm_factory", "api_key") | |||
| def set_api_key(): | |||
| req = request.json | |||
| llm = { | |||
| "tenant_id": current_user.id, | |||
| "llm_factory": req["llm_factory"], | |||
| "api_key": req["api_key"] | |||
| } | |||
| # TODO: Test api_key | |||
| for n in ["model_type", "llm_name"]: | |||
| if n in req: llm[n] = req[n] | |||
| TenantLLM.insert(**llm).on_conflict("replace").execute() | |||
| return get_json_result(data=True) | |||
| @manager.route('/my_llms', methods=['GET']) | |||
| @login_required | |||
| def my_llms(): | |||
| try: | |||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||
| objs = [o.to_dict() for o in objs] | |||
| for o in objs: del o["api_key"] | |||
| return get_json_result(data=objs) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route('/list', methods=['GET']) | |||
| @login_required | |||
| def list(): | |||
| try: | |||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||
| objs = [o.to_dict() for o in objs if o.api_key] | |||
| fct = {} | |||
| for o in objs: | |||
| if o["llm_factory"] not in fct: fct[o["llm_factory"]] = [] | |||
| if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"]) | |||
| llms = LLMService.get_all() | |||
| llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | |||
| for m in llms: | |||
| m["available"] = False | |||
| if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]): | |||
| m["available"] = True | |||
| res = {} | |||
| for m in llms: | |||
| if m["fid"] not in res: res[m["fid"]] = [] | |||
| res[m["fid"]].append(m) | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -16,9 +16,12 @@ | |||
| from flask import request, session, redirect, url_for | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from flask_login import login_required, current_user, login_user, logout_user | |||
| from web_server.db.db_models import TenantLLM | |||
| from web_server.db.services.llm_service import TenantLLMService | |||
| from web_server.utils.api_utils import server_error_response, validate_request | |||
| from web_server.utils import get_uuid, get_format_time, decrypt, download_img | |||
| from web_server.db import UserTenantRole | |||
| from web_server.db import UserTenantRole, LLMType | |||
| from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS | |||
| from web_server.db.services.user_service import UserService, TenantService, UserTenantService | |||
| from web_server.settings import stat_logger | |||
| @@ -47,8 +50,9 @@ def login(): | |||
| avatar = download_img(userinfo["avatar_url"]) | |||
| except Exception as e: | |||
| stat_logger.exception(e) | |||
| user_id = get_uuid() | |||
| try: | |||
| users = user_register({ | |||
| users = user_register(user_id, { | |||
| "access_token": session["access_token"], | |||
| "email": userinfo["email"], | |||
| "avatar": avatar, | |||
| @@ -63,6 +67,7 @@ def login(): | |||
| login_user(user) | |||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") | |||
| except Exception as e: | |||
| rollback_user_registration(user_id) | |||
| stat_logger.exception(e) | |||
| return server_error_response(e) | |||
| elif not request.json: | |||
| @@ -162,7 +167,25 @@ def user_info(): | |||
| return get_json_result(data=current_user.to_dict()) | |||
| def user_register(user): | |||
| def rollback_user_registration(user_id): | |||
| try: | |||
| TenantService.delete_by_id(user_id) | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| u = UserTenantService.query(tenant_id=user_id) | |||
| if u: | |||
| UserTenantService.delete_by_id(u[0].id) | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() | |||
| except Exception as e: | |||
| pass | |||
| def user_register(user_id, user): | |||
| user_id = get_uuid() | |||
| user["id"] = user_id | |||
| tenant = { | |||
| @@ -180,10 +203,12 @@ def user_register(user): | |||
| "invited_by": user_id, | |||
| "role": UserTenantRole.OWNER | |||
| } | |||
| tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"} | |||
| if not UserService.save(**user):return | |||
| TenantService.save(**tenant) | |||
| UserTenantService.save(**usr_tenant) | |||
| TenantLLMService.save(**tenant_llm) | |||
| return UserService.query(email=user["email"]) | |||
| @@ -203,14 +228,17 @@ def user_add(): | |||
| "last_login_time": get_format_time(), | |||
| "is_superuser": False, | |||
| } | |||
| user_id = get_uuid() | |||
| try: | |||
| users = user_register(user_dict) | |||
| users = user_register(user_id, user_dict) | |||
| if not users: raise Exception('Register user failure.') | |||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||
| user = users[0] | |||
| login_user(user) | |||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") | |||
| except Exception as e: | |||
| rollback_user_registration(user_id) | |||
| stat_logger.exception(e) | |||
| return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | |||
| @@ -220,7 +248,7 @@ def user_add(): | |||
| @login_required | |||
| def tenant_info(): | |||
| try: | |||
| tenants = TenantService.get_by_user_id(current_user.id) | |||
| tenants = TenantService.get_by_user_id(current_user.id)[0] | |||
| return get_json_result(data=tenants) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel): | |||
| class LLM(DataBaseModel): | |||
| # defautlt LLMs for every users | |||
| llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) | |||
| model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| fid = CharField(max_length=128, null=False, help_text="LLM factory id") | |||
| tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| @@ -442,8 +443,8 @@ class LLM(DataBaseModel): | |||
| class TenantLLM(DataBaseModel): | |||
| tenant_id = CharField(max_length=32, null=False) | |||
| llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") | |||
| model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| llm_name = CharField(max_length=128, null=False, help_text="LLM name") | |||
| model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") | |||
| api_key = CharField(max_length=255, null=True, help_text="API KEY") | |||
| api_base = CharField(max_length=255, null=True, help_text="API Base") | |||
| @@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel): | |||
| class Meta: | |||
| db_table = "tenant_llm" | |||
| primary_key = CompositeKey('tenant_id', 'llm_factory') | |||
| primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name') | |||
| class Knowledgebase(DataBaseModel): | |||
| @@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel): | |||
| permission = CharField(max_length=16, null=False, help_text="me|team") | |||
| created_by = CharField(max_length=32, null=False) | |||
| doc_num = IntegerField(default=0) | |||
| embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") | |||
| token_num = IntegerField(default=0) | |||
| chunk_num = IntegerField(default=0) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| @@ -13,12 +13,13 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from peewee import Expression | |||
| from web_server.db import TenantPermission, FileType | |||
| from web_server.db.db_models import DB, Knowledgebase | |||
| from web_server.db.db_models import DB, Knowledgebase, Tenant | |||
| from web_server.db.db_models import Document | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.db.services.kb_service import KnowledgebaseService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| from web_server.db.db_utils import StatusEnum | |||
| @@ -61,15 +62,28 @@ class DocumentService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): | |||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] | |||
| docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( | |||
| cls.model.status == StatusEnum.VALID.value, | |||
| cls.model.type != FileType.VIRTUAL, | |||
| cls.model.progress == 0, | |||
| cls.model.update_time >= tm, | |||
| cls.model.create_time % | |||
| comm == mod).order_by( | |||
| cls.model.update_time.asc()).paginate( | |||
| 1, | |||
| items_per_page) | |||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| .where( | |||
| cls.model.status == StatusEnum.VALID.value, | |||
| ~(cls.model.type == FileType.VIRTUAL.value), | |||
| cls.model.progress == 0, | |||
| cls.model.update_time >= tm, | |||
| (Expression(cls.model.create_time, "%%", comm) == mod))\ | |||
| .order_by(cls.model.update_time.asc())\ | |||
| .paginate(1, items_per_page) | |||
| return list(docs.dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): | |||
| num = cls.model.update(token_num=cls.model.token_num + token_num, | |||
| chunk_num=cls.model.chunk_num + chunk_num, | |||
| process_duation=cls.model.process_duation+duation).where( | |||
| cls.model.id == doc_id).execute() | |||
| if num == 0:raise LookupError("Document not found which is supposed to be there") | |||
| num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() | |||
| return num | |||
| @@ -17,7 +17,7 @@ import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from web_server.db import TenantPermission | |||
| from web_server.db.db_models import DB, UserTenant | |||
| from web_server.db.db_models import DB, UserTenant, Tenant | |||
| from web_server.db.db_models import Knowledgebase | |||
| from web_server.db.services.common_service import CommonService | |||
| from web_server.utils import get_uuid, get_format_time | |||
| @@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): | |||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | |||
| page_number, items_per_page, orderby, desc): | |||
| kbs = cls.model.select().where( | |||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) | |||
| & (cls.model.status==StatusEnum.VALID.value) | |||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | |||
| TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) | |||
| & (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||
| else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||
| if desc: | |||
| kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||
| else: | |||
| kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||
| kbs = kbs.paginate(page_number, items_per_page) | |||
| return list(kbs.dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_detail(cls, kb_id): | |||
| fields = [ | |||
| cls.model.id, | |||
| Tenant.embd_id, | |||
| cls.model.avatar, | |||
| cls.model.name, | |||
| cls.model.description, | |||
| cls.model.permission, | |||
| cls.model.doc_num, | |||
| cls.model.token_num, | |||
| cls.model.chunk_num, | |||
| cls.model.parser_id] | |||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | |||
| (cls.model.id == kb_id), | |||
| (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| if not kbs: | |||
| return | |||
| d = kbs[0].to_dict() | |||
| d["embd_id"] = kbs[0].tenant.embd_id | |||
| return d | |||
| @@ -33,3 +33,21 @@ class LLMService(CommonService): | |||
| class TenantLLMService(CommonService): | |||
| model = TenantLLM | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_api_key(cls, tenant_id, model_type): | |||
| objs = cls.query(tenant_id=tenant_id, model_type=model_type) | |||
| if objs and len(objs)>0 and objs[0].llm_name: | |||
| return objs[0] | |||
| fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key] | |||
| objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where( | |||
| (cls.model.tenant_id == tenant_id), | |||
| (cls.model.model_type == model_type), | |||
| (LLM.status == StatusEnum.VALID) | |||
| ) | |||
| if not objs:return | |||
| return objs[0] | |||
| @@ -79,7 +79,7 @@ class TenantService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_user_id(cls, user_id): | |||
| fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] | |||
| fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] | |||
| return list(cls.model.select(*fields)\ | |||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ | |||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | |||
| @@ -143,7 +143,7 @@ def filename_type(filename): | |||
| if re.match(r".*\.pdf$", filename): | |||
| return FileType.PDF.value | |||
| if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): | |||
| if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| return FileType.DOC.value | |||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | |||