| @@ -198,7 +198,7 @@ def chat(dialog, messages, **kwargs): | |||
| return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | |||
| kwargs["knowledge"] = "\n".join(knowledges) | |||
| gen_conf = dialog.llm_setting[dialog.llm_setting_type] | |||
| gen_conf = dialog.llm_setting | |||
| msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] | |||
| used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | |||
| if "max_tokens" in gen_conf: | |||
| @@ -33,38 +33,17 @@ def set_dialog(): | |||
| name = req.get("name", "New Dialog") | |||
| description = req.get("description", "A helpful Dialog") | |||
| language = req.get("language", "Chinese") | |||
| llm_setting_type = req.get("llm_setting_type", "Precise") | |||
| top_n = req.get("top_n", 6) | |||
| similarity_threshold = req.get("similarity_threshold", 0.1) | |||
| vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | |||
| llm_setting = req.get("llm_setting", { | |||
| "Creative": { | |||
| "temperature": 0.9, | |||
| "top_p": 0.9, | |||
| "frequency_penalty": 0.2, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 512 | |||
| }, | |||
| "Precise": { | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 215 | |||
| }, | |||
| "Evenly": { | |||
| "temperature": 0.5, | |||
| "top_p": 0.5, | |||
| "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 215 | |||
| }, | |||
| "Custom": { | |||
| "temperature": 0.2, | |||
| "top_p": 0.3, | |||
| "frequency_penalty": 0.6, | |||
| "presence_penalty": 0.3, | |||
| "max_tokens": 215 | |||
| }, | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, | |||
| "max_tokens": 215 | |||
| }) | |||
| prompt_config = req.get("prompt_config", { | |||
| default_prompt = { | |||
| "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 | |||
| 以下是知识库: | |||
| {knowledge} | |||
| @@ -74,30 +53,40 @@ def set_dialog(): | |||
| {"key": "knowledge", "optional": False} | |||
| ], | |||
| "empty_response": "Sorry! 知识库中未找到相关内容!" | |||
| }) | |||
| } | |||
| prompt_config = req.get("prompt_config", default_prompt) | |||
| if len(prompt_config["parameters"]) < 1: | |||
| return get_data_error_result(retmsg="'knowledge' should be in parameters") | |||
| if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"] | |||
| # if len(prompt_config["parameters"]) < 1: | |||
| # prompt_config["parameters"] = default_prompt["parameters"] | |||
| # for p in prompt_config["parameters"]: | |||
| # if p["key"] == "knowledge":break | |||
| # else: prompt_config["parameters"].append(default_prompt["parameters"][0]) | |||
| for p in prompt_config["parameters"]: | |||
| if prompt_config["system"].find("{%s}"%p["key"]) < 0: | |||
| if p["optional"]: continue | |||
| if prompt_config["system"].find("{%s}" % p["key"]) < 0: | |||
| return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) | |||
| try: | |||
| e, tenant = TenantService.get_by_id(current_user.id) | |||
| if not e:return get_data_error_result(retmsg="Tenant not found!") | |||
| if not e: return get_data_error_result(retmsg="Tenant not found!") | |||
| llm_id = req.get("llm_id", tenant.llm_id) | |||
| if not dialog_id: | |||
| if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!") | |||
| dia = { | |||
| "id": get_uuid(), | |||
| "tenant_id": current_user.id, | |||
| "name": name, | |||
| "kb_ids": req["kb_ids"], | |||
| "description": description, | |||
| "language": language, | |||
| "llm_id": llm_id, | |||
| "llm_setting_type": llm_setting_type, | |||
| "llm_setting": llm_setting, | |||
| "prompt_config": prompt_config | |||
| "prompt_config": prompt_config, | |||
| "top_n": top_n, | |||
| "similarity_threshold": similarity_threshold, | |||
| "vector_similarity_weight": vector_similarity_weight | |||
| } | |||
| if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| e, dia = DialogService.get_by_id(dia["id"]) | |||
| @@ -122,7 +111,7 @@ def set_dialog(): | |||
| def get(): | |||
| dialog_id = request.args["dialog_id"] | |||
| try: | |||
| e,dia = DialogService.get_by_id(dialog_id) | |||
| e, dia = DialogService.get_by_id(dialog_id) | |||
| if not e: return get_data_error_result(retmsg="Dialog not found!") | |||
| dia = dia.to_dict() | |||
| dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | |||
| @@ -130,20 +119,22 @@ def get(): | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| def get_kb_names(kb_ids): | |||
| ids, nms = [], [] | |||
| for kid in kb_ids: | |||
| e, kb = KnowledgebaseService.get_by_id(kid) | |||
| if not e or kb.status != StatusEnum.VALID.value:continue | |||
| if not e or kb.status != StatusEnum.VALID.value: continue | |||
| ids.append(kid) | |||
| nms.append(kb.name) | |||
| return ids, nms | |||
| @manager.route('/list', methods=['GET']) | |||
| @login_required | |||
| def list(): | |||
| try: | |||
| diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value) | |||
| diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) | |||
| diags = [d.to_dict() for d in diags] | |||
| for d in diags: | |||
| d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) | |||
| @@ -154,12 +145,11 @@ def list(): | |||
| @manager.route('/rm', methods=['POST']) | |||
| @login_required | |||
| @validate_request("dialog_id") | |||
| @validate_request("dialog_ids") | |||
| def rm(): | |||
| req = request.json | |||
| try: | |||
| if not DialogService.update_by_id(req["dialog_id"], {"status": StatusEnum.INVALID.value}): | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| return server_error_response(e) | |||
| @@ -529,8 +529,6 @@ class Dialog(DataBaseModel): | |||
| icon = CharField(max_length=16, null=False, help_text="dialog icon") | |||
| language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") | |||
| llm_id = CharField(max_length=32, null=False, help_text="default llm ID") | |||
| llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom", | |||
| default="Creative") | |||
| llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, "max_tokens": 215}) | |||
| prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") | |||
| @@ -1,4 +1,3 @@ | |||
| import copy | |||
| import random | |||
| from .pdf_parser import HuParser as PdfParser | |||
| @@ -10,7 +9,7 @@ import re | |||
| from nltk import word_tokenize | |||
| from rag.nlp import stemmer, huqie | |||
| from ..utils import num_tokens_from_string | |||
| from rag.utils import num_tokens_from_string | |||
| BULLET_PATTERN = [[ | |||
| r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", | |||
| @@ -1,7 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import random | |||
| from functools import partial | |||
| import fitz | |||
| import requests | |||
| @@ -15,6 +14,7 @@ from PIL import Image | |||
| import numpy as np | |||
| from api.db import ParserType | |||
| from deepdoc.visual import OCR, Recognizer | |||
| from rag.nlp import huqie | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| @@ -26,13 +26,32 @@ logging.getLogger("pdfminer").setLevel(logging.WARNING) | |||
| class HuParser: | |||
| def __init__(self): | |||
| from paddleocr import PaddleOCR | |||
| logging.getLogger("ppocr").setLevel(logging.ERROR) | |||
| self.ocr = PaddleOCR(use_angle_cls=False, lang="ch") | |||
| self.ocr = OCR() | |||
| if not hasattr(self, "model_speciess"): | |||
| self.model_speciess = ParserType.GENERAL.value | |||
| self.layouter = partial(self.__remote_call, self.model_speciess) | |||
| self.tbl_det = partial(self.__remote_call, "table_component") | |||
| self.layout_labels = [ | |||
| "_background_", | |||
| "Text", | |||
| "Title", | |||
| "Figure", | |||
| "Figure caption", | |||
| "Table", | |||
| "Table caption", | |||
| "Header", | |||
| "Footer", | |||
| "Reference", | |||
| "Equation", | |||
| ] | |||
| self.tsr_labels = [ | |||
| "table", | |||
| "table column", | |||
| "table row", | |||
| "table column header", | |||
| "table projected row header", | |||
| "table spanning cell", | |||
| ] | |||
| self.layouter = Recognizer(self.layout_labels, "layout", "/data/newpeak/medical-gpt/res/ppdet/") | |||
| self.tbl_det = Recognizer(self.tsr_labels, "tsr", "/data/newpeak/medical-gpt/res/ppdet.tbl/") | |||
| self.updown_cnt_mdl = xgb.Booster() | |||
| if torch.cuda.is_available(): | |||
| @@ -56,7 +75,7 @@ class HuParser: | |||
| token = os.environ.get("INFINIFLOW_TOKEN") | |||
| if not url or not token: | |||
| logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.") | |||
| return [] | |||
| return [[] for _ in range(len(images))] | |||
| def convert_image_to_bytes(PILimage): | |||
| image = BytesIO() | |||
| @@ -382,7 +401,7 @@ class HuParser: | |||
| return layouts | |||
| def __table_paddle(self, images): | |||
| def __table_tsr(self, images): | |||
| tbls = self.tbl_det(images, thr=0.5) | |||
| res = [] | |||
| # align left&right for rows, align top&bottom for columns | |||
| @@ -452,7 +471,7 @@ class HuParser: | |||
| assert len(self.page_images) == len(tbcnt) - 1 | |||
| if not imgs: | |||
| return | |||
| recos = self.__table_paddle(imgs) | |||
| recos = self.__table_tsr(imgs) | |||
| tbcnt = np.cumsum(tbcnt) | |||
| for i in range(len(tbcnt) - 1): # for page | |||
| pg = [] | |||
| @@ -517,8 +536,8 @@ class HuParser: | |||
| b["H_right"] = spans[ii]["x1"] | |||
| b["SP"] = ii | |||
| def __ocr_paddle(self, pagenum, img, chars, ZM=3): | |||
| bxs = self.ocr.ocr(np.array(img), cls=True)[0] | |||
| def __ocr(self, pagenum, img, chars, ZM=3): | |||
| bxs = self.ocr(np.array(img)) | |||
| if not bxs: | |||
| self.boxes.append([]) | |||
| return | |||
| @@ -557,11 +576,12 @@ class HuParser: | |||
| self.boxes.append(bxs) | |||
| def _layouts_paddle(self, ZM): | |||
| def _layouts_rec(self, ZM): | |||
| assert len(self.page_images) == len(self.boxes) | |||
| # Tag layout type | |||
| boxes = [] | |||
| layouts = self.layouter(self.page_images) | |||
| #save_results(self.page_images, layouts, self.layout_labels, output_dir='output/', threshold=0.7) | |||
| assert len(self.page_images) == len(layouts) | |||
| for pn, lts in enumerate(layouts): | |||
| bxs = self.boxes[pn] | |||
| @@ -1741,7 +1761,7 @@ class HuParser: | |||
| # else: | |||
| # self.page_cum_height.append( | |||
| # np.max([c["bottom"] for c in chars])) | |||
| self.__ocr_paddle(i + 1, img, chars, zoomin) | |||
| self.__ocr(i + 1, img, chars, zoomin) | |||
| if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | |||
| bxes = [b for bxs in self.boxes for b in bxs] | |||
| @@ -1754,7 +1774,7 @@ class HuParser: | |||
| def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): | |||
| self.__images__(fnm, zoomin) | |||
| self._layouts_paddle(zoomin) | |||
| self._layouts_rec(zoomin) | |||
| self._table_transformer_job(zoomin) | |||
| self._text_merge() | |||
| self._concat_downward() | |||
| @@ -0,0 +1,2 @@ | |||
| from .ocr import OCR | |||
| from .recognizer import Recognizer | |||
| @@ -0,0 +1,561 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import time | |||
| import os | |||
| from huggingface_hub import snapshot_download | |||
| from .operators import * | |||
| import numpy as np | |||
| import onnxruntime as ort | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from .postprocess import build_post_process | |||
| from rag.settings import cron_logger | |||
| def transform(data, ops=None): | |||
| """ transform """ | |||
| if ops is None: | |||
| ops = [] | |||
| for op in ops: | |||
| data = op(data) | |||
| if data is None: | |||
| return None | |||
| return data | |||
| def create_operators(op_param_list, global_config=None): | |||
| """ | |||
| create operators based on the config | |||
| Args: | |||
| params(list): a dict list, used to create some operators | |||
| """ | |||
| assert isinstance( | |||
| op_param_list, list), ('operator config should be a list') | |||
| ops = [] | |||
| for operator in op_param_list: | |||
| assert isinstance(operator, | |||
| dict) and len(operator) == 1, "yaml format error" | |||
| op_name = list(operator)[0] | |||
| param = {} if operator[op_name] is None else operator[op_name] | |||
| if global_config is not None: | |||
| param.update(global_config) | |||
| op = eval(op_name)(**param) | |||
| ops.append(op) | |||
| return ops | |||
| def load_model(model_dir, nm): | |||
| model_file_path = os.path.join(model_dir, nm + ".onnx") | |||
| if not os.path.exists(model_file_path): | |||
| raise ValueError("not find model file path {}".format( | |||
| model_file_path)) | |||
| sess = ort.InferenceSession(model_file_path) | |||
| return sess, sess.get_inputs()[0] | |||
| class TextRecognizer(object): | |||
| def __init__(self, model_dir): | |||
| self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] | |||
| self.rec_batch_num = 16 | |||
| postprocess_params = { | |||
| 'name': 'CTCLabelDecode', | |||
| "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"), | |||
| "use_space_char": True | |||
| } | |||
| self.postprocess_op = build_post_process(postprocess_params) | |||
| self.predictor, self.input_tensor = load_model(model_dir, 'rec') | |||
| def resize_norm_img(self, img, max_wh_ratio): | |||
| imgC, imgH, imgW = self.rec_image_shape | |||
| assert imgC == img.shape[2] | |||
| imgW = int((imgH * max_wh_ratio)) | |||
| w = self.input_tensor.shape[3:][0] | |||
| if isinstance(w, str): | |||
| pass | |||
| elif w is not None and w > 0: | |||
| imgW = w | |||
| h, w = img.shape[:2] | |||
| ratio = w / float(h) | |||
| if math.ceil(imgH * ratio) > imgW: | |||
| resized_w = imgW | |||
| else: | |||
| resized_w = int(math.ceil(imgH * ratio)) | |||
| resized_image = cv2.resize(img, (resized_w, imgH)) | |||
| resized_image = resized_image.astype('float32') | |||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||
| resized_image -= 0.5 | |||
| resized_image /= 0.5 | |||
| padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) | |||
| padding_im[:, :, 0:resized_w] = resized_image | |||
| return padding_im | |||
| def resize_norm_img_vl(self, img, image_shape): | |||
| imgC, imgH, imgW = image_shape | |||
| img = img[:, :, ::-1] # bgr2rgb | |||
| resized_image = cv2.resize( | |||
| img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) | |||
| resized_image = resized_image.astype('float32') | |||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||
| return resized_image | |||
| def resize_norm_img_srn(self, img, image_shape): | |||
| imgC, imgH, imgW = image_shape | |||
| img_black = np.zeros((imgH, imgW)) | |||
| im_hei = img.shape[0] | |||
| im_wid = img.shape[1] | |||
| if im_wid <= im_hei * 1: | |||
| img_new = cv2.resize(img, (imgH * 1, imgH)) | |||
| elif im_wid <= im_hei * 2: | |||
| img_new = cv2.resize(img, (imgH * 2, imgH)) | |||
| elif im_wid <= im_hei * 3: | |||
| img_new = cv2.resize(img, (imgH * 3, imgH)) | |||
| else: | |||
| img_new = cv2.resize(img, (imgW, imgH)) | |||
| img_np = np.asarray(img_new) | |||
| img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) | |||
| img_black[:, 0:img_np.shape[1]] = img_np | |||
| img_black = img_black[:, :, np.newaxis] | |||
| row, col, c = img_black.shape | |||
| c = 1 | |||
| return np.reshape(img_black, (c, row, col)).astype(np.float32) | |||
| def srn_other_inputs(self, image_shape, num_heads, max_text_length): | |||
| imgC, imgH, imgW = image_shape | |||
| feature_dim = int((imgH / 8) * (imgW / 8)) | |||
| encoder_word_pos = np.array(range(0, feature_dim)).reshape( | |||
| (feature_dim, 1)).astype('int64') | |||
| gsrm_word_pos = np.array(range(0, max_text_length)).reshape( | |||
| (max_text_length, 1)).astype('int64') | |||
| gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) | |||
| gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( | |||
| [-1, 1, max_text_length, max_text_length]) | |||
| gsrm_slf_attn_bias1 = np.tile( | |||
| gsrm_slf_attn_bias1, | |||
| [1, num_heads, 1, 1]).astype('float32') * [-1e9] | |||
| gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( | |||
| [-1, 1, max_text_length, max_text_length]) | |||
| gsrm_slf_attn_bias2 = np.tile( | |||
| gsrm_slf_attn_bias2, | |||
| [1, num_heads, 1, 1]).astype('float32') * [-1e9] | |||
| encoder_word_pos = encoder_word_pos[np.newaxis, :] | |||
| gsrm_word_pos = gsrm_word_pos[np.newaxis, :] | |||
| return [ | |||
| encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, | |||
| gsrm_slf_attn_bias2 | |||
| ] | |||
| def process_image_srn(self, img, image_shape, num_heads, max_text_length): | |||
| norm_img = self.resize_norm_img_srn(img, image_shape) | |||
| norm_img = norm_img[np.newaxis, :] | |||
| [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ | |||
| self.srn_other_inputs(image_shape, num_heads, max_text_length) | |||
| gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) | |||
| gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) | |||
| encoder_word_pos = encoder_word_pos.astype(np.int64) | |||
| gsrm_word_pos = gsrm_word_pos.astype(np.int64) | |||
| return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, | |||
| gsrm_slf_attn_bias2) | |||
| def resize_norm_img_sar(self, img, image_shape, | |||
| width_downsample_ratio=0.25): | |||
| imgC, imgH, imgW_min, imgW_max = image_shape | |||
| h = img.shape[0] | |||
| w = img.shape[1] | |||
| valid_ratio = 1.0 | |||
| # make sure new_width is an integral multiple of width_divisor. | |||
| width_divisor = int(1 / width_downsample_ratio) | |||
| # resize | |||
| ratio = w / float(h) | |||
| resize_w = math.ceil(imgH * ratio) | |||
| if resize_w % width_divisor != 0: | |||
| resize_w = round(resize_w / width_divisor) * width_divisor | |||
| if imgW_min is not None: | |||
| resize_w = max(imgW_min, resize_w) | |||
| if imgW_max is not None: | |||
| valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) | |||
| resize_w = min(imgW_max, resize_w) | |||
| resized_image = cv2.resize(img, (resize_w, imgH)) | |||
| resized_image = resized_image.astype('float32') | |||
| # norm | |||
| if image_shape[0] == 1: | |||
| resized_image = resized_image / 255 | |||
| resized_image = resized_image[np.newaxis, :] | |||
| else: | |||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||
| resized_image -= 0.5 | |||
| resized_image /= 0.5 | |||
| resize_shape = resized_image.shape | |||
| padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) | |||
| padding_im[:, :, 0:resize_w] = resized_image | |||
| pad_shape = padding_im.shape | |||
| return padding_im, resize_shape, pad_shape, valid_ratio | |||
| def resize_norm_img_spin(self, img): | |||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |||
| # return padding_im | |||
| img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC) | |||
| img = np.array(img, np.float32) | |||
| img = np.expand_dims(img, -1) | |||
| img = img.transpose((2, 0, 1)) | |||
| mean = [127.5] | |||
| std = [127.5] | |||
| mean = np.array(mean, dtype=np.float32) | |||
| std = np.array(std, dtype=np.float32) | |||
| mean = np.float32(mean.reshape(1, -1)) | |||
| stdinv = 1 / np.float32(std.reshape(1, -1)) | |||
| img -= mean | |||
| img *= stdinv | |||
| return img | |||
| def resize_norm_img_svtr(self, img, image_shape): | |||
| imgC, imgH, imgW = image_shape | |||
| resized_image = cv2.resize( | |||
| img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) | |||
| resized_image = resized_image.astype('float32') | |||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||
| resized_image -= 0.5 | |||
| resized_image /= 0.5 | |||
| return resized_image | |||
| def resize_norm_img_abinet(self, img, image_shape): | |||
| imgC, imgH, imgW = image_shape | |||
| resized_image = cv2.resize( | |||
| img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) | |||
| resized_image = resized_image.astype('float32') | |||
| resized_image = resized_image / 255. | |||
| mean = np.array([0.485, 0.456, 0.406]) | |||
| std = np.array([0.229, 0.224, 0.225]) | |||
| resized_image = ( | |||
| resized_image - mean[None, None, ...]) / std[None, None, ...] | |||
| resized_image = resized_image.transpose((2, 0, 1)) | |||
| resized_image = resized_image.astype('float32') | |||
| return resized_image | |||
| def norm_img_can(self, img, image_shape): | |||
| img = cv2.cvtColor( | |||
| img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image | |||
| if self.rec_image_shape[0] == 1: | |||
| h, w = img.shape | |||
| _, imgH, imgW = self.rec_image_shape | |||
| if h < imgH or w < imgW: | |||
| padding_h = max(imgH - h, 0) | |||
| padding_w = max(imgW - w, 0) | |||
| img_padded = np.pad(img, ((0, padding_h), (0, padding_w)), | |||
| 'constant', | |||
| constant_values=(255)) | |||
| img = img_padded | |||
| img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w | |||
| img = img.astype('float32') | |||
| return img | |||
| def __call__(self, img_list): | |||
| img_num = len(img_list) | |||
| # Calculate the aspect ratio of all text bars | |||
| width_list = [] | |||
| for img in img_list: | |||
| width_list.append(img.shape[1] / float(img.shape[0])) | |||
| # Sorting can speed up the recognition process | |||
| indices = np.argsort(np.array(width_list)) | |||
| rec_res = [['', 0.0]] * img_num | |||
| batch_num = self.rec_batch_num | |||
| st = time.time() | |||
| for beg_img_no in range(0, img_num, batch_num): | |||
| end_img_no = min(img_num, beg_img_no + batch_num) | |||
| norm_img_batch = [] | |||
| imgC, imgH, imgW = self.rec_image_shape[:3] | |||
| max_wh_ratio = imgW / imgH | |||
| # max_wh_ratio = 0 | |||
| for ino in range(beg_img_no, end_img_no): | |||
| h, w = img_list[indices[ino]].shape[0:2] | |||
| wh_ratio = w * 1.0 / h | |||
| max_wh_ratio = max(max_wh_ratio, wh_ratio) | |||
| for ino in range(beg_img_no, end_img_no): | |||
| norm_img = self.resize_norm_img(img_list[indices[ino]], | |||
| max_wh_ratio) | |||
| norm_img = norm_img[np.newaxis, :] | |||
| norm_img_batch.append(norm_img) | |||
| norm_img_batch = np.concatenate(norm_img_batch) | |||
| norm_img_batch = norm_img_batch.copy() | |||
| input_dict = {} | |||
| input_dict[self.input_tensor.name] = norm_img_batch | |||
| outputs = self.predictor.run(None, input_dict) | |||
| preds = outputs[0] | |||
| rec_result = self.postprocess_op(preds) | |||
| for rno in range(len(rec_result)): | |||
| rec_res[indices[beg_img_no + rno]] = rec_result[rno] | |||
| return rec_res, time.time() - st | |||
| class TextDetector(object): | |||
| def __init__(self, model_dir): | |||
| pre_process_list = [{ | |||
| 'DetResizeForTest': { | |||
| 'limit_side_len': 960, | |||
| 'limit_type': "max", | |||
| } | |||
| }, { | |||
| 'NormalizeImage': { | |||
| 'std': [0.229, 0.224, 0.225], | |||
| 'mean': [0.485, 0.456, 0.406], | |||
| 'scale': '1./255.', | |||
| 'order': 'hwc' | |||
| } | |||
| }, { | |||
| 'ToCHWImage': None | |||
| }, { | |||
| 'KeepKeys': { | |||
| 'keep_keys': ['image', 'shape'] | |||
| } | |||
| }] | |||
| postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.6, "max_candidates": 1000, | |||
| "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} | |||
| self.postprocess_op = build_post_process(postprocess_params) | |||
| self.predictor, self.input_tensor = load_model(model_dir, 'det') | |||
| img_h, img_w = self.input_tensor.shape[2:] | |||
| if isinstance(img_h, str) or isinstance(img_w, str): | |||
| pass | |||
| elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0: | |||
| pre_process_list[0] = { | |||
| 'DetResizeForTest': { | |||
| 'image_shape': [img_h, img_w] | |||
| } | |||
| } | |||
| self.preprocess_op = create_operators(pre_process_list) | |||
| def order_points_clockwise(self, pts): | |||
| rect = np.zeros((4, 2), dtype="float32") | |||
| s = pts.sum(axis=1) | |||
| rect[0] = pts[np.argmin(s)] | |||
| rect[2] = pts[np.argmax(s)] | |||
| tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) | |||
| diff = np.diff(np.array(tmp), axis=1) | |||
| rect[1] = tmp[np.argmin(diff)] | |||
| rect[3] = tmp[np.argmax(diff)] | |||
| return rect | |||
| def clip_det_res(self, points, img_height, img_width): | |||
| for pno in range(points.shape[0]): | |||
| points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) | |||
| points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) | |||
| return points | |||
| def filter_tag_det_res(self, dt_boxes, image_shape): | |||
| img_height, img_width = image_shape[0:2] | |||
| dt_boxes_new = [] | |||
| for box in dt_boxes: | |||
| if isinstance(box, list): | |||
| box = np.array(box) | |||
| box = self.order_points_clockwise(box) | |||
| box = self.clip_det_res(box, img_height, img_width) | |||
| rect_width = int(np.linalg.norm(box[0] - box[1])) | |||
| rect_height = int(np.linalg.norm(box[0] - box[3])) | |||
| if rect_width <= 3 or rect_height <= 3: | |||
| continue | |||
| dt_boxes_new.append(box) | |||
| dt_boxes = np.array(dt_boxes_new) | |||
| return dt_boxes | |||
| def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): | |||
| img_height, img_width = image_shape[0:2] | |||
| dt_boxes_new = [] | |||
| for box in dt_boxes: | |||
| if isinstance(box, list): | |||
| box = np.array(box) | |||
| box = self.clip_det_res(box, img_height, img_width) | |||
| dt_boxes_new.append(box) | |||
| dt_boxes = np.array(dt_boxes_new) | |||
| return dt_boxes | |||
| def __call__(self, img): | |||
| ori_im = img.copy() | |||
| data = {'image': img} | |||
| st = time.time() | |||
| data = transform(data, self.preprocess_op) | |||
| img, shape_list = data | |||
| if img is None: | |||
| return None, 0 | |||
| img = np.expand_dims(img, axis=0) | |||
| shape_list = np.expand_dims(shape_list, axis=0) | |||
| img = img.copy() | |||
| input_dict = {} | |||
| input_dict[self.input_tensor.name] = img | |||
| outputs = self.predictor.run(None, input_dict) | |||
| post_result = self.postprocess_op({"maps": outputs[0]}, shape_list) | |||
| dt_boxes = post_result[0]['points'] | |||
| dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) | |||
| return dt_boxes, time.time() - st | |||
| class OCR(object): | |||
| def __init__(self, model_dir=None): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| For Linux: | |||
| export HF_ENDPOINT=https://hf-mirror.com | |||
| For Windows: | |||
| Good luck | |||
| ^_- | |||
| """ | |||
| if not model_dir: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||
| self.text_detector = TextDetector(model_dir) | |||
| self.text_recognizer = TextRecognizer(model_dir) | |||
| self.drop_score = 0.5 | |||
| self.crop_image_res_index = 0 | |||
| def get_rotate_crop_image(self, img, points): | |||
| ''' | |||
| img_height, img_width = img.shape[0:2] | |||
| left = int(np.min(points[:, 0])) | |||
| right = int(np.max(points[:, 0])) | |||
| top = int(np.min(points[:, 1])) | |||
| bottom = int(np.max(points[:, 1])) | |||
| img_crop = img[top:bottom, left:right, :].copy() | |||
| points[:, 0] = points[:, 0] - left | |||
| points[:, 1] = points[:, 1] - top | |||
| ''' | |||
| assert len(points) == 4, "shape of points must be 4*2" | |||
| img_crop_width = int( | |||
| max( | |||
| np.linalg.norm(points[0] - points[1]), | |||
| np.linalg.norm(points[2] - points[3]))) | |||
| img_crop_height = int( | |||
| max( | |||
| np.linalg.norm(points[0] - points[3]), | |||
| np.linalg.norm(points[1] - points[2]))) | |||
| pts_std = np.float32([[0, 0], [img_crop_width, 0], | |||
| [img_crop_width, img_crop_height], | |||
| [0, img_crop_height]]) | |||
| M = cv2.getPerspectiveTransform(points, pts_std) | |||
| dst_img = cv2.warpPerspective( | |||
| img, | |||
| M, (img_crop_width, img_crop_height), | |||
| borderMode=cv2.BORDER_REPLICATE, | |||
| flags=cv2.INTER_CUBIC) | |||
| dst_img_height, dst_img_width = dst_img.shape[0:2] | |||
| if dst_img_height * 1.0 / dst_img_width >= 1.5: | |||
| dst_img = np.rot90(dst_img) | |||
| return dst_img | |||
| def sorted_boxes(self, dt_boxes): | |||
| """ | |||
| Sort text boxes in order from top to bottom, left to right | |||
| args: | |||
| dt_boxes(array):detected text boxes with shape [4, 2] | |||
| return: | |||
| sorted boxes(array) with shape [4, 2] | |||
| """ | |||
| num_boxes = dt_boxes.shape[0] | |||
| sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) | |||
| _boxes = list(sorted_boxes) | |||
| for i in range(num_boxes - 1): | |||
| for j in range(i, -1, -1): | |||
| if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ | |||
| (_boxes[j + 1][0][0] < _boxes[j][0][0]): | |||
| tmp = _boxes[j] | |||
| _boxes[j] = _boxes[j + 1] | |||
| _boxes[j + 1] = tmp | |||
| else: | |||
| break | |||
| return _boxes | |||
| def __call__(self, img, cls=True): | |||
| time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | |||
| if img is None: | |||
| return None, None, time_dict | |||
| start = time.time() | |||
| ori_im = img.copy() | |||
| dt_boxes, elapse = self.text_detector(img) | |||
| time_dict['det'] = elapse | |||
| if dt_boxes is None: | |||
| end = time.time() | |||
| time_dict['all'] = end - start | |||
| return None, None, time_dict | |||
| else: | |||
| cron_logger.debug("dt_boxes num : {}, elapsed : {}".format( | |||
| len(dt_boxes), elapse)) | |||
| img_crop_list = [] | |||
| dt_boxes = self.sorted_boxes(dt_boxes) | |||
| for bno in range(len(dt_boxes)): | |||
| tmp_box = copy.deepcopy(dt_boxes[bno]) | |||
| img_crop = self.get_rotate_crop_image(ori_im, tmp_box) | |||
| img_crop_list.append(img_crop) | |||
| rec_res, elapse = self.text_recognizer(img_crop_list) | |||
| time_dict['rec'] = elapse | |||
| cron_logger.debug("rec_res num : {}, elapsed : {}".format( | |||
| len(rec_res), elapse)) | |||
| filter_boxes, filter_rec_res = [], [] | |||
| for box, rec_result in zip(dt_boxes, rec_res): | |||
| text, score = rec_result | |||
| if score >= self.drop_score: | |||
| filter_boxes.append(box) | |||
| filter_rec_res.append(rec_result) | |||
| end = time.time() | |||
| time_dict['all'] = end - start | |||
| #for bno in range(len(img_crop_list)): | |||
| # print(f"{bno}, {rec_res[bno]}") | |||
| return list(zip([a.tolist() for a in filter_boxes], filter_rec_res)) | |||
| @@ -0,0 +1,710 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import sys | |||
| import six | |||
| import cv2 | |||
| import numpy as np | |||
| import math | |||
| from PIL import Image | |||
| class DecodeImage(object): | |||
| """ decode image """ | |||
| def __init__(self, | |||
| img_mode='RGB', | |||
| channel_first=False, | |||
| ignore_orientation=False, | |||
| **kwargs): | |||
| self.img_mode = img_mode | |||
| self.channel_first = channel_first | |||
| self.ignore_orientation = ignore_orientation | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| if six.PY2: | |||
| assert isinstance(img, str) and len( | |||
| img) > 0, "invalid input 'img' in DecodeImage" | |||
| else: | |||
| assert isinstance(img, bytes) and len( | |||
| img) > 0, "invalid input 'img' in DecodeImage" | |||
| img = np.frombuffer(img, dtype='uint8') | |||
| if self.ignore_orientation: | |||
| img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | | |||
| cv2.IMREAD_COLOR) | |||
| else: | |||
| img = cv2.imdecode(img, 1) | |||
| if img is None: | |||
| return None | |||
| if self.img_mode == 'GRAY': | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| elif self.img_mode == 'RGB': | |||
| assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( | |||
| img.shape) | |||
| img = img[:, :, ::-1] | |||
| if self.channel_first: | |||
| img = img.transpose((2, 0, 1)) | |||
| data['image'] = img | |||
| return data | |||
| class StandardizeImage(object): | |||
| """normalize image | |||
| Args: | |||
| mean (list): im - mean | |||
| std (list): im / std | |||
| is_scale (bool): whether need im / 255 | |||
| norm_type (str): type in ['mean_std', 'none'] | |||
| """ | |||
| def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): | |||
| self.mean = mean | |||
| self.std = std | |||
| self.is_scale = is_scale | |||
| self.norm_type = norm_type | |||
| def __call__(self, im, im_info): | |||
| """ | |||
| Args: | |||
| im (np.ndarray): image (np.ndarray) | |||
| im_info (dict): info of image | |||
| Returns: | |||
| im (np.ndarray): processed image (np.ndarray) | |||
| im_info (dict): info of processed image | |||
| """ | |||
| im = im.astype(np.float32, copy=False) | |||
| if self.is_scale: | |||
| scale = 1.0 / 255.0 | |||
| im *= scale | |||
| if self.norm_type == 'mean_std': | |||
| mean = np.array(self.mean)[np.newaxis, np.newaxis, :] | |||
| std = np.array(self.std)[np.newaxis, np.newaxis, :] | |||
| im -= mean | |||
| im /= std | |||
| return im, im_info | |||
| class NormalizeImage(object): | |||
| """ normalize image such as substract mean, divide std | |||
| """ | |||
| def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): | |||
| if isinstance(scale, str): | |||
| scale = eval(scale) | |||
| self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) | |||
| mean = mean if mean is not None else [0.485, 0.456, 0.406] | |||
| std = std if std is not None else [0.229, 0.224, 0.225] | |||
| shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) | |||
| self.mean = np.array(mean).reshape(shape).astype('float32') | |||
| self.std = np.array(std).reshape(shape).astype('float32') | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| from PIL import Image | |||
| if isinstance(img, Image.Image): | |||
| img = np.array(img) | |||
| assert isinstance(img, | |||
| np.ndarray), "invalid input 'img' in NormalizeImage" | |||
| data['image'] = ( | |||
| img.astype('float32') * self.scale - self.mean) / self.std | |||
| return data | |||
| class ToCHWImage(object): | |||
| """ convert hwc image to chw image | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| pass | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| from PIL import Image | |||
| if isinstance(img, Image.Image): | |||
| img = np.array(img) | |||
| data['image'] = img.transpose((2, 0, 1)) | |||
| return data | |||
| class Fasttext(object): | |||
| def __init__(self, path="None", **kwargs): | |||
| import fasttext | |||
| self.fast_model = fasttext.load_model(path) | |||
| def __call__(self, data): | |||
| label = data['label'] | |||
| fast_label = self.fast_model[label] | |||
| data['fast_label'] = fast_label | |||
| return data | |||
| class KeepKeys(object): | |||
| def __init__(self, keep_keys, **kwargs): | |||
| self.keep_keys = keep_keys | |||
| def __call__(self, data): | |||
| data_list = [] | |||
| for key in self.keep_keys: | |||
| data_list.append(data[key]) | |||
| return data_list | |||
| class Pad(object): | |||
| def __init__(self, size=None, size_div=32, **kwargs): | |||
| if size is not None and not isinstance(size, (int, list, tuple)): | |||
| raise TypeError("Type of target_size is invalid. Now is {}".format( | |||
| type(size))) | |||
| if isinstance(size, int): | |||
| size = [size, size] | |||
| self.size = size | |||
| self.size_div = size_div | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| img_h, img_w = img.shape[0], img.shape[1] | |||
| if self.size: | |||
| resize_h2, resize_w2 = self.size | |||
| assert ( | |||
| img_h < resize_h2 and img_w < resize_w2 | |||
| ), '(h, w) of target size should be greater than (img_h, img_w)' | |||
| else: | |||
| resize_h2 = max( | |||
| int(math.ceil(img.shape[0] / self.size_div) * self.size_div), | |||
| self.size_div) | |||
| resize_w2 = max( | |||
| int(math.ceil(img.shape[1] / self.size_div) * self.size_div), | |||
| self.size_div) | |||
| img = cv2.copyMakeBorder( | |||
| img, | |||
| 0, | |||
| resize_h2 - img_h, | |||
| 0, | |||
| resize_w2 - img_w, | |||
| cv2.BORDER_CONSTANT, | |||
| value=0) | |||
| data['image'] = img | |||
| return data | |||
| class LinearResize(object): | |||
| """resize image by target_size and max_size | |||
| Args: | |||
| target_size (int): the target size of image | |||
| keep_ratio (bool): whether keep_ratio or not, default true | |||
| interp (int): method of resize | |||
| """ | |||
| def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): | |||
| if isinstance(target_size, int): | |||
| target_size = [target_size, target_size] | |||
| self.target_size = target_size | |||
| self.keep_ratio = keep_ratio | |||
| self.interp = interp | |||
| def __call__(self, im, im_info): | |||
| """ | |||
| Args: | |||
| im (np.ndarray): image (np.ndarray) | |||
| im_info (dict): info of image | |||
| Returns: | |||
| im (np.ndarray): processed image (np.ndarray) | |||
| im_info (dict): info of processed image | |||
| """ | |||
| assert len(self.target_size) == 2 | |||
| assert self.target_size[0] > 0 and self.target_size[1] > 0 | |||
| im_channel = im.shape[2] | |||
| im_scale_y, im_scale_x = self.generate_scale(im) | |||
| im = cv2.resize( | |||
| im, | |||
| None, | |||
| None, | |||
| fx=im_scale_x, | |||
| fy=im_scale_y, | |||
| interpolation=self.interp) | |||
| im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') | |||
| im_info['scale_factor'] = np.array( | |||
| [im_scale_y, im_scale_x]).astype('float32') | |||
| return im, im_info | |||
| def generate_scale(self, im): | |||
| """ | |||
| Args: | |||
| im (np.ndarray): image (np.ndarray) | |||
| Returns: | |||
| im_scale_x: the resize ratio of X | |||
| im_scale_y: the resize ratio of Y | |||
| """ | |||
| origin_shape = im.shape[:2] | |||
| im_c = im.shape[2] | |||
| if self.keep_ratio: | |||
| im_size_min = np.min(origin_shape) | |||
| im_size_max = np.max(origin_shape) | |||
| target_size_min = np.min(self.target_size) | |||
| target_size_max = np.max(self.target_size) | |||
| im_scale = float(target_size_min) / float(im_size_min) | |||
| if np.round(im_scale * im_size_max) > target_size_max: | |||
| im_scale = float(target_size_max) / float(im_size_max) | |||
| im_scale_x = im_scale | |||
| im_scale_y = im_scale | |||
| else: | |||
| resize_h, resize_w = self.target_size | |||
| im_scale_y = resize_h / float(origin_shape[0]) | |||
| im_scale_x = resize_w / float(origin_shape[1]) | |||
| return im_scale_y, im_scale_x | |||
| class Resize(object): | |||
| def __init__(self, size=(640, 640), **kwargs): | |||
| self.size = size | |||
| def resize_image(self, img): | |||
| resize_h, resize_w = self.size | |||
| ori_h, ori_w = img.shape[:2] # (h, w, c) | |||
| ratio_h = float(resize_h) / ori_h | |||
| ratio_w = float(resize_w) / ori_w | |||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||
| return img, [ratio_h, ratio_w] | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| if 'polys' in data: | |||
| text_polys = data['polys'] | |||
| img_resize, [ratio_h, ratio_w] = self.resize_image(img) | |||
| if 'polys' in data: | |||
| new_boxes = [] | |||
| for box in text_polys: | |||
| new_box = [] | |||
| for cord in box: | |||
| new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) | |||
| new_boxes.append(new_box) | |||
| data['polys'] = np.array(new_boxes, dtype=np.float32) | |||
| data['image'] = img_resize | |||
| return data | |||
| class DetResizeForTest(object): | |||
| def __init__(self, **kwargs): | |||
| super(DetResizeForTest, self).__init__() | |||
| self.resize_type = 0 | |||
| self.keep_ratio = False | |||
| if 'image_shape' in kwargs: | |||
| self.image_shape = kwargs['image_shape'] | |||
| self.resize_type = 1 | |||
| if 'keep_ratio' in kwargs: | |||
| self.keep_ratio = kwargs['keep_ratio'] | |||
| elif 'limit_side_len' in kwargs: | |||
| self.limit_side_len = kwargs['limit_side_len'] | |||
| self.limit_type = kwargs.get('limit_type', 'min') | |||
| elif 'resize_long' in kwargs: | |||
| self.resize_type = 2 | |||
| self.resize_long = kwargs.get('resize_long', 960) | |||
| else: | |||
| self.limit_side_len = 736 | |||
| self.limit_type = 'min' | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| src_h, src_w, _ = img.shape | |||
| if sum([src_h, src_w]) < 64: | |||
| img = self.image_padding(img) | |||
| if self.resize_type == 0: | |||
| # img, shape = self.resize_image_type0(img) | |||
| img, [ratio_h, ratio_w] = self.resize_image_type0(img) | |||
| elif self.resize_type == 2: | |||
| img, [ratio_h, ratio_w] = self.resize_image_type2(img) | |||
| else: | |||
| # img, shape = self.resize_image_type1(img) | |||
| img, [ratio_h, ratio_w] = self.resize_image_type1(img) | |||
| data['image'] = img | |||
| data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) | |||
| return data | |||
| def image_padding(self, im, value=0): | |||
| h, w, c = im.shape | |||
| im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value | |||
| im_pad[:h, :w, :] = im | |||
| return im_pad | |||
| def resize_image_type1(self, img): | |||
| resize_h, resize_w = self.image_shape | |||
| ori_h, ori_w = img.shape[:2] # (h, w, c) | |||
| if self.keep_ratio is True: | |||
| resize_w = ori_w * resize_h / ori_h | |||
| N = math.ceil(resize_w / 32) | |||
| resize_w = N * 32 | |||
| ratio_h = float(resize_h) / ori_h | |||
| ratio_w = float(resize_w) / ori_w | |||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||
| # return img, np.array([ori_h, ori_w]) | |||
| return img, [ratio_h, ratio_w] | |||
| def resize_image_type0(self, img): | |||
| """ | |||
| resize image to a size multiple of 32 which is required by the network | |||
| args: | |||
| img(array): array with shape [h, w, c] | |||
| return(tuple): | |||
| img, (ratio_h, ratio_w) | |||
| """ | |||
| limit_side_len = self.limit_side_len | |||
| h, w, c = img.shape | |||
| # limit the max side | |||
| if self.limit_type == 'max': | |||
| if max(h, w) > limit_side_len: | |||
| if h > w: | |||
| ratio = float(limit_side_len) / h | |||
| else: | |||
| ratio = float(limit_side_len) / w | |||
| else: | |||
| ratio = 1. | |||
| elif self.limit_type == 'min': | |||
| if min(h, w) < limit_side_len: | |||
| if h < w: | |||
| ratio = float(limit_side_len) / h | |||
| else: | |||
| ratio = float(limit_side_len) / w | |||
| else: | |||
| ratio = 1. | |||
| elif self.limit_type == 'resize_long': | |||
| ratio = float(limit_side_len) / max(h, w) | |||
| else: | |||
| raise Exception('not support limit type, image ') | |||
| resize_h = int(h * ratio) | |||
| resize_w = int(w * ratio) | |||
| resize_h = max(int(round(resize_h / 32) * 32), 32) | |||
| resize_w = max(int(round(resize_w / 32) * 32), 32) | |||
| try: | |||
| if int(resize_w) <= 0 or int(resize_h) <= 0: | |||
| return None, (None, None) | |||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||
| except BaseException: | |||
| print(img.shape, resize_w, resize_h) | |||
| sys.exit(0) | |||
| ratio_h = resize_h / float(h) | |||
| ratio_w = resize_w / float(w) | |||
| return img, [ratio_h, ratio_w] | |||
| def resize_image_type2(self, img): | |||
| h, w, _ = img.shape | |||
| resize_w = w | |||
| resize_h = h | |||
| if resize_h > resize_w: | |||
| ratio = float(self.resize_long) / resize_h | |||
| else: | |||
| ratio = float(self.resize_long) / resize_w | |||
| resize_h = int(resize_h * ratio) | |||
| resize_w = int(resize_w * ratio) | |||
| max_stride = 128 | |||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||
| ratio_h = resize_h / float(h) | |||
| ratio_w = resize_w / float(w) | |||
| return img, [ratio_h, ratio_w] | |||
| class E2EResizeForTest(object): | |||
| def __init__(self, **kwargs): | |||
| super(E2EResizeForTest, self).__init__() | |||
| self.max_side_len = kwargs['max_side_len'] | |||
| self.valid_set = kwargs['valid_set'] | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| src_h, src_w, _ = img.shape | |||
| if self.valid_set == 'totaltext': | |||
| im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( | |||
| img, max_side_len=self.max_side_len) | |||
| else: | |||
| im_resized, (ratio_h, ratio_w) = self.resize_image( | |||
| img, max_side_len=self.max_side_len) | |||
| data['image'] = im_resized | |||
| data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) | |||
| return data | |||
| def resize_image_for_totaltext(self, im, max_side_len=512): | |||
| h, w, _ = im.shape | |||
| resize_w = w | |||
| resize_h = h | |||
| ratio = 1.25 | |||
| if h * ratio > max_side_len: | |||
| ratio = float(max_side_len) / resize_h | |||
| resize_h = int(resize_h * ratio) | |||
| resize_w = int(resize_w * ratio) | |||
| max_stride = 128 | |||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||
| im = cv2.resize(im, (int(resize_w), int(resize_h))) | |||
| ratio_h = resize_h / float(h) | |||
| ratio_w = resize_w / float(w) | |||
| return im, (ratio_h, ratio_w) | |||
| def resize_image(self, im, max_side_len=512): | |||
| """ | |||
| resize image to a size multiple of max_stride which is required by the network | |||
| :param im: the resized image | |||
| :param max_side_len: limit of max image size to avoid out of memory in gpu | |||
| :return: the resized image and the resize ratio | |||
| """ | |||
| h, w, _ = im.shape | |||
| resize_w = w | |||
| resize_h = h | |||
| # Fix the longer side | |||
| if resize_h > resize_w: | |||
| ratio = float(max_side_len) / resize_h | |||
| else: | |||
| ratio = float(max_side_len) / resize_w | |||
| resize_h = int(resize_h * ratio) | |||
| resize_w = int(resize_w * ratio) | |||
| max_stride = 128 | |||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||
| im = cv2.resize(im, (int(resize_w), int(resize_h))) | |||
| ratio_h = resize_h / float(h) | |||
| ratio_w = resize_w / float(w) | |||
| return im, (ratio_h, ratio_w) | |||
| class KieResize(object): | |||
| def __init__(self, **kwargs): | |||
| super(KieResize, self).__init__() | |||
| self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[ | |||
| 'img_scale'][1] | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| points = data['points'] | |||
| src_h, src_w, _ = img.shape | |||
| im_resized, scale_factor, [ratio_h, ratio_w | |||
| ], [new_h, new_w] = self.resize_image(img) | |||
| resize_points = self.resize_boxes(img, points, scale_factor) | |||
| data['ori_image'] = img | |||
| data['ori_boxes'] = points | |||
| data['points'] = resize_points | |||
| data['image'] = im_resized | |||
| data['shape'] = np.array([new_h, new_w]) | |||
| return data | |||
| def resize_image(self, img): | |||
| norm_img = np.zeros([1024, 1024, 3], dtype='float32') | |||
| scale = [512, 1024] | |||
| h, w = img.shape[:2] | |||
| max_long_edge = max(scale) | |||
| max_short_edge = min(scale) | |||
| scale_factor = min(max_long_edge / max(h, w), | |||
| max_short_edge / min(h, w)) | |||
| resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( | |||
| scale_factor) + 0.5) | |||
| max_stride = 32 | |||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||
| im = cv2.resize(img, (resize_w, resize_h)) | |||
| new_h, new_w = im.shape[:2] | |||
| w_scale = new_w / w | |||
| h_scale = new_h / h | |||
| scale_factor = np.array( | |||
| [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||
| norm_img[:new_h, :new_w, :] = im | |||
| return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] | |||
| def resize_boxes(self, im, points, scale_factor): | |||
| points = points * scale_factor | |||
| img_shape = im.shape[:2] | |||
| points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) | |||
| points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) | |||
| return points | |||
| class SRResize(object): | |||
| def __init__(self, | |||
| imgH=32, | |||
| imgW=128, | |||
| down_sample_scale=4, | |||
| keep_ratio=False, | |||
| min_ratio=1, | |||
| mask=False, | |||
| infer_mode=False, | |||
| **kwargs): | |||
| self.imgH = imgH | |||
| self.imgW = imgW | |||
| self.keep_ratio = keep_ratio | |||
| self.min_ratio = min_ratio | |||
| self.down_sample_scale = down_sample_scale | |||
| self.mask = mask | |||
| self.infer_mode = infer_mode | |||
| def __call__(self, data): | |||
| imgH = self.imgH | |||
| imgW = self.imgW | |||
| images_lr = data["image_lr"] | |||
| transform2 = ResizeNormalize( | |||
| (imgW // self.down_sample_scale, imgH // self.down_sample_scale)) | |||
| images_lr = transform2(images_lr) | |||
| data["img_lr"] = images_lr | |||
| if self.infer_mode: | |||
| return data | |||
| images_HR = data["image_hr"] | |||
| label_strs = data["label"] | |||
| transform = ResizeNormalize((imgW, imgH)) | |||
| images_HR = transform(images_HR) | |||
| data["img_hr"] = images_HR | |||
| return data | |||
| class ResizeNormalize(object): | |||
| def __init__(self, size, interpolation=Image.BICUBIC): | |||
| self.size = size | |||
| self.interpolation = interpolation | |||
| def __call__(self, img): | |||
| img = img.resize(self.size, self.interpolation) | |||
| img_numpy = np.array(img).astype("float32") | |||
| img_numpy = img_numpy.transpose((2, 0, 1)) / 255 | |||
| return img_numpy | |||
| class GrayImageChannelFormat(object): | |||
| """ | |||
| format gray scale image's channel: (3,h,w) -> (1,h,w) | |||
| Args: | |||
| inverse: inverse gray image | |||
| """ | |||
| def __init__(self, inverse=False, **kwargs): | |||
| self.inverse = inverse | |||
| def __call__(self, data): | |||
| img = data['image'] | |||
| img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |||
| img_expanded = np.expand_dims(img_single_channel, 0) | |||
| if self.inverse: | |||
| data['image'] = np.abs(img_expanded - 1) | |||
| else: | |||
| data['image'] = img_expanded | |||
| data['src_image'] = img | |||
| return data | |||
| class Permute(object): | |||
| """permute image | |||
| Args: | |||
| to_bgr (bool): whether convert RGB to BGR | |||
| channel_first (bool): whether convert HWC to CHW | |||
| """ | |||
| def __init__(self, ): | |||
| super(Permute, self).__init__() | |||
| def __call__(self, im, im_info): | |||
| """ | |||
| Args: | |||
| im (np.ndarray): image (np.ndarray) | |||
| im_info (dict): info of image | |||
| Returns: | |||
| im (np.ndarray): processed image (np.ndarray) | |||
| im_info (dict): info of processed image | |||
| """ | |||
| im = im.transpose((2, 0, 1)).copy() | |||
| return im, im_info | |||
| class PadStride(object): | |||
| """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config | |||
| Args: | |||
| stride (bool): model with FPN need image shape % stride == 0 | |||
| """ | |||
| def __init__(self, stride=0): | |||
| self.coarsest_stride = stride | |||
| def __call__(self, im, im_info): | |||
| """ | |||
| Args: | |||
| im (np.ndarray): image (np.ndarray) | |||
| im_info (dict): info of image | |||
| Returns: | |||
| im (np.ndarray): processed image (np.ndarray) | |||
| im_info (dict): info of processed image | |||
| """ | |||
| coarsest_stride = self.coarsest_stride | |||
| if coarsest_stride <= 0: | |||
| return im, im_info | |||
| im_c, im_h, im_w = im.shape | |||
| pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) | |||
| pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) | |||
| padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) | |||
| padding_im[:, :im_h, :im_w] = im | |||
| return padding_im, im_info | |||
| def decode_image(im_file, im_info): | |||
| """read rgb image | |||
| Args: | |||
| im_file (str|np.ndarray): input can be image path or np.ndarray | |||
| im_info (dict): info of image | |||
| Returns: | |||
| im (np.ndarray): processed image (np.ndarray) | |||
| im_info (dict): info of processed image | |||
| """ | |||
| if isinstance(im_file, str): | |||
| with open(im_file, 'rb') as f: | |||
| im_read = f.read() | |||
| data = np.frombuffer(im_read, dtype='uint8') | |||
| im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode | |||
| im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | |||
| else: | |||
| im = im_file | |||
| im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) | |||
| im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) | |||
| return im, im_info | |||
| def preprocess(im, preprocess_ops): | |||
| # process image by preprocess_ops | |||
| im_info = { | |||
| 'scale_factor': np.array( | |||
| [1., 1.], dtype=np.float32), | |||
| 'im_shape': None, | |||
| } | |||
| im, im_info = decode_image(im, im_info) | |||
| for operator in preprocess_ops: | |||
| im, im_info = operator(im, im_info) | |||
| return im, im_info | |||
| @@ -0,0 +1,354 @@ | |||
| import copy | |||
| import numpy as np | |||
| import cv2 | |||
| import paddle | |||
| from shapely.geometry import Polygon | |||
| import pyclipper | |||
| def build_post_process(config, global_config=None): | |||
| support_dict = ['DBPostProcess', 'CTCLabelDecode'] | |||
| config = copy.deepcopy(config) | |||
| module_name = config.pop('name') | |||
| if module_name == "None": | |||
| return | |||
| if global_config is not None: | |||
| config.update(global_config) | |||
| assert module_name in support_dict, Exception( | |||
| 'post process only support {}'.format(support_dict)) | |||
| module_class = eval(module_name)(**config) | |||
| return module_class | |||
| class DBPostProcess(object): | |||
| """ | |||
| The post process for Differentiable Binarization (DB). | |||
| """ | |||
| def __init__(self, | |||
| thresh=0.3, | |||
| box_thresh=0.7, | |||
| max_candidates=1000, | |||
| unclip_ratio=2.0, | |||
| use_dilation=False, | |||
| score_mode="fast", | |||
| box_type='quad', | |||
| **kwargs): | |||
| self.thresh = thresh | |||
| self.box_thresh = box_thresh | |||
| self.max_candidates = max_candidates | |||
| self.unclip_ratio = unclip_ratio | |||
| self.min_size = 3 | |||
| self.score_mode = score_mode | |||
| self.box_type = box_type | |||
| assert score_mode in [ | |||
| "slow", "fast" | |||
| ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) | |||
| self.dilation_kernel = None if not use_dilation else np.array( | |||
| [[1, 1], [1, 1]]) | |||
| def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): | |||
| ''' | |||
| _bitmap: single map with shape (1, H, W), | |||
| whose values are binarized as {0, 1} | |||
| ''' | |||
| bitmap = _bitmap | |||
| height, width = bitmap.shape | |||
| boxes = [] | |||
| scores = [] | |||
| contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), | |||
| cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | |||
| for contour in contours[:self.max_candidates]: | |||
| epsilon = 0.002 * cv2.arcLength(contour, True) | |||
| approx = cv2.approxPolyDP(contour, epsilon, True) | |||
| points = approx.reshape((-1, 2)) | |||
| if points.shape[0] < 4: | |||
| continue | |||
| score = self.box_score_fast(pred, points.reshape(-1, 2)) | |||
| if self.box_thresh > score: | |||
| continue | |||
| if points.shape[0] > 2: | |||
| box = self.unclip(points, self.unclip_ratio) | |||
| if len(box) > 1: | |||
| continue | |||
| else: | |||
| continue | |||
| box = box.reshape(-1, 2) | |||
| _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) | |||
| if sside < self.min_size + 2: | |||
| continue | |||
| box = np.array(box) | |||
| box[:, 0] = np.clip( | |||
| np.round(box[:, 0] / width * dest_width), 0, dest_width) | |||
| box[:, 1] = np.clip( | |||
| np.round(box[:, 1] / height * dest_height), 0, dest_height) | |||
| boxes.append(box.tolist()) | |||
| scores.append(score) | |||
| return boxes, scores | |||
| def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): | |||
| ''' | |||
| _bitmap: single map with shape (1, H, W), | |||
| whose values are binarized as {0, 1} | |||
| ''' | |||
| bitmap = _bitmap | |||
| height, width = bitmap.shape | |||
| outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, | |||
| cv2.CHAIN_APPROX_SIMPLE) | |||
| if len(outs) == 3: | |||
| img, contours, _ = outs[0], outs[1], outs[2] | |||
| elif len(outs) == 2: | |||
| contours, _ = outs[0], outs[1] | |||
| num_contours = min(len(contours), self.max_candidates) | |||
| boxes = [] | |||
| scores = [] | |||
| for index in range(num_contours): | |||
| contour = contours[index] | |||
| points, sside = self.get_mini_boxes(contour) | |||
| if sside < self.min_size: | |||
| continue | |||
| points = np.array(points) | |||
| if self.score_mode == "fast": | |||
| score = self.box_score_fast(pred, points.reshape(-1, 2)) | |||
| else: | |||
| score = self.box_score_slow(pred, contour) | |||
| if self.box_thresh > score: | |||
| continue | |||
| box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2) | |||
| box, sside = self.get_mini_boxes(box) | |||
| if sside < self.min_size + 2: | |||
| continue | |||
| box = np.array(box) | |||
| box[:, 0] = np.clip( | |||
| np.round(box[:, 0] / width * dest_width), 0, dest_width) | |||
| box[:, 1] = np.clip( | |||
| np.round(box[:, 1] / height * dest_height), 0, dest_height) | |||
| boxes.append(box.astype("int32")) | |||
| scores.append(score) | |||
| return np.array(boxes, dtype="int32"), scores | |||
| def unclip(self, box, unclip_ratio): | |||
| poly = Polygon(box) | |||
| distance = poly.area * unclip_ratio / poly.length | |||
| offset = pyclipper.PyclipperOffset() | |||
| offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | |||
| expanded = np.array(offset.Execute(distance)) | |||
| return expanded | |||
| def get_mini_boxes(self, contour): | |||
| bounding_box = cv2.minAreaRect(contour) | |||
| points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) | |||
| index_1, index_2, index_3, index_4 = 0, 1, 2, 3 | |||
| if points[1][1] > points[0][1]: | |||
| index_1 = 0 | |||
| index_4 = 1 | |||
| else: | |||
| index_1 = 1 | |||
| index_4 = 0 | |||
| if points[3][1] > points[2][1]: | |||
| index_2 = 2 | |||
| index_3 = 3 | |||
| else: | |||
| index_2 = 3 | |||
| index_3 = 2 | |||
| box = [ | |||
| points[index_1], points[index_2], points[index_3], points[index_4] | |||
| ] | |||
| return box, min(bounding_box[1]) | |||
| def box_score_fast(self, bitmap, _box): | |||
| ''' | |||
| box_score_fast: use bbox mean score as the mean score | |||
| ''' | |||
| h, w = bitmap.shape[:2] | |||
| box = _box.copy() | |||
| xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) | |||
| xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) | |||
| ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) | |||
| ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) | |||
| mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) | |||
| box[:, 0] = box[:, 0] - xmin | |||
| box[:, 1] = box[:, 1] - ymin | |||
| cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) | |||
| return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] | |||
| def box_score_slow(self, bitmap, contour): | |||
| ''' | |||
| box_score_slow: use polyon mean score as the mean score | |||
| ''' | |||
| h, w = bitmap.shape[:2] | |||
| contour = contour.copy() | |||
| contour = np.reshape(contour, (-1, 2)) | |||
| xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) | |||
| xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) | |||
| ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) | |||
| ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) | |||
| mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) | |||
| contour[:, 0] = contour[:, 0] - xmin | |||
| contour[:, 1] = contour[:, 1] - ymin | |||
| cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) | |||
| return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] | |||
| def __call__(self, outs_dict, shape_list): | |||
| pred = outs_dict['maps'] | |||
| if isinstance(pred, paddle.Tensor): | |||
| pred = pred.numpy() | |||
| pred = pred[:, 0, :, :] | |||
| segmentation = pred > self.thresh | |||
| boxes_batch = [] | |||
| for batch_index in range(pred.shape[0]): | |||
| src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] | |||
| if self.dilation_kernel is not None: | |||
| mask = cv2.dilate( | |||
| np.array(segmentation[batch_index]).astype(np.uint8), | |||
| self.dilation_kernel) | |||
| else: | |||
| mask = segmentation[batch_index] | |||
| if self.box_type == 'poly': | |||
| boxes, scores = self.polygons_from_bitmap(pred[batch_index], | |||
| mask, src_w, src_h) | |||
| elif self.box_type == 'quad': | |||
| boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, | |||
| src_w, src_h) | |||
| else: | |||
| raise ValueError( | |||
| "box_type can only be one of ['quad', 'poly']") | |||
| boxes_batch.append({'points': boxes}) | |||
| return boxes_batch | |||
| class BaseRecLabelDecode(object): | |||
| """ Convert between text-label and text-index """ | |||
| def __init__(self, character_dict_path=None, use_space_char=False): | |||
| self.beg_str = "sos" | |||
| self.end_str = "eos" | |||
| self.reverse = False | |||
| self.character_str = [] | |||
| if character_dict_path is None: | |||
| self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" | |||
| dict_character = list(self.character_str) | |||
| else: | |||
| with open(character_dict_path, "rb") as fin: | |||
| lines = fin.readlines() | |||
| for line in lines: | |||
| line = line.decode('utf-8').strip("\n").strip("\r\n") | |||
| self.character_str.append(line) | |||
| if use_space_char: | |||
| self.character_str.append(" ") | |||
| dict_character = list(self.character_str) | |||
| if 'arabic' in character_dict_path: | |||
| self.reverse = True | |||
| dict_character = self.add_special_char(dict_character) | |||
| self.dict = {} | |||
| for i, char in enumerate(dict_character): | |||
| self.dict[char] = i | |||
| self.character = dict_character | |||
| def pred_reverse(self, pred): | |||
| pred_re = [] | |||
| c_current = '' | |||
| for c in pred: | |||
| if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)): | |||
| if c_current != '': | |||
| pred_re.append(c_current) | |||
| pred_re.append(c) | |||
| c_current = '' | |||
| else: | |||
| c_current += c | |||
| if c_current != '': | |||
| pred_re.append(c_current) | |||
| return ''.join(pred_re[::-1]) | |||
| def add_special_char(self, dict_character): | |||
| return dict_character | |||
| def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | |||
| """ convert text-index into text-label. """ | |||
| result_list = [] | |||
| ignored_tokens = self.get_ignored_tokens() | |||
| batch_size = len(text_index) | |||
| for batch_idx in range(batch_size): | |||
| selection = np.ones(len(text_index[batch_idx]), dtype=bool) | |||
| if is_remove_duplicate: | |||
| selection[1:] = text_index[batch_idx][1:] != text_index[ | |||
| batch_idx][:-1] | |||
| for ignored_token in ignored_tokens: | |||
| selection &= text_index[batch_idx] != ignored_token | |||
| char_list = [ | |||
| self.character[text_id] | |||
| for text_id in text_index[batch_idx][selection] | |||
| ] | |||
| if text_prob is not None: | |||
| conf_list = text_prob[batch_idx][selection] | |||
| else: | |||
| conf_list = [1] * len(selection) | |||
| if len(conf_list) == 0: | |||
| conf_list = [0] | |||
| text = ''.join(char_list) | |||
| if self.reverse: # for arabic rec | |||
| text = self.pred_reverse(text) | |||
| result_list.append((text, np.mean(conf_list).tolist())) | |||
| return result_list | |||
| def get_ignored_tokens(self): | |||
| return [0] # for ctc blank | |||
| class CTCLabelDecode(BaseRecLabelDecode): | |||
| """ Convert between text-label and text-index """ | |||
| def __init__(self, character_dict_path=None, use_space_char=False, | |||
| **kwargs): | |||
| super(CTCLabelDecode, self).__init__(character_dict_path, | |||
| use_space_char) | |||
| def __call__(self, preds, label=None, *args, **kwargs): | |||
| if isinstance(preds, tuple) or isinstance(preds, list): | |||
| preds = preds[-1] | |||
| if isinstance(preds, paddle.Tensor): | |||
| preds = preds.numpy() | |||
| preds_idx = preds.argmax(axis=2) | |||
| preds_prob = preds.max(axis=2) | |||
| text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) | |||
| if label is None: | |||
| return text | |||
| label = self.decode(label) | |||
| return text, label | |||
| def add_special_char(self, dict_character): | |||
| dict_character = ['blank'] + dict_character | |||
| return dict_character | |||
| @@ -0,0 +1,139 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import onnxruntime as ort | |||
| from huggingface_hub import snapshot_download | |||
| from .operators import * | |||
| from rag.settings import cron_logger | |||
| class Recognizer(object): | |||
| def __init__(self, label_list, task_name, model_dir=None): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| For Linux: | |||
| export HF_ENDPOINT=https://hf-mirror.com | |||
| For Windows: | |||
| Good luck | |||
| ^_- | |||
| """ | |||
| if not model_dir: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||
| if not os.path.exists(model_file_path): | |||
| raise ValueError("not find model file path {}".format( | |||
| model_file_path)) | |||
| if ort.get_device() == "GPU": | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||
| else: | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||
| self.label_list = label_list | |||
| def create_inputs(self, imgs, im_info): | |||
| """generate input for different model type | |||
| Args: | |||
| imgs (list(numpy)): list of images (np.ndarray) | |||
| im_info (list(dict)): list of image info | |||
| Returns: | |||
| inputs (dict): input of model | |||
| """ | |||
| inputs = {} | |||
| im_shape = [] | |||
| scale_factor = [] | |||
| if len(imgs) == 1: | |||
| inputs['image'] = np.array((imgs[0],)).astype('float32') | |||
| inputs['im_shape'] = np.array( | |||
| (im_info[0]['im_shape'],)).astype('float32') | |||
| inputs['scale_factor'] = np.array( | |||
| (im_info[0]['scale_factor'],)).astype('float32') | |||
| return inputs | |||
| for e in im_info: | |||
| im_shape.append(np.array((e['im_shape'],)).astype('float32')) | |||
| scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) | |||
| inputs['im_shape'] = np.concatenate(im_shape, axis=0) | |||
| inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) | |||
| imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] | |||
| max_shape_h = max([e[0] for e in imgs_shape]) | |||
| max_shape_w = max([e[1] for e in imgs_shape]) | |||
| padding_imgs = [] | |||
| for img in imgs: | |||
| im_c, im_h, im_w = img.shape[:] | |||
| padding_im = np.zeros( | |||
| (im_c, max_shape_h, max_shape_w), dtype=np.float32) | |||
| padding_im[:, :im_h, :im_w] = img | |||
| padding_imgs.append(padding_im) | |||
| inputs['image'] = np.stack(padding_imgs, axis=0) | |||
| return inputs | |||
| def preprocess(self, image_list): | |||
| preprocess_ops = [] | |||
| for op_info in [ | |||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||
| {'type': 'Permute'}, | |||
| {'stride': 32, 'type': 'PadStride'} | |||
| ]: | |||
| new_op_info = op_info.copy() | |||
| op_type = new_op_info.pop('type') | |||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||
| inputs = [] | |||
| for im_path in image_list: | |||
| im, im_info = preprocess(im_path, preprocess_ops) | |||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||
| return inputs | |||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||
| res = [] | |||
| imgs = [] | |||
| for i in range(len(image_list)): | |||
| if not isinstance(image_list[i], np.ndarray): | |||
| imgs.append(np.array(image_list[i])) | |||
| else: imgs.append(image_list[i]) | |||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||
| for i in range(batch_loop_cnt): | |||
| start_index = i * batch_size | |||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||
| batch_image_list = imgs[start_index:end_index] | |||
| inputs = self.preprocess(batch_image_list) | |||
| for ins in inputs: | |||
| bb = [] | |||
| for b in self.ort_sess.run(None, ins)[0]: | |||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||
| if score < thr: | |||
| continue | |||
| if clsid >= len(self.label_list): | |||
| cron_logger.warning(f"bad category id") | |||
| continue | |||
| bb.append({ | |||
| "type": self.label_list[clsid].lower(), | |||
| "bbox": [float(t) for t in bbox.tolist()], | |||
| "score": float(score) | |||
| }) | |||
| res.append(bb) | |||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||
| return res | |||
| @@ -0,0 +1,83 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import PIL | |||
| from PIL import ImageDraw | |||
| def save_results(image_list, results, labels, output_dir='output/', threshold=0.5): | |||
| if not os.path.exists(output_dir): | |||
| os.makedirs(output_dir) | |||
| for idx, im in enumerate(image_list): | |||
| im = draw_box(im, results[idx], labels, threshold=threshold) | |||
| out_path = os.path.join(output_dir, f"{idx}.jpg") | |||
| im.save(out_path, quality=95) | |||
| print("save result to: " + out_path) | |||
| def draw_box(im, result, lables, threshold=0.5): | |||
| draw_thickness = min(im.size) // 320 | |||
| draw = ImageDraw.Draw(im) | |||
| color_list = get_color_map_list(len(lables)) | |||
| clsid2color = {n.lower():color_list[i] for i,n in enumerate(lables)} | |||
| result = [r for r in result if r["score"] >= threshold] | |||
| for dt in result: | |||
| color = tuple(clsid2color[dt["type"]]) | |||
| xmin, ymin, xmax, ymax = dt["bbox"] | |||
| draw.line( | |||
| [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), | |||
| (xmin, ymin)], | |||
| width=draw_thickness, | |||
| fill=color) | |||
| # draw label | |||
| text = "{} {:.4f}".format(dt["type"], dt["score"]) | |||
| tw, th = imagedraw_textsize_c(draw, text) | |||
| draw.rectangle( | |||
| [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) | |||
| draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) | |||
| return im | |||
| def get_color_map_list(num_classes): | |||
| """ | |||
| Args: | |||
| num_classes (int): number of class | |||
| Returns: | |||
| color_map (list): RGB color list | |||
| """ | |||
| color_map = num_classes * [0, 0, 0] | |||
| for i in range(0, num_classes): | |||
| j = 0 | |||
| lab = i | |||
| while lab: | |||
| color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) | |||
| color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) | |||
| color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) | |||
| j += 1 | |||
| lab >>= 3 | |||
| color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] | |||
| return color_map | |||
| def imagedraw_textsize_c(draw, text): | |||
| if int(PIL.__version__.split('.')[0]) < 10: | |||
| tw, th = draw.textsize(text) | |||
| else: | |||
| left, top, right, bottom = draw.textbbox((0, 0), text) | |||
| tw, th = right - left, bottom - top | |||
| return tw, th | |||
| @@ -1,15 +1,24 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import random | |||
| import re | |||
| import numpy as np | |||
| from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \ | |||
| from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, \ | |||
| hierarchical_merge, make_colon_as_title, naive_merge, random_choices | |||
| from rag.nlp import huqie | |||
| from rag.parser.docx_parser import HuDocxParser | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import PdfParser, DocxParser | |||
| class Pdf(HuParser): | |||
| class Pdf(PdfParser): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| @@ -21,7 +30,7 @@ class Pdf(HuParser): | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| self._layouts_rec(zoomin) | |||
| callback(0.47, "Layout analysis finished") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| @@ -53,7 +62,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k | |||
| sections,tbls = [], [] | |||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| doc_parser = HuDocxParser() | |||
| doc_parser = DocxParser() | |||
| # TODO: table of contents need to be removed | |||
| sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | |||
| remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) | |||
| @@ -1,16 +1,27 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import re | |||
| from io import BytesIO | |||
| from docx import Document | |||
| from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \ | |||
| from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \ | |||
| make_colon_as_title | |||
| from rag.nlp import huqie | |||
| from rag.parser.docx_parser import HuDocxParser | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import PdfParser, DocxParser | |||
| from rag.settings import cron_logger | |||
| class Docx(HuDocxParser): | |||
| class Docx(DocxParser): | |||
| def __init__(self): | |||
| pass | |||
| @@ -35,7 +46,7 @@ class Docx(HuDocxParser): | |||
| return [l for l in lines if l] | |||
| class Pdf(HuParser): | |||
| class Pdf(PdfParser): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| @@ -47,7 +58,7 @@ class Pdf(HuParser): | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| self._layouts_rec(zoomin) | |||
| callback(0.77, "Layout analysis finished") | |||
| cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1))) | |||
| self._naive_vertical_merge() | |||
| @@ -1,12 +1,12 @@ | |||
| import copy | |||
| import re | |||
| from rag.parser import tokenize | |||
| from deepdoc.parser import tokenize | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import PdfParser | |||
| from rag.utils import num_tokens_from_string | |||
| class Pdf(HuParser): | |||
| class Pdf(PdfParser): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| @@ -18,7 +18,7 @@ class Pdf(HuParser): | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| self._layouts_rec(zoomin) | |||
| callback(0.5, "Layout analysis finished.") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| @@ -1,13 +1,25 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import re | |||
| from rag.app import laws | |||
| from rag.parser import is_english, tokenize, naive_merge | |||
| from deepdoc.parser import is_english, tokenize, naive_merge | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import PdfParser | |||
| from rag.settings import cron_logger | |||
| class Pdf(HuParser): | |||
| class Pdf(PdfParser): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| @@ -19,7 +31,7 @@ class Pdf(HuParser): | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| self._layouts_rec(zoomin) | |||
| callback(0.77, "Layout analysis finished") | |||
| cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | |||
| self._naive_vertical_merge() | |||
| @@ -1,16 +1,28 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import re | |||
| from collections import Counter | |||
| from api.db import ParserType | |||
| from rag.parser import tokenize | |||
| from deepdoc.parser import tokenize | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import PdfParser | |||
| import numpy as np | |||
| from rag.utils import num_tokens_from_string | |||
| class Pdf(HuParser): | |||
| class Pdf(PdfParser): | |||
| def __init__(self): | |||
| self.model_speciess = ParserType.PAPER.value | |||
| super().__init__() | |||
| @@ -26,7 +38,7 @@ class Pdf(HuParser): | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| self._layouts_paddle(zoomin) | |||
| self._layouts_rec(zoomin) | |||
| callback(0.47, "Layout analysis finished") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| @@ -1,11 +1,22 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import re | |||
| from io import BytesIO | |||
| from pptx import Presentation | |||
| from rag.parser import tokenize, is_english | |||
| from deepdoc.parser import tokenize, is_english | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import PdfParser | |||
| class Ppt(object): | |||
| @@ -58,7 +69,7 @@ class Ppt(object): | |||
| return [(txts[i], imgs[i]) for i in range(len(txts))] | |||
| class Pdf(HuParser): | |||
| class Pdf(PdfParser): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -74,7 +85,7 @@ class Pdf(HuParser): | |||
| assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) | |||
| res = [] | |||
| #################### More precisely ################### | |||
| # self._layouts_paddle(zoomin) | |||
| # self._layouts_rec(zoomin) | |||
| # self._text_merge() | |||
| # pages = {} | |||
| # for b in self.boxes: | |||
| @@ -1,13 +1,25 @@ | |||
| import random | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| from io import BytesIO | |||
| from nltk import word_tokenize | |||
| from openpyxl import load_workbook | |||
| from rag.parser import is_english, random_choices | |||
| from deepdoc.parser import is_english, random_choices | |||
| from rag.nlp import huqie, stemmer | |||
| from deepdoc.parser import ExcelParser | |||
| class Excel(object): | |||
| class Excel(ExcelParser): | |||
| def __call__(self, fnm, binary=None, callback=None): | |||
| if not binary: | |||
| wb = load_workbook(fnm) | |||
| @@ -1,59 +1,82 @@ | |||
| import copy | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import base64 | |||
| import datetime | |||
| import json | |||
| import os | |||
| import re | |||
| import pandas as pd | |||
| import requests | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.settings import stat_logger | |||
| from rag.nlp import huqie | |||
| from deepdoc.parser.resume import refactor | |||
| from deepdoc.parser.resume import step_one, step_two | |||
| from rag.settings import cron_logger | |||
| from rag.utils import rmSpace | |||
| forbidden_select_fields4resume = [ | |||
| "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" | |||
| ] | |||
| def remote_call(filename, binary): | |||
| q = { | |||
| "header": { | |||
| "uid": 1, | |||
| "user": "kevinhu", | |||
| "log_id": filename | |||
| }, | |||
| "request": { | |||
| "p": { | |||
| "request_id": "1", | |||
| "encrypt_type": "base64", | |||
| "filename": filename, | |||
| "langtype": '', | |||
| "fileori": base64.b64encode(binary.stream.read()).decode('utf-8') | |||
| }, | |||
| "c": "resume_parse_module", | |||
| "m": "resume_parse" | |||
| } | |||
| } | |||
| for _ in range(3): | |||
| try: | |||
| resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q)) | |||
| resume = resume.json()["response"]["results"] | |||
| resume = refactor(resume) | |||
| for k in ["education", "work", "project", "training", "skill", "certificate", "language"]: | |||
| if not resume.get(k) and k in resume: del resume[k] | |||
| resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", | |||
| "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) | |||
| resume = step_two.parse(resume) | |||
| return resume | |||
| except Exception as e: | |||
| cron_logger.error("Resume parser error: "+str(e)) | |||
| return {} | |||
| def chunk(filename, binary=None, callback=None, **kwargs): | |||
| """ | |||
| The supported file formats are pdf, docx and txt. | |||
| To maximize the effectiveness, parse the resume correctly, | |||
| please visit https://github.com/infiniflow/ragflow, and sign in the our demo web-site | |||
| to get token. It's FREE! | |||
| Set INFINIFLOW_SERVER and INFINIFLOW_TOKEN in '.env' file or | |||
| using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN in docker container. | |||
| To maximize the effectiveness, parse the resume correctly, please concat us: https://github.com/infiniflow/ragflow | |||
| """ | |||
| if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): | |||
| raise NotImplementedError("file type not supported yet(pdf supported)") | |||
| url = os.environ.get("INFINIFLOW_SERVER") | |||
| token = os.environ.get("INFINIFLOW_TOKEN") | |||
| if not url or not token: | |||
| stat_logger.warning( | |||
| "INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.") | |||
| return [] | |||
| if not binary: | |||
| with open(filename, "rb") as f: | |||
| binary = f.read() | |||
| def remote_call(): | |||
| nonlocal filename, binary | |||
| for _ in range(3): | |||
| try: | |||
| res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)], | |||
| headers={"Authorization": token}, timeout=180) | |||
| res = res.json() | |||
| if res["retcode"] != 0: | |||
| raise RuntimeError(res["retmsg"]) | |||
| return res["data"] | |||
| except RuntimeError as e: | |||
| raise e | |||
| except Exception as e: | |||
| cron_logger.error("resume parsing:" + str(e)) | |||
| callback(0.2, "Resume parsing is going on...") | |||
| resume = remote_call() | |||
| resume = remote_call(filename, binary) | |||
| if len(resume.keys()) < 7: | |||
| callback(-1, "Resume is not successfully parsed.") | |||
| return [] | |||
| @@ -1,3 +1,15 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import copy | |||
| import re | |||
| from io import BytesIO | |||
| @@ -8,11 +20,12 @@ from openpyxl import load_workbook | |||
| from dateutil.parser import parse as datetime_parse | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from rag.parser import is_english, tokenize | |||
| from rag.nlp import huqie, stemmer | |||
| from deepdoc.parser import is_english, tokenize | |||
| from rag.nlp import huqie | |||
| from deepdoc.parser import ExcelParser | |||
| class Excel(object): | |||
| class Excel(ExcelParser): | |||
| def __call__(self, fnm, binary=None, callback=None): | |||
| if not binary: | |||
| wb = load_workbook(fnm) | |||
| @@ -1,3 +1,15 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| import os | |||
| import copy | |||
| @@ -443,13 +455,13 @@ if __name__ == "__main__": | |||
| import sys | |||
| sys.path.append(os.path.dirname(__file__) + "/../") | |||
| if sys.argv[1].split(".")[-1].lower() == "pdf": | |||
| from parser import PdfParser | |||
| from deepdoc.parser import PdfParser | |||
| ckr = PdfChunker(PdfParser()) | |||
| if sys.argv[1].split(".")[-1].lower().find("doc") >= 0: | |||
| from parser import DocxParser | |||
| from deepdoc.parser import DocxParser | |||
| ckr = DocxChunker(DocxParser()) | |||
| if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0: | |||
| from parser import ExcelParser | |||
| from deepdoc.parser import ExcelParser | |||
| ckr = ExcelChunker(ExcelParser()) | |||
| # ckr.html(sys.argv[1]) | |||
| @@ -21,7 +21,7 @@ from datetime import datetime | |||
| from api.db.db_models import Task | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.db.services.task_service import TaskService | |||
| from rag.parser.pdf_parser import HuParser | |||
| from deepdoc.parser import HuParser | |||
| from rag.settings import cron_logger | |||
| from rag.utils import MINIO | |||
| from rag.utils import findMaxTm | |||