| from elasticsearch_dsl import Q | from elasticsearch_dsl import Q | ||||
| from rag.app.qa import rmPrefix, beAdoc | from rag.app.qa import rmPrefix, beAdoc | ||||
| from rag.nlp import search, huqie, retrievaler | |||||
| from rag.nlp import search, huqie | |||||
| from rag.utils import ELASTICSEARCH, rmSpace | from rag.utils import ELASTICSEARCH, rmSpace | ||||
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.user_service import UserTenantService | from api.db.services.user_service import UserTenantService | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.settings import RetCode | |||||
| from api.settings import RetCode, retrievaler | |||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| import hashlib | import hashlib | ||||
| import re | import re |
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMService, LLMBundle | from api.db.services.llm_service import LLMService, LLMBundle | ||||
| from api.settings import access_logger, stat_logger | |||||
| from api.settings import access_logger, stat_logger, retrievaler | |||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| from rag.app.resume import forbidden_select_fields4resume | from rag.app.resume import forbidden_select_fields4resume | ||||
| from rag.llm import ChatModel | |||||
| from rag.nlp import retrievaler | |||||
| from rag.nlp.search import index_name | from rag.nlp.search import index_name | ||||
| from rag.utils import num_tokens_from_string, encoder, rmSpace | from rag.utils import num_tokens_from_string, encoder, rmSpace | ||||
| import time | import time | ||||
| import uuid | import uuid | ||||
| from api.db import LLMType | |||||
| from api.db import LLMType, UserTenantRole | |||||
| from api.db.db_models import init_database_tables as init_web_db | from api.db.db_models import init_database_tables as init_web_db | ||||
| from api.db.services import UserService | from api.db.services import UserService | ||||
| from api.db.services.llm_service import LLMFactoriesService, LLMService | |||||
| from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | |||||
| from api.db.services.user_service import TenantService, UserTenantService | |||||
| from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY | |||||
| def init_superuser(): | def init_superuser(): | ||||
| "creator": "system", | "creator": "system", | ||||
| "status": "1", | "status": "1", | ||||
| } | } | ||||
| tenant = { | |||||
| "id": user_info["id"], | |||||
| "name": user_info["nickname"] + "‘s Kingdom", | |||||
| "llm_id": CHAT_MDL, | |||||
| "embd_id": EMBEDDING_MDL, | |||||
| "asr_id": ASR_MDL, | |||||
| "parser_ids": PARSERS, | |||||
| "img2txt_id": IMAGE2TEXT_MDL | |||||
| } | |||||
| usr_tenant = { | |||||
| "tenant_id": user_info["id"], | |||||
| "user_id": user_info["id"], | |||||
| "invited_by": user_info["id"], | |||||
| "role": UserTenantRole.OWNER | |||||
| } | |||||
| tenant_llm = [] | |||||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||||
| tenant_llm.append( | |||||
| {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, | |||||
| "api_key": API_KEY}) | |||||
| if not UserService.save(**user_info): | |||||
| print("【ERROR】can't init admin.") | |||||
| return | |||||
| TenantService.save(**tenant) | |||||
| UserTenantService.save(**usr_tenant) | |||||
| TenantLLMService.insert_many(tenant_llm) | |||||
| UserService.save(**user_info) | UserService.save(**user_info) | ||||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | |||||
| msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) | |||||
| if msg.find("ERROR: ") == 0: | |||||
| print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg) | |||||
| embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"]) | |||||
| v,c = embd_mdl.encode(["Hello!"]) | |||||
| if c == 0: | |||||
| print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"])) | |||||
| def init_llm_factory(): | def init_llm_factory(): | ||||
| factory_infos = [{ | factory_infos = [{ | ||||
| def init_web_data(): | def init_web_data(): | ||||
| start_time = time.time() | start_time = time.time() | ||||
| if not UserService.get_all().count(): | |||||
| init_superuser() | |||||
| if not LLMService.get_all().count():init_llm_factory() | if not LLMService.get_all().count():init_llm_factory() | ||||
| if not UserService.get_all().count(): | |||||
| init_superuser() | |||||
| print("init web data success:{}".format(time.time() - start_time)) | print("init web data success:{}".format(time.time() - start_time)) | ||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| from api.utils.log_utils import LoggerFactory, getLogger | from api.utils.log_utils import LoggerFactory, getLogger | ||||
| from rag.nlp import search | |||||
| from rag.utils import ELASTICSEARCH | |||||
| # Server | |||||
| API_VERSION = "v1" | API_VERSION = "v1" | ||||
| RAG_FLOW_SERVICE_NAME = "ragflow" | RAG_FLOW_SERVICE_NAME = "ragflow" | ||||
| SERVER_MODULE = "rag_flow_server.py" | SERVER_MODULE = "rag_flow_server.py" | ||||
| PRIVILEGE_COMMAND_WHITELIST = [] | PRIVILEGE_COMMAND_WHITELIST = [] | ||||
| CHECK_NODES_IDENTITY = False | CHECK_NODES_IDENTITY = False | ||||
| retrievaler = search.Dealer(ELASTICSEARCH) | |||||
| class CustomEnum(Enum): | class CustomEnum(Enum): | ||||
| @classmethod | @classmethod | ||||
| def valid(cls, value): | def valid(cls, value): |
| b["H_right"] = headers[ii]["x1"] | b["H_right"] = headers[ii]["x1"] | ||||
| b["H"] = ii | b["H"] = ii | ||||
| ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) | |||||
| ii = Recognizer.find_horizontally_tightest_fit(b, clmns) | |||||
| if ii is not None: | if ii is not None: | ||||
| b["C"] = ii | b["C"] = ii | ||||
| b["C_left"] = clmns[ii]["x0"] | b["C_left"] = clmns[ii]["x0"] |
| super().__init__(self.labels, domain, | super().__init__(self.labels, domain, | ||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | ||||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | |||||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | |||||
| def __is_garbage(b): | def __is_garbage(b): | ||||
| patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | ||||
| r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", | r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", |
| import numpy as np | import numpy as np | ||||
| import cv2 | import cv2 | ||||
| import paddle | |||||
| from shapely.geometry import Polygon | from shapely.geometry import Polygon | ||||
| import pyclipper | import pyclipper | ||||
| def __call__(self, outs_dict, shape_list): | def __call__(self, outs_dict, shape_list): | ||||
| pred = outs_dict['maps'] | pred = outs_dict['maps'] | ||||
| if isinstance(pred, paddle.Tensor): | |||||
| if not isinstance(pred, np.ndarray): | |||||
| pred = pred.numpy() | pred = pred.numpy() | ||||
| pred = pred[:, 0, :, :] | pred = pred[:, 0, :, :] | ||||
| segmentation = pred > self.thresh | segmentation = pred > self.thresh | ||||
| def __call__(self, preds, label=None, *args, **kwargs): | def __call__(self, preds, label=None, *args, **kwargs): | ||||
| if isinstance(preds, tuple) or isinstance(preds, list): | if isinstance(preds, tuple) or isinstance(preds, list): | ||||
| preds = preds[-1] | preds = preds[-1] | ||||
| if isinstance(preds, paddle.Tensor): | |||||
| if not isinstance(preds, np.ndarray): | |||||
| preds = preds.numpy() | preds = preds.numpy() | ||||
| preds_idx = preds.argmax(axis=2) | preds_idx = preds.argmax(axis=2) | ||||
| preds_prob = preds.max(axis=2) | preds_prob = preds.max(axis=2) |
| return max_overlaped_i | return max_overlaped_i | ||||
| @staticmethod | |||||
| def find_horizontally_tightest_fit(box, boxes): | |||||
| if not boxes: | |||||
| return | |||||
| min_dis, min_i = 1000000, None | |||||
| for i,b in enumerate(boxes): | |||||
| dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2) | |||||
| if dis < min_dis: | |||||
| min_i = i | |||||
| min_dis = dis | |||||
| return min_i | |||||
| @staticmethod | @staticmethod | ||||
| def find_overlapped_with_threashold(box, boxes, thr=0.3): | def find_overlapped_with_threashold(box, boxes, thr=0.3): | ||||
| if not boxes: | if not boxes: |
| clmns = sorted([r for r in tb_cpns if re.match( | clmns = sorted([r for r in tb_cpns if re.match( | ||||
| r"table column$", r["label"])], key=lambda x: x["x0"]) | r"table column$", r["label"])], key=lambda x: x["x0"]) | ||||
| clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) | clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) | ||||
| for b in boxes: | for b in boxes: | ||||
| ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) | ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) | ||||
| if ii is not None: | if ii is not None: | ||||
| b["H_right"] = headers[ii]["x1"] | b["H_right"] = headers[ii]["x1"] | ||||
| b["H"] = ii | b["H"] = ii | ||||
| ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) | |||||
| ii = Recognizer.find_horizontally_tightest_fit(b, clmns) | |||||
| if ii is not None: | if ii is not None: | ||||
| b["C"] = ii | b["C"] = ii | ||||
| b["C_left"] = clmns[ii]["x0"] | b["C_left"] = clmns[ii]["x0"] | ||||
| b["H_left"] = spans[ii]["x0"] | b["H_left"] = spans[ii]["x0"] | ||||
| b["H_right"] = spans[ii]["x1"] | b["H_right"] = spans[ii]["x1"] | ||||
| b["SP"] = ii | b["SP"] = ii | ||||
| html = """ | html = """ | ||||
| <html> | <html> | ||||
| <head> | <head> |
| import os | import os | ||||
| import re | import re | ||||
| from collections import Counter | from collections import Counter | ||||
| from copy import deepcopy | |||||
| import numpy as np | import numpy as np | ||||
| super().__init__(self.labels, "tsr", | super().__init__(self.labels, "tsr", | ||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | ||||
| def __call__(self, images, thr=0.5): | |||||
| def __call__(self, images, thr=0.2): | |||||
| tbls = super().__call__(images, thr) | tbls = super().__call__(images, thr) | ||||
| res = [] | res = [] | ||||
| # align left&right for rows, align top&bottom for columns | # align left&right for rows, align top&bottom for columns | ||||
| "row") > 0 or b["label"].find("header") > 0] | "row") > 0 or b["label"].find("header") > 0] | ||||
| if not left: | if not left: | ||||
| continue | continue | ||||
| left = np.median(left) if len(left) > 4 else np.min(left) | |||||
| right = np.median(right) if len(right) > 4 else np.max(right) | |||||
| left = np.mean(left) if len(left) > 4 else np.min(left) | |||||
| right = np.mean(right) if len(right) > 4 else np.max(right) | |||||
| for b in lts: | for b in lts: | ||||
| if b["label"].find("row") > 0 or b["label"].find("header") > 0: | if b["label"].find("row") > 0 or b["label"].find("header") > 0: | ||||
| if b["x0"] > left: | if b["x0"] > left: | ||||
| i = 0 | i = 0 | ||||
| while i < len(boxes): | while i < len(boxes): | ||||
| if TableStructureRecognizer.is_caption(boxes[i]): | if TableStructureRecognizer.is_caption(boxes[i]): | ||||
| if is_english: cap + " " | |||||
| cap += boxes[i]["text"] | cap += boxes[i]["text"] | ||||
| boxes.pop(i) | boxes.pop(i) | ||||
| i -= 1 | i -= 1 | ||||
| for i in range(clmno): | for i in range(clmno): | ||||
| if not tbl[r][i]: | if not tbl[r][i]: | ||||
| continue | continue | ||||
| txt = "".join([a["text"].strip() for a in tbl[r][i]]) | |||||
| txt = " ".join([a["text"].strip() for a in tbl[r][i]]) | |||||
| headers[r][i] = txt | headers[r][i] = txt | ||||
| hdrset.add(txt) | hdrset.add(txt) | ||||
| if all([not t for t in headers[r]]): | if all([not t for t in headers[r]]): |
| # | # | ||||
| from abc import ABC | from abc import ABC | ||||
| from openai import OpenAI | from openai import OpenAI | ||||
| import os | |||||
| import openai | |||||
| class Base(ABC): | class Base(ABC): | ||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| if system: history.insert(0, {"role": "system", "content": system}) | if system: history.insert(0, {"role": "system", "content": system}) | ||||
| res = self.client.chat.completions.create( | |||||
| model=self.model_name, | |||||
| messages=history, | |||||
| **gen_conf) | |||||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||||
| try: | |||||
| res = self.client.chat.completions.create( | |||||
| model=self.model_name, | |||||
| messages=history, | |||||
| **gen_conf) | |||||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||||
| except openai.APIError as e: | |||||
| return "ERROR: "+str(e), 0 | |||||
| from dashscope import Generation | from dashscope import Generation | ||||
| ) | ) | ||||
| if response.status_code == HTTPStatus.OK: | if response.status_code == HTTPStatus.OK: | ||||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | return response.output.choices[0]['message']['content'], response.usage.output_tokens | ||||
| return response.message, 0 | |||||
| return "ERROR: " + response.message, 0 | |||||
| from zhipuai import ZhipuAI | from zhipuai import ZhipuAI | ||||
| ) | ) | ||||
| if response.status_code == HTTPStatus.OK: | if response.status_code == HTTPStatus.OK: | ||||
| return response.output.choices[0]['message']['content'], response.usage.completion_tokens | return response.output.choices[0]['message']['content'], response.usage.completion_tokens | ||||
| return response.message, 0 | |||||
| return "ERROR: " + response.message, 0 |
| from . import search | |||||
| from rag.utils import ELASTICSEARCH | |||||
| retrievaler = search.Dealer(ELASTICSEARCH) | |||||
| from nltk.stem import PorterStemmer | from nltk.stem import PorterStemmer | ||||
| stemmer = PorterStemmer() | stemmer = PorterStemmer() | ||||
| ] | ] | ||||
| ] | ] | ||||
| def random_choices(arr, k): | def random_choices(arr, k): | ||||
| k = min(len(arr), k) | k = min(len(arr), k) | ||||
| return random.choices(arr, k=k) | return random.choices(arr, k=k) | ||||
| def bullets_category(sections): | def bullets_category(sections): | ||||
| global BULLET_PATTERN | global BULLET_PATTERN | ||||
| hits = [0] * len(BULLET_PATTERN) | hits = [0] * len(BULLET_PATTERN) |
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import json | import json | ||||
| import re | import re | ||||
| from elasticsearch_dsl import Q, Search, A | |||||
| from elasticsearch_dsl import Q, Search | |||||
| from typing import List, Optional, Dict, Union | from typing import List, Optional, Dict, Union | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| def insert_citations(self, answer, chunks, chunk_v, | def insert_citations(self, answer, chunks, chunk_v, | ||||
| embd_mdl, tkweight=0.3, vtweight=0.7): | embd_mdl, tkweight=0.3, vtweight=0.7): | ||||
| assert len(chunks) == len(chunk_v) | |||||
| pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | ||||
| for i in range(1, len(pieces)): | for i in range(1, len(pieces)): | ||||
| if re.match(r"[a-z][.?;!][ \n]", pieces[i]): | if re.match(r"[a-z][.?;!][ \n]", pieces[i]): | ||||
| if mx < 0.55: | if mx < 0.55: | ||||
| continue | continue | ||||
| cites[idx[i]] = list( | cites[idx[i]] = list( | ||||
| set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] | |||||
| set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] | |||||
| res = "" | res = "" | ||||
| for i, p in enumerate(pieces): | for i, p in enumerate(pieces): | ||||
| continue | continue | ||||
| if i not in cites: | if i not in cites: | ||||
| continue | continue | ||||
| assert int(cites[i]) < len(chunk_v) | |||||
| res += "##%s$$" % "$".join(cites[i]) | res += "##%s$$" % "$".join(cites[i]) | ||||
| return res | return res |