| @@ -20,7 +20,7 @@ from flask_login import login_required, current_user | |||
| from elasticsearch_dsl import Q | |||
| from rag.app.qa import rmPrefix, beAdoc | |||
| from rag.nlp import search, huqie, retrievaler | |||
| from rag.nlp import search, huqie | |||
| from rag.utils import ELASTICSEARCH, rmSpace | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| @@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.db.services.document_service import DocumentService | |||
| from api.settings import RetCode | |||
| from api.settings import RetCode, retrievaler | |||
| from api.utils.api_utils import get_json_result | |||
| import hashlib | |||
| import re | |||
| @@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService | |||
| from api.db import LLMType | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMService, LLMBundle | |||
| from api.settings import access_logger, stat_logger | |||
| from api.settings import access_logger, stat_logger, retrievaler | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.app.resume import forbidden_select_fields4resume | |||
| from rag.llm import ChatModel | |||
| from rag.nlp import retrievaler | |||
| from rag.nlp.search import index_name | |||
| from rag.utils import num_tokens_from_string, encoder, rmSpace | |||
| @@ -16,10 +16,12 @@ | |||
| import time | |||
| import uuid | |||
| from api.db import LLMType | |||
| from api.db import LLMType, UserTenantRole | |||
| from api.db.db_models import init_database_tables as init_web_db | |||
| from api.db.services import UserService | |||
| from api.db.services.llm_service import LLMFactoriesService, LLMService | |||
| from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY | |||
| def init_superuser(): | |||
| @@ -32,8 +34,44 @@ def init_superuser(): | |||
| "creator": "system", | |||
| "status": "1", | |||
| } | |||
| tenant = { | |||
| "id": user_info["id"], | |||
| "name": user_info["nickname"] + "‘s Kingdom", | |||
| "llm_id": CHAT_MDL, | |||
| "embd_id": EMBEDDING_MDL, | |||
| "asr_id": ASR_MDL, | |||
| "parser_ids": PARSERS, | |||
| "img2txt_id": IMAGE2TEXT_MDL | |||
| } | |||
| usr_tenant = { | |||
| "tenant_id": user_info["id"], | |||
| "user_id": user_info["id"], | |||
| "invited_by": user_info["id"], | |||
| "role": UserTenantRole.OWNER | |||
| } | |||
| tenant_llm = [] | |||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||
| tenant_llm.append( | |||
| {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, | |||
| "api_key": API_KEY}) | |||
| if not UserService.save(**user_info): | |||
| print("【ERROR】can't init admin.") | |||
| return | |||
| TenantService.save(**tenant) | |||
| UserTenantService.save(**usr_tenant) | |||
| TenantLLMService.insert_many(tenant_llm) | |||
| UserService.save(**user_info) | |||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | |||
| msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) | |||
| if msg.find("ERROR: ") == 0: | |||
| print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg) | |||
| embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"]) | |||
| v,c = embd_mdl.encode(["Hello!"]) | |||
| if c == 0: | |||
| print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"])) | |||
| def init_llm_factory(): | |||
| factory_infos = [{ | |||
| @@ -171,10 +209,10 @@ def init_llm_factory(): | |||
| def init_web_data(): | |||
| start_time = time.time() | |||
| if not UserService.get_all().count(): | |||
| init_superuser() | |||
| if not LLMService.get_all().count():init_llm_factory() | |||
| if not UserService.get_all().count(): | |||
| init_superuser() | |||
| print("init web data success:{}".format(time.time() - start_time)) | |||
| @@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from api.utils.log_utils import LoggerFactory, getLogger | |||
| from rag.nlp import search | |||
| from rag.utils import ELASTICSEARCH | |||
| # Server | |||
| API_VERSION = "v1" | |||
| RAG_FLOW_SERVICE_NAME = "ragflow" | |||
| SERVER_MODULE = "rag_flow_server.py" | |||
| @@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s | |||
| PRIVILEGE_COMMAND_WHITELIST = [] | |||
| CHECK_NODES_IDENTITY = False | |||
| retrievaler = search.Dealer(ELASTICSEARCH) | |||
| class CustomEnum(Enum): | |||
| @classmethod | |||
| def valid(cls, value): | |||
| @@ -230,7 +230,7 @@ class HuParser: | |||
| b["H_right"] = headers[ii]["x1"] | |||
| b["H"] = ii | |||
| ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) | |||
| ii = Recognizer.find_horizontally_tightest_fit(b, clmns) | |||
| if ii is not None: | |||
| b["C"] = ii | |||
| b["C_left"] = clmns[ii]["x0"] | |||
| @@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer): | |||
| super().__init__(self.labels, domain, | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | |||
| def __is_garbage(b): | |||
| patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | |||
| r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", | |||
| @@ -2,7 +2,6 @@ import copy | |||
| import numpy as np | |||
| import cv2 | |||
| import paddle | |||
| from shapely.geometry import Polygon | |||
| import pyclipper | |||
| @@ -215,7 +214,7 @@ class DBPostProcess(object): | |||
| def __call__(self, outs_dict, shape_list): | |||
| pred = outs_dict['maps'] | |||
| if isinstance(pred, paddle.Tensor): | |||
| if not isinstance(pred, np.ndarray): | |||
| pred = pred.numpy() | |||
| pred = pred[:, 0, :, :] | |||
| segmentation = pred > self.thresh | |||
| @@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode): | |||
| def __call__(self, preds, label=None, *args, **kwargs): | |||
| if isinstance(preds, tuple) or isinstance(preds, list): | |||
| preds = preds[-1] | |||
| if isinstance(preds, paddle.Tensor): | |||
| if not isinstance(preds, np.ndarray): | |||
| preds = preds.numpy() | |||
| preds_idx = preds.argmax(axis=2) | |||
| preds_prob = preds.max(axis=2) | |||
| @@ -259,6 +259,18 @@ class Recognizer(object): | |||
| return max_overlaped_i | |||
| @staticmethod | |||
| def find_horizontally_tightest_fit(box, boxes): | |||
| if not boxes: | |||
| return | |||
| min_dis, min_i = 1000000, None | |||
| for i,b in enumerate(boxes): | |||
| dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2) | |||
| if dis < min_dis: | |||
| min_i = i | |||
| min_dis = dis | |||
| return min_i | |||
| @staticmethod | |||
| def find_overlapped_with_threashold(box, boxes, thr=0.3): | |||
| if not boxes: | |||
| @@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr): | |||
| clmns = sorted([r for r in tb_cpns if re.match( | |||
| r"table column$", r["label"])], key=lambda x: x["x0"]) | |||
| clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) | |||
| for b in boxes: | |||
| ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) | |||
| if ii is not None: | |||
| @@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr): | |||
| b["H_right"] = headers[ii]["x1"] | |||
| b["H"] = ii | |||
| ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) | |||
| ii = Recognizer.find_horizontally_tightest_fit(b, clmns) | |||
| if ii is not None: | |||
| b["C"] = ii | |||
| b["C_left"] = clmns[ii]["x0"] | |||
| @@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr): | |||
| b["H_left"] = spans[ii]["x0"] | |||
| b["H_right"] = spans[ii]["x1"] | |||
| b["SP"] = ii | |||
| html = """ | |||
| <html> | |||
| <head> | |||
| @@ -14,7 +14,6 @@ import logging | |||
| import os | |||
| import re | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| @@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer): | |||
| super().__init__(self.labels, "tsr", | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, images, thr=0.5): | |||
| def __call__(self, images, thr=0.2): | |||
| tbls = super().__call__(images, thr) | |||
| res = [] | |||
| # align left&right for rows, align top&bottom for columns | |||
| @@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer): | |||
| "row") > 0 or b["label"].find("header") > 0] | |||
| if not left: | |||
| continue | |||
| left = np.median(left) if len(left) > 4 else np.min(left) | |||
| right = np.median(right) if len(right) > 4 else np.max(right) | |||
| left = np.mean(left) if len(left) > 4 else np.min(left) | |||
| right = np.mean(right) if len(right) > 4 else np.max(right) | |||
| for b in lts: | |||
| if b["label"].find("row") > 0 or b["label"].find("header") > 0: | |||
| if b["x0"] > left: | |||
| @@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer): | |||
| i = 0 | |||
| while i < len(boxes): | |||
| if TableStructureRecognizer.is_caption(boxes[i]): | |||
| if is_english: cap + " " | |||
| cap += boxes[i]["text"] | |||
| boxes.pop(i) | |||
| i -= 1 | |||
| @@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer): | |||
| for i in range(clmno): | |||
| if not tbl[r][i]: | |||
| continue | |||
| txt = "".join([a["text"].strip() for a in tbl[r][i]]) | |||
| txt = " ".join([a["text"].strip() for a in tbl[r][i]]) | |||
| headers[r][i] = txt | |||
| hdrset.add(txt) | |||
| if all([not t for t in headers[r]]): | |||
| @@ -15,7 +15,7 @@ | |||
| # | |||
| from abc import ABC | |||
| from openai import OpenAI | |||
| import os | |||
| import openai | |||
| class Base(ABC): | |||
| @@ -33,11 +33,14 @@ class GptTurbo(Base): | |||
| def chat(self, system, history, gen_conf): | |||
| if system: history.insert(0, {"role": "system", "content": system}) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=history, | |||
| **gen_conf) | |||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||
| try: | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=history, | |||
| **gen_conf) | |||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||
| except openai.APIError as e: | |||
| return "ERROR: "+str(e), 0 | |||
| from dashscope import Generation | |||
| @@ -58,7 +61,7 @@ class QWenChat(Base): | |||
| ) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | |||
| return response.message, 0 | |||
| return "ERROR: " + response.message, 0 | |||
| from zhipuai import ZhipuAI | |||
| @@ -77,4 +80,4 @@ class ZhipuChat(Base): | |||
| ) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'], response.usage.completion_tokens | |||
| return response.message, 0 | |||
| return "ERROR: " + response.message, 0 | |||
| @@ -1,7 +1,4 @@ | |||
| from . import search | |||
| from rag.utils import ELASTICSEARCH | |||
| retrievaler = search.Dealer(ELASTICSEARCH) | |||
| from nltk.stem import PorterStemmer | |||
| stemmer = PorterStemmer() | |||
| @@ -39,10 +36,12 @@ BULLET_PATTERN = [[ | |||
| ] | |||
| ] | |||
| def random_choices(arr, k): | |||
| k = min(len(arr), k) | |||
| return random.choices(arr, k=k) | |||
| def bullets_category(sections): | |||
| global BULLET_PATTERN | |||
| hits = [0] * len(BULLET_PATTERN) | |||
| @@ -1,7 +1,7 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import json | |||
| import re | |||
| from elasticsearch_dsl import Q, Search, A | |||
| from elasticsearch_dsl import Q, Search | |||
| from typing import List, Optional, Dict, Union | |||
| from dataclasses import dataclass | |||
| @@ -183,6 +183,7 @@ class Dealer: | |||
| def insert_citations(self, answer, chunks, chunk_v, | |||
| embd_mdl, tkweight=0.3, vtweight=0.7): | |||
| assert len(chunks) == len(chunk_v) | |||
| pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | |||
| for i in range(1, len(pieces)): | |||
| if re.match(r"[a-z][.?;!][ \n]", pieces[i]): | |||
| @@ -216,7 +217,7 @@ class Dealer: | |||
| if mx < 0.55: | |||
| continue | |||
| cites[idx[i]] = list( | |||
| set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] | |||
| set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] | |||
| res = "" | |||
| for i, p in enumerate(pieces): | |||
| @@ -225,6 +226,7 @@ class Dealer: | |||
| continue | |||
| if i not in cites: | |||
| continue | |||
| assert int(cites[i]) < len(chunk_v) | |||
| res += "##%s$$" % "$".join(cites[i]) | |||
| return res | |||