### What problem does this PR solve? #384 ### Type of change - [x] Performance Improvementtags/v0.3.0
| @@ -72,8 +72,8 @@ | |||
| ### 📝 Prerequisites | |||
| - CPU >= 2 cores | |||
| - RAM >= 8 GB | |||
| - CPU >= 4 cores | |||
| - RAM >= 12 GB | |||
| - Docker >= 24.0.0 & Docker Compose >= v2.26.1 | |||
| > If you have not installed Docker on your local machine (Windows, Mac, or Linux), see [Install Docker Engine](https://docs.docker.com/engine/install/). | |||
| @@ -72,8 +72,8 @@ | |||
| ### 📝 必要条件 | |||
| - CPU >= 2 cores | |||
| - RAM >= 8 GB | |||
| - CPU >= 4 cores | |||
| - RAM >= 12 GB | |||
| - Docker >= 24.0.0 & Docker Compose >= v2.26.1 | |||
| > ローカルマシン(Windows、Mac、または Linux)に Docker をインストールしていない場合は、[Docker Engine のインストール](https://docs.docker.com/engine/install/) を参照してください。 | |||
| @@ -72,8 +72,8 @@ | |||
| ### 📝 前提条件 | |||
| - CPU >= 2 核 | |||
| - RAM >= 8 GB | |||
| - CPU >= 4 核 | |||
| - RAM >= 12 GB | |||
| - Docker >= 24.0.0 & Docker Compose >= v2.26.1 | |||
| > 如果你并没有在本机安装 Docker(Windows、Mac,或者 Linux), 可以参考文档 [Install Docker Engine](https://docs.docker.com/engine/install/) 自行安装。 | |||
| @@ -105,7 +105,7 @@ def stats(): | |||
| res = { | |||
| "pv": [(o["dt"], o["pv"]) for o in objs], | |||
| "uv": [(o["dt"], o["uv"]) for o in objs], | |||
| "speed": [(o["dt"], float(o["tokens"])/float(o["duration"])) for o in objs], | |||
| "speed": [(o["dt"], float(o["tokens"])/(float(o["duration"]+0.1))) for o in objs], | |||
| "tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs], | |||
| "round": [(o["dt"], o["round"]) for o in objs], | |||
| "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] | |||
| @@ -176,7 +176,6 @@ def completion(): | |||
| conv.reference.append(ans["reference"]) | |||
| conv.message.append({"role": "assistant", "content": ans["answer"]}) | |||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||
| APITokenService.APITokenService(token) | |||
| return get_json_result(data=ans) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -14,6 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| from datetime import datetime | |||
| from flask import request, session, redirect | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| @@ -22,7 +23,7 @@ from flask_login import login_required, current_user, login_user, logout_user | |||
| from api.db.db_models import TenantLLM | |||
| from api.db.services.llm_service import TenantLLMService, LLMService | |||
| from api.utils.api_utils import server_error_response, validate_request | |||
| from api.utils import get_uuid, get_format_time, decrypt, download_img | |||
| from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | |||
| from api.db import UserTenantRole, LLMType | |||
| from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \ | |||
| LLM_FACTORY, LLM_BASE_URL | |||
| @@ -56,6 +57,8 @@ def login(): | |||
| response_data = user.to_json() | |||
| user.access_token = get_uuid() | |||
| login_user(user) | |||
| user.update_time = current_timestamp(), | |||
| user.update_date = datetime_format(datetime.now()), | |||
| user.save() | |||
| msg = "Welcome back!" | |||
| return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) | |||
| @@ -40,8 +40,8 @@ class API4ConversationService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def append_message(cls, id, conversation): | |||
| cls.model.update_by_id(id, conversation) | |||
| return cls.model.update(round=cls.model.round + 1).where(id=id).execute() | |||
| cls.update_by_id(id, conversation) | |||
| return cls.model.update(round=cls.model.round + 1).where(cls.model.id==id).execute() | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -13,6 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import random | |||
| from peewee import Expression | |||
| from api.db.db_models import DB | |||
| from api.db import StatusEnum, FileType, TaskStatus | |||
| @@ -26,7 +28,7 @@ class TaskService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): | |||
| def get_tasks(cls, tm, mod=0, comm=1, items_per_page=1, takeit=True): | |||
| fields = [ | |||
| cls.model.id, | |||
| cls.model.doc_id, | |||
| @@ -45,20 +47,28 @@ class TaskService(CommonService): | |||
| Tenant.img2txt_id, | |||
| Tenant.asr_id, | |||
| cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| .where( | |||
| Document.status == StatusEnum.VALID.value, | |||
| Document.run == TaskStatus.RUNNING.value, | |||
| ~(Document.type == FileType.VIRTUAL.value), | |||
| cls.model.progress == 0, | |||
| cls.model.update_time >= tm, | |||
| (Expression(cls.model.create_time, "%%", comm) == mod))\ | |||
| .order_by(cls.model.update_time.asc())\ | |||
| .paginate(1, items_per_page) | |||
| return list(docs.dicts()) | |||
| with DB.lock("get_task", -1): | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| .where( | |||
| Document.status == StatusEnum.VALID.value, | |||
| Document.run == TaskStatus.RUNNING.value, | |||
| ~(Document.type == FileType.VIRTUAL.value), | |||
| cls.model.progress == 0, | |||
| #cls.model.update_time >= tm, | |||
| #(Expression(cls.model.create_time, "%%", comm) == mod) | |||
| )\ | |||
| .order_by(cls.model.update_time.asc())\ | |||
| .paginate(0, items_per_page) | |||
| docs = list(docs.dicts()) | |||
| if not docs: return [] | |||
| if not takeit: return docs | |||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n" + "Task has been received.", progress=random.random()/10.).where( | |||
| cls.model.id == docs[0]["id"]).execute() | |||
| return docs | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -74,9 +84,10 @@ class TaskService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def update_progress(cls, id, info): | |||
| if info["progress_msg"]: | |||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( | |||
| cls.model.id == id).execute() | |||
| if "progress" in info: | |||
| cls.model.update(progress=info["progress"]).where( | |||
| cls.model.id == id).execute() | |||
| with DB.lock("update_progress", -1): | |||
| if info["progress_msg"]: | |||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( | |||
| cls.model.id == id).execute() | |||
| if "progress" in info: | |||
| cls.model.update(progress=info["progress"]).where( | |||
| cls.model.id == id).execute() | |||
| @@ -3,6 +3,8 @@ from openpyxl import load_workbook | |||
| import sys | |||
| from io import BytesIO | |||
| from rag.nlp import find_codec | |||
| class HuExcelParser: | |||
| def html(self, fnm): | |||
| @@ -66,7 +68,8 @@ class HuExcelParser: | |||
| return total | |||
| if fnm.split(".")[-1].lower() in ["csv", "txt"]: | |||
| txt = binary.decode("utf-8") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| return len(txt.split("\n")) | |||
| @@ -15,7 +15,8 @@ import re | |||
| from io import BytesIO | |||
| from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, \ | |||
| hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, add_positions, tokenize_chunks | |||
| hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, add_positions, \ | |||
| tokenize_chunks, find_codec | |||
| from rag.nlp import huqie | |||
| from deepdoc.parser import PdfParser, DocxParser, PlainParser | |||
| @@ -87,7 +88,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| @@ -17,7 +17,7 @@ from docx import Document | |||
| from api.db import ParserType | |||
| from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \ | |||
| make_colon_as_title, add_positions, tokenize_chunks | |||
| make_colon_as_title, add_positions, tokenize_chunks, find_codec | |||
| from rag.nlp import huqie | |||
| from deepdoc.parser import PdfParser, DocxParser, PlainParser | |||
| from rag.settings import cron_logger | |||
| @@ -111,7 +111,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| @@ -14,7 +14,7 @@ from io import BytesIO | |||
| from docx import Document | |||
| import re | |||
| from deepdoc.parser.pdf_parser import PlainParser | |||
| from rag.nlp import huqie, naive_merge, tokenize_table, tokenize_chunks | |||
| from rag.nlp import huqie, naive_merge, tokenize_table, tokenize_chunks, find_codec | |||
| from deepdoc.parser import PdfParser, ExcelParser, DocxParser | |||
| from rag.settings import cron_logger | |||
| @@ -139,10 +139,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| try: | |||
| txt = binary.decode("utf-8") | |||
| except Exception as e: | |||
| txt = binary.decode("gb2312") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| @@ -12,7 +12,7 @@ | |||
| # | |||
| import re | |||
| from rag.app import laws | |||
| from rag.nlp import huqie, tokenize | |||
| from rag.nlp import huqie, tokenize, find_codec | |||
| from deepdoc.parser import PdfParser, ExcelParser, PlainParser | |||
| @@ -82,7 +82,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| @@ -15,7 +15,7 @@ from copy import deepcopy | |||
| from io import BytesIO | |||
| from nltk import word_tokenize | |||
| from openpyxl import load_workbook | |||
| from rag.nlp import is_english, random_choices | |||
| from rag.nlp import is_english, random_choices, find_codec | |||
| from rag.nlp import huqie | |||
| from deepdoc.parser import ExcelParser | |||
| @@ -106,7 +106,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| @@ -20,7 +20,7 @@ from openpyxl import load_workbook | |||
| from dateutil.parser import parse as datetime_parse | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from rag.nlp import huqie, is_english, tokenize | |||
| from rag.nlp import huqie, is_english, tokenize, find_codec | |||
| from deepdoc.parser import ExcelParser | |||
| @@ -147,7 +147,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| encoding = find_codec(binary) | |||
| txt = binary.decode(encoding) | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| @@ -6,6 +6,35 @@ from . import huqie | |||
| import re | |||
| import copy | |||
| all_codecs = [ | |||
| 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', | |||
| 'cp037', 'cp273', 'cp424', 'cp437', | |||
| 'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', | |||
| 'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869', | |||
| 'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125', | |||
| 'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', | |||
| 'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr', | |||
| 'gb2312', 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', | |||
| 'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1', | |||
| 'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', | |||
| 'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13', | |||
| 'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u', | |||
| 'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman', | |||
| 'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213', | |||
| 'utf_32', 'utf_32_be', 'utf_32_le''utf_16_be', 'utf_16_le', 'utf_7' | |||
| ] | |||
| def find_codec(blob): | |||
| global all_codecs | |||
| for c in all_codecs: | |||
| try: | |||
| blob.decode(c) | |||
| return c | |||
| except Exception as e: | |||
| pass | |||
| return "utf-8" | |||
| BULLET_PATTERN = [[ | |||
| r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", | |||
| @@ -8,6 +8,7 @@ import re | |||
| import string | |||
| import sys | |||
| from hanziconv import HanziConv | |||
| from huggingface_hub import snapshot_download | |||
| from nltk import word_tokenize | |||
| from nltk.stem import PorterStemmer, WordNetLemmatizer | |||
| from api.utils.file_utils import get_project_base_directory | |||
| @@ -68,7 +68,7 @@ class Dealer: | |||
| pg = int(req.get("page", 1)) - 1 | |||
| ps = int(req.get("size", 1000)) | |||
| topk = int(req.get("topk", 1024)) | |||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", | |||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", | |||
| "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", | |||
| "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) | |||
| @@ -289,8 +289,18 @@ class Dealer: | |||
| sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] | |||
| if not ins_embd: | |||
| return [], [], [] | |||
| ins_tw = [sres.field[i][cfield].split(" ") | |||
| for i in sres.ids] | |||
| for i in sres.ids: | |||
| if isinstance(sres.field[i].get("important_kwd", []), str): | |||
| sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]] | |||
| ins_tw = [] | |||
| for i in sres.ids: | |||
| content_ltks = sres.field[i][cfield].split(" ") | |||
| title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t] | |||
| important_kwd = sres.field[i].get("important_kwd", []) | |||
| tks = content_ltks + title_tks + important_kwd | |||
| ins_tw.append(tks) | |||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | |||
| ins_embd, | |||
| keywords, | |||
| @@ -368,7 +378,7 @@ class Dealer: | |||
| def sql_retrieval(self, sql, fetch_size=128, format="json"): | |||
| from api.settings import chat_logger | |||
| sql = re.sub(r"[ ]+", " ", sql) | |||
| sql = re.sub(r"[ `]+", " ", sql) | |||
| sql = sql.replace("%", "") | |||
| es_logger.info(f"Get es sql: {sql}") | |||
| replaces = [] | |||
| @@ -121,6 +121,7 @@ def dispatch(): | |||
| tsks.append(new_task()) | |||
| bulk_insert_into_db(Task, tsks, True) | |||
| print("TSK:", len(tsks)) | |||
| set_dispatching(r["id"]) | |||
| except Exception as e: | |||
| cron_logger.exception(e) | |||
| @@ -19,6 +19,7 @@ import logging | |||
| import os | |||
| import hashlib | |||
| import copy | |||
| import random | |||
| import re | |||
| import sys | |||
| import time | |||
| @@ -92,6 +93,7 @@ def set_progress(task_id, from_page=0, to_page=-1, | |||
| def collect(comm, mod, tm): | |||
| tasks = TaskService.get_tasks(tm, mod, comm) | |||
| #print(tasks) | |||
| if len(tasks) == 0: | |||
| time.sleep(1) | |||
| return pd.DataFrame() | |||
| @@ -243,6 +245,7 @@ def main(comm, mod): | |||
| tmf = open(tm_fnm, "a+") | |||
| for _, r in rows.iterrows(): | |||
| callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) | |||
| #callback(random.random()/10., "Task has been received.") | |||
| try: | |||
| embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) | |||
| except Exception as e: | |||
| @@ -300,9 +303,8 @@ if __name__ == "__main__": | |||
| peewee_logger.addHandler(database_logger.handlers[0]) | |||
| peewee_logger.setLevel(database_logger.level) | |||
| from mpi4py import MPI | |||
| comm = MPI.COMM_WORLD | |||
| #from mpi4py import MPI | |||
| #comm = MPI.COMM_WORLD | |||
| while True: | |||
| main(int(sys.argv[2]), int(sys.argv[1])) | |||
| close_connection() | |||