Browse Source

refine admin initialization (#75)

tags/v0.1.0
KevinHuSh 1 year ago
parent
commit
4568a4b2cb
No account linked to committer's email address

+ 2
- 2
api/apps/chunk_app.py View File

from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, huqie, retrievaler
from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.settings import RetCode, retrievaler
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib import hashlib
import re import re

+ 1
- 3
api/apps/conversation_app.py View File

from api.db import LLMType from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger
from api.settings import access_logger, stat_logger, retrievaler
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder, rmSpace from rag.utils import num_tokens_from_string, encoder, rmSpace

+ 42
- 4
api/db/init_data.py View File

import time import time
import uuid import uuid
from api.db import LLMType
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
def init_superuser(): def init_superuser():
"creator": "system", "creator": "system",
"status": "1", "status": "1",
} }
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "‘s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY})
if not UserService.save(**user_info):
print("【ERROR】can't init admin.")
return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
UserService.save(**user_info) UserService.save(**user_info)
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"])
if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
def init_llm_factory(): def init_llm_factory():
factory_infos = [{ factory_infos = [{
def init_web_data(): def init_web_data():
start_time = time.time() start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory() if not LLMService.get_all().count():init_llm_factory()
if not UserService.get_all().count():
init_superuser()
print("init web data success:{}".format(time.time() - start_time)) print("init web data success:{}".format(time.time() - start_time))

+ 5
- 1
api/settings.py View File

from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
from rag.nlp import search
from rag.utils import ELASTICSEARCH
# Server
API_VERSION = "v1" API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow" RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py" SERVER_MODULE = "rag_flow_server.py"
PRIVILEGE_COMMAND_WHITELIST = [] PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
class CustomEnum(Enum): class CustomEnum(Enum):
@classmethod @classmethod
def valid(cls, value): def valid(cls, value):

+ 1
- 1
deepdoc/parser/pdf_parser.py View File

b["H_right"] = headers[ii]["x1"] b["H_right"] = headers[ii]["x1"]
b["H"] = ii b["H"] = ii


ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None: if ii is not None:
b["C"] = ii b["C"] = ii
b["C_left"] = clmns[ii]["x0"] b["C_left"] = clmns[ii]["x0"]

+ 1
- 1
deepdoc/vision/layout_recognizer.py View File

super().__init__(self.labels, domain, super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",

+ 2
- 3
deepdoc/vision/postprocess.py View File

import numpy as np import numpy as np
import cv2 import cv2
import paddle
from shapely.geometry import Polygon from shapely.geometry import Polygon
import pyclipper import pyclipper
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor):
if not isinstance(pred, np.ndarray):
pred = pred.numpy() pred = pred.numpy()
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
segmentation = pred > self.thresh segmentation = pred > self.thresh
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list): if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1] preds = preds[-1]
if isinstance(preds, paddle.Tensor):
if not isinstance(preds, np.ndarray):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)

+ 12
- 0
deepdoc/vision/recognizer.py View File

return max_overlaped_i return max_overlaped_i
@staticmethod
def find_horizontally_tightest_fit(box, boxes):
if not boxes:
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
min_dis = dis
return min_i
@staticmethod @staticmethod
def find_overlapped_with_threashold(box, boxes, thr=0.3): def find_overlapped_with_threashold(box, boxes, thr=0.3):
if not boxes: if not boxes:

+ 3
- 1
deepdoc/vision/t_recognizer.py View File

clmns = sorted([r for r in tb_cpns if re.match( clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"]) r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
for b in boxes: for b in boxes:
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
if ii is not None: if ii is not None:
b["H_right"] = headers[ii]["x1"] b["H_right"] = headers[ii]["x1"]
b["H"] = ii b["H"] = ii
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None: if ii is not None:
b["C"] = ii b["C"] = ii
b["C_left"] = clmns[ii]["x0"] b["C_left"] = clmns[ii]["x0"]
b["H_left"] = spans[ii]["x0"] b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"] b["H_right"] = spans[ii]["x1"]
b["SP"] = ii b["SP"] = ii
html = """ html = """
<html> <html>
<head> <head>

+ 5
- 5
deepdoc/vision/table_structure_recognizer.py View File

import os import os
import re import re
from collections import Counter from collections import Counter
from copy import deepcopy
import numpy as np import numpy as np
super().__init__(self.labels, "tsr", super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, images, thr=0.5):
def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr) tbls = super().__call__(images, thr)
res = [] res = []
# align left&right for rows, align top&bottom for columns # align left&right for rows, align top&bottom for columns
"row") > 0 or b["label"].find("header") > 0] "row") > 0 or b["label"].find("header") > 0]
if not left: if not left:
continue continue
left = np.median(left) if len(left) > 4 else np.min(left)
right = np.median(right) if len(right) > 4 else np.max(right)
left = np.mean(left) if len(left) > 4 else np.min(left)
right = np.mean(right) if len(right) > 4 else np.max(right)
for b in lts: for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0: if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left: if b["x0"] > left:
i = 0 i = 0
while i < len(boxes): while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]): if TableStructureRecognizer.is_caption(boxes[i]):
if is_english: cap + " "
cap += boxes[i]["text"] cap += boxes[i]["text"]
boxes.pop(i) boxes.pop(i)
i -= 1 i -= 1
for i in range(clmno): for i in range(clmno):
if not tbl[r][i]: if not tbl[r][i]:
continue continue
txt = "".join([a["text"].strip() for a in tbl[r][i]])
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt headers[r][i] = txt
hdrset.add(txt) hdrset.add(txt)
if all([not t for t in headers[r]]): if all([not t for t in headers[r]]):

+ 11
- 8
rag/llm/chat_model.py View File

# #
from abc import ABC from abc import ABC
from openai import OpenAI from openai import OpenAI
import os
import openai




class Base(ABC): class Base(ABC):


def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system: history.insert(0, {"role": "system", "content": system})
res = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
return res.choices[0].message.content.strip(), res.usage.completion_tokens
try:
res = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
return res.choices[0].message.content.strip(), res.usage.completion_tokens
except openai.APIError as e:
return "ERROR: "+str(e), 0




from dashscope import Generation from dashscope import Generation
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0
return "ERROR: " + response.message, 0




from zhipuai import ZhipuAI from zhipuai import ZhipuAI
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.completion_tokens return response.output.choices[0]['message']['content'], response.usage.completion_tokens
return response.message, 0
return "ERROR: " + response.message, 0

+ 2
- 3
rag/nlp/__init__.py View File

from . import search
from rag.utils import ELASTICSEARCH
retrievaler = search.Dealer(ELASTICSEARCH)
from nltk.stem import PorterStemmer from nltk.stem import PorterStemmer
stemmer = PorterStemmer() stemmer = PorterStemmer()
] ]
] ]
def random_choices(arr, k): def random_choices(arr, k):
k = min(len(arr), k) k = min(len(arr), k)
return random.choices(arr, k=k) return random.choices(arr, k=k)
def bullets_category(sections): def bullets_category(sections):
global BULLET_PATTERN global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN) hits = [0] * len(BULLET_PATTERN)

+ 4
- 2
rag/nlp/search.py View File

# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import re import re
from elasticsearch_dsl import Q, Search, A
from elasticsearch_dsl import Q, Search
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass




def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7): embd_mdl, tkweight=0.3, vtweight=0.7):
assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)): for i in range(1, len(pieces)):
if re.match(r"[a-z][.?;!][ \n]", pieces[i]): if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
if mx < 0.55: if mx < 0.55:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]


res = "" res = ""
for i, p in enumerate(pieces): for i, p in enumerate(pieces):
continue continue
if i not in cites: if i not in cites:
continue continue
assert int(cites[i]) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i]) res += "##%s$$" % "$".join(cites[i])


return res return res

Loading…
Cancel
Save