* clean rust version project * clean rust version project * build python version rag-flow * add alot of apitags/v0.1.0
| class HuEmbedding(Base): | class HuEmbedding(Base): | ||||
| def __init__(self): | |||||
| def __init__(self, key="", model_name=""): | |||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| flds = self.Fields() | flds = self.Fields() | ||||
| if self.is_binary_file(fnm): | if self.is_binary_file(fnm): | ||||
| return flds | return flds | ||||
| with open(fnm, "r") as f: | |||||
| txt = f.read() | |||||
| flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] | |||||
| txt = "" | |||||
| if isinstance(fnm, str): | |||||
| with open(fnm, "r") as f: | |||||
| txt = f.read() | |||||
| else: txt = fnm.decode("utf-8") | |||||
| flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] | |||||
| flds.table_chunks = [] | flds.table_chunks = [] | ||||
| return flds | return flds | ||||
| import numpy as np | import numpy as np | ||||
| def index_name(uid): return f"docgpt_{uid}" | |||||
| def index_name(uid): return f"ragflow_{uid}" | |||||
| class Dealer: | class Dealer: |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import json | import json | ||||
| import logging | |||||
| import os | import os | ||||
| import hashlib | import hashlib | ||||
| import copy | import copy | ||||
| from rag.llm import EmbeddingModel, CvModel | from rag.llm import EmbeddingModel, CvModel | ||||
| from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | ||||
| from rag.utils import ELASTICSEARCH, num_tokens_from_string | |||||
| from rag.utils import ELASTICSEARCH | |||||
| from rag.utils import MINIO | from rag.utils import MINIO | ||||
| from rag.utils import rmSpace, findMaxDt | |||||
| from rag.utils import rmSpace, findMaxTm | |||||
| from rag.nlp import huchunk, huqie, search | from rag.nlp import huchunk, huqie, search | ||||
| from io import BytesIO | from io import BytesIO | ||||
| import pandas as pd | import pandas as pd | ||||
| from web_server.db import LLMType | from web_server.db import LLMType | ||||
| from web_server.db.services.document_service import DocumentService | from web_server.db.services.document_service import DocumentService | ||||
| from web_server.db.services.llm_service import TenantLLMService | from web_server.db.services.llm_service import TenantLLMService | ||||
| from web_server.settings import database_logger | |||||
| from web_server.utils import get_format_time | from web_server.utils import get_format_time | ||||
| from web_server.utils.file_utils import get_project_base_directory | from web_server.utils.file_utils import get_project_base_directory | ||||
| if len(docs) == 0: | if len(docs) == 0: | ||||
| return pd.DataFrame() | return pd.DataFrame() | ||||
| docs = pd.DataFrame(docs) | docs = pd.DataFrame(docs) | ||||
| mtm = str(docs["update_time"].max())[:19] | |||||
| mtm = docs["update_time"].max() | |||||
| cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) | cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) | ||||
| return docs | return docs | ||||
| cron_logger.error("set_progress:({}), {}".format(docid, str(e))) | cron_logger.error("set_progress:({}), {}".format(docid, str(e))) | ||||
| def build(row): | |||||
| def build(row, cvmdl): | |||||
| if row["size"] > DOC_MAXIMUM_SIZE: | if row["size"] > DOC_MAXIMUM_SIZE: | ||||
| set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % | set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % | ||||
| (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | ||||
| return [] | return [] | ||||
| res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) | res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) | ||||
| if ELASTICSEARCH.getTotal(res) > 0: | if ELASTICSEARCH.getTotal(res) > 0: | ||||
| ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), | ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), | ||||
| set_progress(row["id"], random.randint(0, 20) / | set_progress(row["id"], random.randint(0, 20) / | ||||
| 100., "Finished preparing! Start to slice file!", True) | 100., "Finished preparing! Start to slice file!", True) | ||||
| try: | try: | ||||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) | |||||
| cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | |||||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) | |||||
| except Exception as e: | except Exception as e: | ||||
| if re.search("(No such file|not found)", str(e)): | if re.search("(No such file|not found)", str(e)): | ||||
| set_progress( | set_progress( | ||||
| row["id"], -1, f"Internal server error: %s" % | row["id"], -1, f"Internal server error: %s" % | ||||
| str(e).replace( | str(e).replace( | ||||
| "'", "")) | "'", "")) | ||||
| cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e))) | |||||
| return [] | return [] | ||||
| if not obj.text_chunks and not obj.table_chunks: | if not obj.text_chunks and not obj.table_chunks: | ||||
| "Finished slicing files. Start to embedding the content.") | "Finished slicing files. Start to embedding the content.") | ||||
| doc = { | doc = { | ||||
| "doc_id": row["did"], | |||||
| "doc_id": row["id"], | |||||
| "kb_id": [str(row["kb_id"])], | "kb_id": [str(row["kb_id"])], | ||||
| "docnm_kwd": os.path.split(row["location"])[-1], | "docnm_kwd": os.path.split(row["location"])[-1], | ||||
| "title_tks": huqie.qie(row["name"]), | "title_tks": huqie.qie(row["name"]), | ||||
| docs.append(d) | docs.append(d) | ||||
| continue | continue | ||||
| if isinstance(img, Image): | |||||
| img.save(output_buffer, format='JPEG') | |||||
| else: | |||||
| if isinstance(img, bytes): | |||||
| output_buffer = BytesIO(img) | output_buffer = BytesIO(img) | ||||
| else: | |||||
| img.save(output_buffer, format='JPEG') | |||||
| MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) | MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) | ||||
| d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) | d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) | ||||
| def model_instance(tenant_id, llm_type): | def model_instance(tenant_id, llm_type): | ||||
| model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) | |||||
| if not model_config:return | |||||
| model_config = model_config[0] | |||||
| model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) | |||||
| if not model_config: | |||||
| model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} | |||||
| else: model_config = model_config[0].to_dict() | |||||
| if llm_type == LLMType.EMBEDDING: | if llm_type == LLMType.EMBEDDING: | ||||
| if model_config.llm_factory not in EmbeddingModel: return | |||||
| return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) | |||||
| if model_config["llm_factory"] not in EmbeddingModel: return | |||||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||||
| if llm_type == LLMType.IMAGE2TEXT: | if llm_type == LLMType.IMAGE2TEXT: | ||||
| if model_config.llm_factory not in CvModel: return | |||||
| return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) | |||||
| if model_config["llm_factory"] not in CvModel: return | |||||
| return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"]) | |||||
| def main(comm, mod): | def main(comm, mod): | ||||
| from rag.llm import HuEmbedding | from rag.llm import HuEmbedding | ||||
| model = HuEmbedding() | model = HuEmbedding() | ||||
| tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") | tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") | ||||
| tm = findMaxDt(tm_fnm) | |||||
| tm = findMaxTm(tm_fnm) | |||||
| rows = collect(comm, mod, tm) | rows = collect(comm, mod, tm) | ||||
| if len(rows) == 0: | if len(rows) == 0: | ||||
| return | return | ||||
| st_tm = timer() | st_tm = timer() | ||||
| cks = build(r, cv_mdl) | cks = build(r, cv_mdl) | ||||
| if not cks: | if not cks: | ||||
| tmf.write(str(r["updated_at"]) + "\n") | |||||
| tmf.write(str(r["update_time"]) + "\n") | |||||
| continue | continue | ||||
| # TODO: exception handler | # TODO: exception handler | ||||
| ## set_progress(r["did"], -1, "ERROR: ") | ## set_progress(r["did"], -1, "ERROR: ") | ||||
| cron_logger.error(str(es_r)) | cron_logger.error(str(es_r)) | ||||
| else: | else: | ||||
| set_progress(r["id"], 1., "Done!") | set_progress(r["id"], 1., "Done!") | ||||
| DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) | |||||
| DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm) | |||||
| cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | |||||
| tmf.write(str(r["update_time"]) + "\n") | tmf.write(str(r["update_time"]) + "\n") | ||||
| tmf.close() | tmf.close() | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| peewee_logger = logging.getLogger('peewee') | |||||
| peewee_logger.propagate = False | |||||
| peewee_logger.addHandler(database_logger.handlers[0]) | |||||
| peewee_logger.setLevel(database_logger.level) | |||||
| from mpi4py import MPI | from mpi4py import MPI | ||||
| comm = MPI.COMM_WORLD | comm = MPI.COMM_WORLD | ||||
| main(comm.Get_size(), comm.Get_rank()) | main(comm.Get_size(), comm.Get_rank()) |
| print("WARNING: can't find " + fnm) | print("WARNING: can't find " + fnm) | ||||
| return m | return m | ||||
| def findMaxTm(fnm): | |||||
| m = 0 | |||||
| try: | |||||
| with open(fnm, "r") as f: | |||||
| while True: | |||||
| l = f.readline() | |||||
| if not l: | |||||
| break | |||||
| l = l.strip("\n") | |||||
| if l == 'nan': | |||||
| continue | |||||
| if int(l) > m: | |||||
| m = int(l) | |||||
| except Exception as e: | |||||
| print("WARNING: can't find " + fnm) | |||||
| return m | |||||
| def num_tokens_from_string(string: str) -> int: | def num_tokens_from_string(string: str) -> int: | ||||
| """Returns the number of tokens in a text string.""" | """Returns the number of tokens in a text string.""" | ||||
| encoding = tiktoken.get_encoding('cl100k_base') | encoding = tiktoken.get_encoding('cl100k_base') |
| except Exception as e: | except Exception as e: | ||||
| es_logger.error("ES updateByQuery deleteByQuery: " + | es_logger.error("ES updateByQuery deleteByQuery: " + | ||||
| str(e) + "【Q】:" + str(query.to_dict())) | str(e) + "【Q】:" + str(query.to_dict())) | ||||
| if str(e).find("NotFoundError") > 0: return True | |||||
| if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | ||||
| continue | continue | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import base64 | |||||
| import pathlib | import pathlib | ||||
| from elasticsearch_dsl import Q | from elasticsearch_dsl import Q | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(retmsg="Document not found!") | return get_data_error_result(retmsg="Document not found!") | ||||
| if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): | |||||
| return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) | |||||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) | |||||
| if not DocumentService.delete_by_id(req["doc_id"]): | if not DocumentService.delete_by_id(req["doc_id"]): | ||||
| return get_data_error_result( | return get_data_error_result( | ||||
| retmsg="Database error (Document removal)!") | retmsg="Database error (Document removal)!") | ||||
| e, kb = KnowledgebaseService.get_by_id(doc.kb_id) | |||||
| MINIO.rm(kb.id, doc.location) | |||||
| MINIO.rm(doc.kb_id, doc.location) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @manager.route('/get', methods=['GET']) | |||||
| @login_required | |||||
| def get(): | |||||
| doc_id = request.args["doc_id"] | |||||
| try: | |||||
| e, doc = DocumentService.get_by_id(doc_id) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Document not found!") | |||||
| blob = MINIO.get(doc.kb_id, doc.location) | |||||
| return get_json_result(data={"base64": base64.b64decode(blob)}) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/change_parser', methods=['POST']) | |||||
| @login_required | |||||
| @validate_request("doc_id", "parser_id") | |||||
| def change_parser(): | |||||
| req = request.json | |||||
| try: | |||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Document not found!") | |||||
| if doc.parser_id.lower() == req["parser_id"].lower(): | |||||
| return get_json_result(data=True) | |||||
| e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Document not found!") | |||||
| e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Document not found!") | |||||
| return get_json_result(data=True) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/create', methods=['post']) | @manager.route('/create', methods=['post']) | ||||
| @login_required | @login_required | ||||
| @validate_request("name", "description", "permission", "embd_id", "parser_id") | |||||
| @validate_request("name", "description", "permission", "parser_id") | |||||
| def create(): | def create(): | ||||
| req = request.json | req = request.json | ||||
| req["name"] = req["name"].strip() | req["name"] = req["name"].strip() | ||||
| @manager.route('/update', methods=['post']) | @manager.route('/update', methods=['post']) | ||||
| @login_required | @login_required | ||||
| @validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") | |||||
| @validate_request("kb_id", "name", "description", "permission", "parser_id") | |||||
| def update(): | def update(): | ||||
| req = request.json | req = request.json | ||||
| req["name"] = req["name"].strip() | req["name"] = req["name"].strip() | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @manager.route('/detail', methods=['GET']) | |||||
| @login_required | |||||
| def detail(): | |||||
| kb_id = request.args["kb_id"] | |||||
| try: | |||||
| kb = KnowledgebaseService.get_detail(kb_id) | |||||
| if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") | |||||
| return get_json_result(data=kb) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/list', methods=['GET']) | @manager.route('/list', methods=['GET']) | ||||
| @login_required | @login_required | ||||
| def list(): | def list(): |
| # | |||||
| # Copyright 2019 The FATE Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from flask import request | |||||
| from flask_login import login_required, current_user | |||||
| from web_server.db.services import duplicate_name | |||||
| from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||||
| from web_server.db.services.user_service import TenantService, UserTenantService | |||||
| from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||||
| from web_server.utils import get_uuid, get_format_time | |||||
| from web_server.db import StatusEnum, UserTenantRole | |||||
| from web_server.db.services.kb_service import KnowledgebaseService | |||||
| from web_server.db.db_models import Knowledgebase, TenantLLM | |||||
| from web_server.settings import stat_logger, RetCode | |||||
| from web_server.utils.api_utils import get_json_result | |||||
| @manager.route('/factories', methods=['GET']) | |||||
| @login_required | |||||
| def factories(): | |||||
| try: | |||||
| fac = LLMFactoriesService.get_all() | |||||
| return get_json_result(data=fac.to_json()) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/set_api_key', methods=['POST']) | |||||
| @login_required | |||||
| @validate_request("llm_factory", "api_key") | |||||
| def set_api_key(): | |||||
| req = request.json | |||||
| llm = { | |||||
| "tenant_id": current_user.id, | |||||
| "llm_factory": req["llm_factory"], | |||||
| "api_key": req["api_key"] | |||||
| } | |||||
| # TODO: Test api_key | |||||
| for n in ["model_type", "llm_name"]: | |||||
| if n in req: llm[n] = req[n] | |||||
| TenantLLM.insert(**llm).on_conflict("replace").execute() | |||||
| return get_json_result(data=True) | |||||
| @manager.route('/my_llms', methods=['GET']) | |||||
| @login_required | |||||
| def my_llms(): | |||||
| try: | |||||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||||
| objs = [o.to_dict() for o in objs] | |||||
| for o in objs: del o["api_key"] | |||||
| return get_json_result(data=objs) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route('/list', methods=['GET']) | |||||
| @login_required | |||||
| def list(): | |||||
| try: | |||||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||||
| objs = [o.to_dict() for o in objs if o.api_key] | |||||
| fct = {} | |||||
| for o in objs: | |||||
| if o["llm_factory"] not in fct: fct[o["llm_factory"]] = [] | |||||
| if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"]) | |||||
| llms = LLMService.get_all() | |||||
| llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | |||||
| for m in llms: | |||||
| m["available"] = False | |||||
| if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]): | |||||
| m["available"] = True | |||||
| res = {} | |||||
| for m in llms: | |||||
| if m["fid"] not in res: res[m["fid"]] = [] | |||||
| res[m["fid"]].append(m) | |||||
| return get_json_result(data=res) | |||||
| except Exception as e: | |||||
| return server_error_response(e) |
| from flask import request, session, redirect, url_for | from flask import request, session, redirect, url_for | ||||
| from werkzeug.security import generate_password_hash, check_password_hash | from werkzeug.security import generate_password_hash, check_password_hash | ||||
| from flask_login import login_required, current_user, login_user, logout_user | from flask_login import login_required, current_user, login_user, logout_user | ||||
| from web_server.db.db_models import TenantLLM | |||||
| from web_server.db.services.llm_service import TenantLLMService | |||||
| from web_server.utils.api_utils import server_error_response, validate_request | from web_server.utils.api_utils import server_error_response, validate_request | ||||
| from web_server.utils import get_uuid, get_format_time, decrypt, download_img | from web_server.utils import get_uuid, get_format_time, decrypt, download_img | ||||
| from web_server.db import UserTenantRole | |||||
| from web_server.db import UserTenantRole, LLMType | |||||
| from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS | from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS | ||||
| from web_server.db.services.user_service import UserService, TenantService, UserTenantService | from web_server.db.services.user_service import UserService, TenantService, UserTenantService | ||||
| from web_server.settings import stat_logger | from web_server.settings import stat_logger | ||||
| avatar = download_img(userinfo["avatar_url"]) | avatar = download_img(userinfo["avatar_url"]) | ||||
| except Exception as e: | except Exception as e: | ||||
| stat_logger.exception(e) | stat_logger.exception(e) | ||||
| user_id = get_uuid() | |||||
| try: | try: | ||||
| users = user_register({ | |||||
| users = user_register(user_id, { | |||||
| "access_token": session["access_token"], | "access_token": session["access_token"], | ||||
| "email": userinfo["email"], | "email": userinfo["email"], | ||||
| "avatar": avatar, | "avatar": avatar, | ||||
| login_user(user) | login_user(user) | ||||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") | return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") | ||||
| except Exception as e: | except Exception as e: | ||||
| rollback_user_registration(user_id) | |||||
| stat_logger.exception(e) | stat_logger.exception(e) | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| elif not request.json: | elif not request.json: | ||||
| return get_json_result(data=current_user.to_dict()) | return get_json_result(data=current_user.to_dict()) | ||||
| def user_register(user): | |||||
| def rollback_user_registration(user_id): | |||||
| try: | |||||
| TenantService.delete_by_id(user_id) | |||||
| except Exception as e: | |||||
| pass | |||||
| try: | |||||
| u = UserTenantService.query(tenant_id=user_id) | |||||
| if u: | |||||
| UserTenantService.delete_by_id(u[0].id) | |||||
| except Exception as e: | |||||
| pass | |||||
| try: | |||||
| TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() | |||||
| except Exception as e: | |||||
| pass | |||||
| def user_register(user_id, user): | |||||
| user_id = get_uuid() | user_id = get_uuid() | ||||
| user["id"] = user_id | user["id"] = user_id | ||||
| tenant = { | tenant = { | ||||
| "invited_by": user_id, | "invited_by": user_id, | ||||
| "role": UserTenantRole.OWNER | "role": UserTenantRole.OWNER | ||||
| } | } | ||||
| tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"} | |||||
| if not UserService.save(**user):return | if not UserService.save(**user):return | ||||
| TenantService.save(**tenant) | TenantService.save(**tenant) | ||||
| UserTenantService.save(**usr_tenant) | UserTenantService.save(**usr_tenant) | ||||
| TenantLLMService.save(**tenant_llm) | |||||
| return UserService.query(email=user["email"]) | return UserService.query(email=user["email"]) | ||||
| "last_login_time": get_format_time(), | "last_login_time": get_format_time(), | ||||
| "is_superuser": False, | "is_superuser": False, | ||||
| } | } | ||||
| user_id = get_uuid() | |||||
| try: | try: | ||||
| users = user_register(user_dict) | |||||
| users = user_register(user_id, user_dict) | |||||
| if not users: raise Exception('Register user failure.') | if not users: raise Exception('Register user failure.') | ||||
| if len(users) > 1: raise Exception('Same E-mail exist!') | if len(users) > 1: raise Exception('Same E-mail exist!') | ||||
| user = users[0] | user = users[0] | ||||
| login_user(user) | login_user(user) | ||||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") | return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") | ||||
| except Exception as e: | except Exception as e: | ||||
| rollback_user_registration(user_id) | |||||
| stat_logger.exception(e) | stat_logger.exception(e) | ||||
| return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | ||||
| @login_required | @login_required | ||||
| def tenant_info(): | def tenant_info(): | ||||
| try: | try: | ||||
| tenants = TenantService.get_by_user_id(current_user.id) | |||||
| tenants = TenantService.get_by_user_id(current_user.id)[0] | |||||
| return get_json_result(data=tenants) | return get_json_result(data=tenants) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) |
| class LLM(DataBaseModel): | class LLM(DataBaseModel): | ||||
| # defautlt LLMs for every users | # defautlt LLMs for every users | ||||
| llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) | llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) | ||||
| model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") | |||||
| fid = CharField(max_length=128, null=False, help_text="LLM factory id") | fid = CharField(max_length=128, null=False, help_text="LLM factory id") | ||||
| tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") | tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") | ||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | ||||
| class TenantLLM(DataBaseModel): | class TenantLLM(DataBaseModel): | ||||
| tenant_id = CharField(max_length=32, null=False) | tenant_id = CharField(max_length=32, null=False) | ||||
| llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") | llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") | ||||
| model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") | |||||
| llm_name = CharField(max_length=128, null=False, help_text="LLM name") | |||||
| model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") | |||||
| llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") | |||||
| api_key = CharField(max_length=255, null=True, help_text="API KEY") | api_key = CharField(max_length=255, null=True, help_text="API KEY") | ||||
| api_base = CharField(max_length=255, null=True, help_text="API Base") | api_base = CharField(max_length=255, null=True, help_text="API Base") | ||||
| class Meta: | class Meta: | ||||
| db_table = "tenant_llm" | db_table = "tenant_llm" | ||||
| primary_key = CompositeKey('tenant_id', 'llm_factory') | |||||
| primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name') | |||||
| class Knowledgebase(DataBaseModel): | class Knowledgebase(DataBaseModel): | ||||
| permission = CharField(max_length=16, null=False, help_text="me|team") | permission = CharField(max_length=16, null=False, help_text="me|team") | ||||
| created_by = CharField(max_length=32, null=False) | created_by = CharField(max_length=32, null=False) | ||||
| doc_num = IntegerField(default=0) | doc_num = IntegerField(default=0) | ||||
| embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") | |||||
| token_num = IntegerField(default=0) | |||||
| chunk_num = IntegerField(default=0) | |||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | ||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from peewee import Expression | |||||
| from web_server.db import TenantPermission, FileType | from web_server.db import TenantPermission, FileType | ||||
| from web_server.db.db_models import DB, Knowledgebase | |||||
| from web_server.db.db_models import DB, Knowledgebase, Tenant | |||||
| from web_server.db.db_models import Document | from web_server.db.db_models import Document | ||||
| from web_server.db.services.common_service import CommonService | from web_server.db.services.common_service import CommonService | ||||
| from web_server.db.services.kb_service import KnowledgebaseService | from web_server.db.services.kb_service import KnowledgebaseService | ||||
| from web_server.utils import get_uuid, get_format_time | |||||
| from web_server.db.db_utils import StatusEnum | from web_server.db.db_utils import StatusEnum | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): | def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): | ||||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] | |||||
| docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( | |||||
| cls.model.status == StatusEnum.VALID.value, | |||||
| cls.model.type != FileType.VIRTUAL, | |||||
| cls.model.progress == 0, | |||||
| cls.model.update_time >= tm, | |||||
| cls.model.create_time % | |||||
| comm == mod).order_by( | |||||
| cls.model.update_time.asc()).paginate( | |||||
| 1, | |||||
| items_per_page) | |||||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time] | |||||
| docs = cls.model.select(*fields) \ | |||||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ | |||||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||||
| .where( | |||||
| cls.model.status == StatusEnum.VALID.value, | |||||
| ~(cls.model.type == FileType.VIRTUAL.value), | |||||
| cls.model.progress == 0, | |||||
| cls.model.update_time >= tm, | |||||
| (Expression(cls.model.create_time, "%%", comm) == mod))\ | |||||
| .order_by(cls.model.update_time.asc())\ | |||||
| .paginate(1, items_per_page) | |||||
| return list(docs.dicts()) | return list(docs.dicts()) | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): | |||||
| num = cls.model.update(token_num=cls.model.token_num + token_num, | |||||
| chunk_num=cls.model.chunk_num + chunk_num, | |||||
| process_duation=cls.model.process_duation+duation).where( | |||||
| cls.model.id == doc_id).execute() | |||||
| if num == 0:raise LookupError("Document not found which is supposed to be there") | |||||
| num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() | |||||
| return num | |||||
| from werkzeug.security import generate_password_hash, check_password_hash | from werkzeug.security import generate_password_hash, check_password_hash | ||||
| from web_server.db import TenantPermission | from web_server.db import TenantPermission | ||||
| from web_server.db.db_models import DB, UserTenant | |||||
| from web_server.db.db_models import DB, UserTenant, Tenant | |||||
| from web_server.db.db_models import Knowledgebase | from web_server.db.db_models import Knowledgebase | ||||
| from web_server.db.services.common_service import CommonService | from web_server.db.services.common_service import CommonService | ||||
| from web_server.utils import get_uuid, get_format_time | from web_server.utils import get_uuid, get_format_time | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): | |||||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, | |||||
| page_number, items_per_page, orderby, desc): | |||||
| kbs = cls.model.select().where( | kbs = cls.model.select().where( | ||||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) | |||||
| & (cls.model.status==StatusEnum.VALID.value) | |||||
| ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == | |||||
| TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) | |||||
| & (cls.model.status == StatusEnum.VALID.value) | |||||
| ) | ) | ||||
| if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||||
| else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||||
| if desc: | |||||
| kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) | |||||
| else: | |||||
| kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) | |||||
| kbs = kbs.paginate(page_number, items_per_page) | kbs = kbs.paginate(page_number, items_per_page) | ||||
| return list(kbs.dicts()) | return list(kbs.dicts()) | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_detail(cls, kb_id): | |||||
| fields = [ | |||||
| cls.model.id, | |||||
| Tenant.embd_id, | |||||
| cls.model.avatar, | |||||
| cls.model.name, | |||||
| cls.model.description, | |||||
| cls.model.permission, | |||||
| cls.model.doc_num, | |||||
| cls.model.token_num, | |||||
| cls.model.chunk_num, | |||||
| cls.model.parser_id] | |||||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | |||||
| (cls.model.id == kb_id), | |||||
| (cls.model.status == StatusEnum.VALID.value) | |||||
| ) | |||||
| if not kbs: | |||||
| return | |||||
| d = kbs[0].to_dict() | |||||
| d["embd_id"] = kbs[0].tenant.embd_id | |||||
| return d |
| class TenantLLMService(CommonService): | class TenantLLMService(CommonService): | ||||
| model = TenantLLM | model = TenantLLM | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_api_key(cls, tenant_id, model_type): | |||||
| objs = cls.query(tenant_id=tenant_id, model_type=model_type) | |||||
| if objs and len(objs)>0 and objs[0].llm_name: | |||||
| return objs[0] | |||||
| fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key] | |||||
| objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where( | |||||
| (cls.model.tenant_id == tenant_id), | |||||
| (cls.model.model_type == model_type), | |||||
| (LLM.status == StatusEnum.VALID) | |||||
| ) | |||||
| if not objs:return | |||||
| return objs[0] | |||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_by_user_id(cls, user_id): | def get_by_user_id(cls, user_id): | ||||
| fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] | |||||
| fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] | |||||
| return list(cls.model.select(*fields)\ | return list(cls.model.select(*fields)\ | ||||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ | .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ | ||||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | .where(cls.model.status == StatusEnum.VALID.value).dicts()) |
| if re.match(r".*\.pdf$", filename): | if re.match(r".*\.pdf$", filename): | ||||
| return FileType.PDF.value | return FileType.PDF.value | ||||
| if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): | |||||
| if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||||
| return FileType.DOC.value | return FileType.DOC.value | ||||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): |