| return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | ||||
| kwargs["knowledge"] = "\n".join(knowledges) | kwargs["knowledge"] = "\n".join(knowledges) | ||||
| gen_conf = dialog.llm_setting[dialog.llm_setting_type] | |||||
| gen_conf = dialog.llm_setting | |||||
| msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] | msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] | ||||
| used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | ||||
| if "max_tokens" in gen_conf: | if "max_tokens" in gen_conf: |
| name = req.get("name", "New Dialog") | name = req.get("name", "New Dialog") | ||||
| description = req.get("description", "A helpful Dialog") | description = req.get("description", "A helpful Dialog") | ||||
| language = req.get("language", "Chinese") | language = req.get("language", "Chinese") | ||||
| llm_setting_type = req.get("llm_setting_type", "Precise") | |||||
| top_n = req.get("top_n", 6) | |||||
| similarity_threshold = req.get("similarity_threshold", 0.1) | |||||
| vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | |||||
| llm_setting = req.get("llm_setting", { | llm_setting = req.get("llm_setting", { | ||||
| "Creative": { | |||||
| "temperature": 0.9, | |||||
| "top_p": 0.9, | |||||
| "frequency_penalty": 0.2, | |||||
| "presence_penalty": 0.4, | |||||
| "max_tokens": 512 | |||||
| }, | |||||
| "Precise": { | |||||
| "temperature": 0.1, | |||||
| "top_p": 0.3, | |||||
| "frequency_penalty": 0.7, | |||||
| "presence_penalty": 0.4, | |||||
| "max_tokens": 215 | |||||
| }, | |||||
| "Evenly": { | |||||
| "temperature": 0.5, | |||||
| "top_p": 0.5, | |||||
| "frequency_penalty": 0.7, | |||||
| "presence_penalty": 0.4, | |||||
| "max_tokens": 215 | |||||
| }, | |||||
| "Custom": { | |||||
| "temperature": 0.2, | |||||
| "top_p": 0.3, | |||||
| "frequency_penalty": 0.6, | |||||
| "presence_penalty": 0.3, | |||||
| "max_tokens": 215 | |||||
| }, | |||||
| "temperature": 0.1, | |||||
| "top_p": 0.3, | |||||
| "frequency_penalty": 0.7, | |||||
| "presence_penalty": 0.4, | |||||
| "max_tokens": 215 | |||||
| }) | }) | ||||
| prompt_config = req.get("prompt_config", { | |||||
| default_prompt = { | |||||
| "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 | "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。 | ||||
| 以下是知识库: | 以下是知识库: | ||||
| {knowledge} | {knowledge} | ||||
| {"key": "knowledge", "optional": False} | {"key": "knowledge", "optional": False} | ||||
| ], | ], | ||||
| "empty_response": "Sorry! 知识库中未找到相关内容!" | "empty_response": "Sorry! 知识库中未找到相关内容!" | ||||
| }) | |||||
| } | |||||
| prompt_config = req.get("prompt_config", default_prompt) | |||||
| if len(prompt_config["parameters"]) < 1: | |||||
| return get_data_error_result(retmsg="'knowledge' should be in parameters") | |||||
| if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"] | |||||
| # if len(prompt_config["parameters"]) < 1: | |||||
| # prompt_config["parameters"] = default_prompt["parameters"] | |||||
| # for p in prompt_config["parameters"]: | |||||
| # if p["key"] == "knowledge":break | |||||
| # else: prompt_config["parameters"].append(default_prompt["parameters"][0]) | |||||
| for p in prompt_config["parameters"]: | for p in prompt_config["parameters"]: | ||||
| if prompt_config["system"].find("{%s}"%p["key"]) < 0: | |||||
| if p["optional"]: continue | |||||
| if prompt_config["system"].find("{%s}" % p["key"]) < 0: | |||||
| return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) | return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) | ||||
| try: | try: | ||||
| e, tenant = TenantService.get_by_id(current_user.id) | e, tenant = TenantService.get_by_id(current_user.id) | ||||
| if not e:return get_data_error_result(retmsg="Tenant not found!") | |||||
| if not e: return get_data_error_result(retmsg="Tenant not found!") | |||||
| llm_id = req.get("llm_id", tenant.llm_id) | llm_id = req.get("llm_id", tenant.llm_id) | ||||
| if not dialog_id: | if not dialog_id: | ||||
| if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!") | |||||
| dia = { | dia = { | ||||
| "id": get_uuid(), | "id": get_uuid(), | ||||
| "tenant_id": current_user.id, | "tenant_id": current_user.id, | ||||
| "name": name, | "name": name, | ||||
| "kb_ids": req["kb_ids"], | |||||
| "description": description, | "description": description, | ||||
| "language": language, | "language": language, | ||||
| "llm_id": llm_id, | "llm_id": llm_id, | ||||
| "llm_setting_type": llm_setting_type, | |||||
| "llm_setting": llm_setting, | "llm_setting": llm_setting, | ||||
| "prompt_config": prompt_config | |||||
| "prompt_config": prompt_config, | |||||
| "top_n": top_n, | |||||
| "similarity_threshold": similarity_threshold, | |||||
| "vector_similarity_weight": vector_similarity_weight | |||||
| } | } | ||||
| if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") | if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") | ||||
| e, dia = DialogService.get_by_id(dia["id"]) | e, dia = DialogService.get_by_id(dia["id"]) | ||||
| def get(): | def get(): | ||||
| dialog_id = request.args["dialog_id"] | dialog_id = request.args["dialog_id"] | ||||
| try: | try: | ||||
| e,dia = DialogService.get_by_id(dialog_id) | |||||
| e, dia = DialogService.get_by_id(dialog_id) | |||||
| if not e: return get_data_error_result(retmsg="Dialog not found!") | if not e: return get_data_error_result(retmsg="Dialog not found!") | ||||
| dia = dia.to_dict() | dia = dia.to_dict() | ||||
| dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| def get_kb_names(kb_ids): | def get_kb_names(kb_ids): | ||||
| ids, nms = [], [] | ids, nms = [], [] | ||||
| for kid in kb_ids: | for kid in kb_ids: | ||||
| e, kb = KnowledgebaseService.get_by_id(kid) | e, kb = KnowledgebaseService.get_by_id(kid) | ||||
| if not e or kb.status != StatusEnum.VALID.value:continue | |||||
| if not e or kb.status != StatusEnum.VALID.value: continue | |||||
| ids.append(kid) | ids.append(kid) | ||||
| nms.append(kb.name) | nms.append(kb.name) | ||||
| return ids, nms | return ids, nms | ||||
| @manager.route('/list', methods=['GET']) | @manager.route('/list', methods=['GET']) | ||||
| @login_required | @login_required | ||||
| def list(): | def list(): | ||||
| try: | try: | ||||
| diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value) | |||||
| diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) | |||||
| diags = [d.to_dict() for d in diags] | diags = [d.to_dict() for d in diags] | ||||
| for d in diags: | for d in diags: | ||||
| d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) | d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) | ||||
| @manager.route('/rm', methods=['POST']) | @manager.route('/rm', methods=['POST']) | ||||
| @login_required | @login_required | ||||
| @validate_request("dialog_id") | |||||
| @validate_request("dialog_ids") | |||||
| def rm(): | def rm(): | ||||
| req = request.json | req = request.json | ||||
| try: | try: | ||||
| if not DialogService.update_by_id(req["dialog_id"], {"status": StatusEnum.INVALID.value}): | |||||
| return get_data_error_result(retmsg="Dialog not found!") | |||||
| DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | |||||
| return server_error_response(e) |
| icon = CharField(max_length=16, null=False, help_text="dialog icon") | icon = CharField(max_length=16, null=False, help_text="dialog icon") | ||||
| language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") | language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") | ||||
| llm_id = CharField(max_length=32, null=False, help_text="default llm ID") | llm_id = CharField(max_length=32, null=False, help_text="default llm ID") | ||||
| llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom", | |||||
| default="Creative") | |||||
| llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | ||||
| "presence_penalty": 0.4, "max_tokens": 215}) | "presence_penalty": 0.4, "max_tokens": 215}) | ||||
| prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") | prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") |
| import copy | |||||
| import random | import random | ||||
| from .pdf_parser import HuParser as PdfParser | from .pdf_parser import HuParser as PdfParser | ||||
| from nltk import word_tokenize | from nltk import word_tokenize | ||||
| from rag.nlp import stemmer, huqie | from rag.nlp import stemmer, huqie | ||||
| from ..utils import num_tokens_from_string | |||||
| from rag.utils import num_tokens_from_string | |||||
| BULLET_PATTERN = [[ | BULLET_PATTERN = [[ | ||||
| r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", | r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", |
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | import os | ||||
| import random | import random | ||||
| from functools import partial | |||||
| import fitz | import fitz | ||||
| import requests | import requests | ||||
| import numpy as np | import numpy as np | ||||
| from api.db import ParserType | from api.db import ParserType | ||||
| from deepdoc.visual import OCR, Recognizer | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from collections import Counter | from collections import Counter | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| class HuParser: | class HuParser: | ||||
| def __init__(self): | def __init__(self): | ||||
| from paddleocr import PaddleOCR | |||||
| logging.getLogger("ppocr").setLevel(logging.ERROR) | |||||
| self.ocr = PaddleOCR(use_angle_cls=False, lang="ch") | |||||
| self.ocr = OCR() | |||||
| if not hasattr(self, "model_speciess"): | if not hasattr(self, "model_speciess"): | ||||
| self.model_speciess = ParserType.GENERAL.value | self.model_speciess = ParserType.GENERAL.value | ||||
| self.layouter = partial(self.__remote_call, self.model_speciess) | |||||
| self.tbl_det = partial(self.__remote_call, "table_component") | |||||
| self.layout_labels = [ | |||||
| "_background_", | |||||
| "Text", | |||||
| "Title", | |||||
| "Figure", | |||||
| "Figure caption", | |||||
| "Table", | |||||
| "Table caption", | |||||
| "Header", | |||||
| "Footer", | |||||
| "Reference", | |||||
| "Equation", | |||||
| ] | |||||
| self.tsr_labels = [ | |||||
| "table", | |||||
| "table column", | |||||
| "table row", | |||||
| "table column header", | |||||
| "table projected row header", | |||||
| "table spanning cell", | |||||
| ] | |||||
| self.layouter = Recognizer(self.layout_labels, "layout", "/data/newpeak/medical-gpt/res/ppdet/") | |||||
| self.tbl_det = Recognizer(self.tsr_labels, "tsr", "/data/newpeak/medical-gpt/res/ppdet.tbl/") | |||||
| self.updown_cnt_mdl = xgb.Booster() | self.updown_cnt_mdl = xgb.Booster() | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| token = os.environ.get("INFINIFLOW_TOKEN") | token = os.environ.get("INFINIFLOW_TOKEN") | ||||
| if not url or not token: | if not url or not token: | ||||
| logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.") | logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.") | ||||
| return [] | |||||
| return [[] for _ in range(len(images))] | |||||
| def convert_image_to_bytes(PILimage): | def convert_image_to_bytes(PILimage): | ||||
| image = BytesIO() | image = BytesIO() | ||||
| return layouts | return layouts | ||||
| def __table_paddle(self, images): | |||||
| def __table_tsr(self, images): | |||||
| tbls = self.tbl_det(images, thr=0.5) | tbls = self.tbl_det(images, thr=0.5) | ||||
| res = [] | res = [] | ||||
| # align left&right for rows, align top&bottom for columns | # align left&right for rows, align top&bottom for columns | ||||
| assert len(self.page_images) == len(tbcnt) - 1 | assert len(self.page_images) == len(tbcnt) - 1 | ||||
| if not imgs: | if not imgs: | ||||
| return | return | ||||
| recos = self.__table_paddle(imgs) | |||||
| recos = self.__table_tsr(imgs) | |||||
| tbcnt = np.cumsum(tbcnt) | tbcnt = np.cumsum(tbcnt) | ||||
| for i in range(len(tbcnt) - 1): # for page | for i in range(len(tbcnt) - 1): # for page | ||||
| pg = [] | pg = [] | ||||
| b["H_right"] = spans[ii]["x1"] | b["H_right"] = spans[ii]["x1"] | ||||
| b["SP"] = ii | b["SP"] = ii | ||||
| def __ocr_paddle(self, pagenum, img, chars, ZM=3): | |||||
| bxs = self.ocr.ocr(np.array(img), cls=True)[0] | |||||
| def __ocr(self, pagenum, img, chars, ZM=3): | |||||
| bxs = self.ocr(np.array(img)) | |||||
| if not bxs: | if not bxs: | ||||
| self.boxes.append([]) | self.boxes.append([]) | ||||
| return | return | ||||
| self.boxes.append(bxs) | self.boxes.append(bxs) | ||||
| def _layouts_paddle(self, ZM): | |||||
| def _layouts_rec(self, ZM): | |||||
| assert len(self.page_images) == len(self.boxes) | assert len(self.page_images) == len(self.boxes) | ||||
| # Tag layout type | # Tag layout type | ||||
| boxes = [] | boxes = [] | ||||
| layouts = self.layouter(self.page_images) | layouts = self.layouter(self.page_images) | ||||
| #save_results(self.page_images, layouts, self.layout_labels, output_dir='output/', threshold=0.7) | |||||
| assert len(self.page_images) == len(layouts) | assert len(self.page_images) == len(layouts) | ||||
| for pn, lts in enumerate(layouts): | for pn, lts in enumerate(layouts): | ||||
| bxs = self.boxes[pn] | bxs = self.boxes[pn] | ||||
| # else: | # else: | ||||
| # self.page_cum_height.append( | # self.page_cum_height.append( | ||||
| # np.max([c["bottom"] for c in chars])) | # np.max([c["bottom"] for c in chars])) | ||||
| self.__ocr_paddle(i + 1, img, chars, zoomin) | |||||
| self.__ocr(i + 1, img, chars, zoomin) | |||||
| if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | ||||
| bxes = [b for bxs in self.boxes for b in bxs] | bxes = [b for bxs in self.boxes for b in bxs] | ||||
| def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): | def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): | ||||
| self.__images__(fnm, zoomin) | self.__images__(fnm, zoomin) | ||||
| self._layouts_paddle(zoomin) | |||||
| self._layouts_rec(zoomin) | |||||
| self._table_transformer_job(zoomin) | self._table_transformer_job(zoomin) | ||||
| self._text_merge() | self._text_merge() | ||||
| self._concat_downward() | self._concat_downward() |
| from .ocr import OCR | |||||
| from .recognizer import Recognizer |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | |||||
| import time | |||||
| import os | |||||
| from huggingface_hub import snapshot_download | |||||
| from .operators import * | |||||
| import numpy as np | |||||
| import onnxruntime as ort | |||||
| from api.utils.file_utils import get_project_base_directory | |||||
| from .postprocess import build_post_process | |||||
| from rag.settings import cron_logger | |||||
| def transform(data, ops=None): | |||||
| """ transform """ | |||||
| if ops is None: | |||||
| ops = [] | |||||
| for op in ops: | |||||
| data = op(data) | |||||
| if data is None: | |||||
| return None | |||||
| return data | |||||
| def create_operators(op_param_list, global_config=None): | |||||
| """ | |||||
| create operators based on the config | |||||
| Args: | |||||
| params(list): a dict list, used to create some operators | |||||
| """ | |||||
| assert isinstance( | |||||
| op_param_list, list), ('operator config should be a list') | |||||
| ops = [] | |||||
| for operator in op_param_list: | |||||
| assert isinstance(operator, | |||||
| dict) and len(operator) == 1, "yaml format error" | |||||
| op_name = list(operator)[0] | |||||
| param = {} if operator[op_name] is None else operator[op_name] | |||||
| if global_config is not None: | |||||
| param.update(global_config) | |||||
| op = eval(op_name)(**param) | |||||
| ops.append(op) | |||||
| return ops | |||||
| def load_model(model_dir, nm): | |||||
| model_file_path = os.path.join(model_dir, nm + ".onnx") | |||||
| if not os.path.exists(model_file_path): | |||||
| raise ValueError("not find model file path {}".format( | |||||
| model_file_path)) | |||||
| sess = ort.InferenceSession(model_file_path) | |||||
| return sess, sess.get_inputs()[0] | |||||
| class TextRecognizer(object): | |||||
| def __init__(self, model_dir): | |||||
| self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")] | |||||
| self.rec_batch_num = 16 | |||||
| postprocess_params = { | |||||
| 'name': 'CTCLabelDecode', | |||||
| "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"), | |||||
| "use_space_char": True | |||||
| } | |||||
| self.postprocess_op = build_post_process(postprocess_params) | |||||
| self.predictor, self.input_tensor = load_model(model_dir, 'rec') | |||||
| def resize_norm_img(self, img, max_wh_ratio): | |||||
| imgC, imgH, imgW = self.rec_image_shape | |||||
| assert imgC == img.shape[2] | |||||
| imgW = int((imgH * max_wh_ratio)) | |||||
| w = self.input_tensor.shape[3:][0] | |||||
| if isinstance(w, str): | |||||
| pass | |||||
| elif w is not None and w > 0: | |||||
| imgW = w | |||||
| h, w = img.shape[:2] | |||||
| ratio = w / float(h) | |||||
| if math.ceil(imgH * ratio) > imgW: | |||||
| resized_w = imgW | |||||
| else: | |||||
| resized_w = int(math.ceil(imgH * ratio)) | |||||
| resized_image = cv2.resize(img, (resized_w, imgH)) | |||||
| resized_image = resized_image.astype('float32') | |||||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||||
| resized_image -= 0.5 | |||||
| resized_image /= 0.5 | |||||
| padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) | |||||
| padding_im[:, :, 0:resized_w] = resized_image | |||||
| return padding_im | |||||
| def resize_norm_img_vl(self, img, image_shape): | |||||
| imgC, imgH, imgW = image_shape | |||||
| img = img[:, :, ::-1] # bgr2rgb | |||||
| resized_image = cv2.resize( | |||||
| img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) | |||||
| resized_image = resized_image.astype('float32') | |||||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||||
| return resized_image | |||||
| def resize_norm_img_srn(self, img, image_shape): | |||||
| imgC, imgH, imgW = image_shape | |||||
| img_black = np.zeros((imgH, imgW)) | |||||
| im_hei = img.shape[0] | |||||
| im_wid = img.shape[1] | |||||
| if im_wid <= im_hei * 1: | |||||
| img_new = cv2.resize(img, (imgH * 1, imgH)) | |||||
| elif im_wid <= im_hei * 2: | |||||
| img_new = cv2.resize(img, (imgH * 2, imgH)) | |||||
| elif im_wid <= im_hei * 3: | |||||
| img_new = cv2.resize(img, (imgH * 3, imgH)) | |||||
| else: | |||||
| img_new = cv2.resize(img, (imgW, imgH)) | |||||
| img_np = np.asarray(img_new) | |||||
| img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) | |||||
| img_black[:, 0:img_np.shape[1]] = img_np | |||||
| img_black = img_black[:, :, np.newaxis] | |||||
| row, col, c = img_black.shape | |||||
| c = 1 | |||||
| return np.reshape(img_black, (c, row, col)).astype(np.float32) | |||||
| def srn_other_inputs(self, image_shape, num_heads, max_text_length): | |||||
| imgC, imgH, imgW = image_shape | |||||
| feature_dim = int((imgH / 8) * (imgW / 8)) | |||||
| encoder_word_pos = np.array(range(0, feature_dim)).reshape( | |||||
| (feature_dim, 1)).astype('int64') | |||||
| gsrm_word_pos = np.array(range(0, max_text_length)).reshape( | |||||
| (max_text_length, 1)).astype('int64') | |||||
| gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) | |||||
| gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( | |||||
| [-1, 1, max_text_length, max_text_length]) | |||||
| gsrm_slf_attn_bias1 = np.tile( | |||||
| gsrm_slf_attn_bias1, | |||||
| [1, num_heads, 1, 1]).astype('float32') * [-1e9] | |||||
| gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( | |||||
| [-1, 1, max_text_length, max_text_length]) | |||||
| gsrm_slf_attn_bias2 = np.tile( | |||||
| gsrm_slf_attn_bias2, | |||||
| [1, num_heads, 1, 1]).astype('float32') * [-1e9] | |||||
| encoder_word_pos = encoder_word_pos[np.newaxis, :] | |||||
| gsrm_word_pos = gsrm_word_pos[np.newaxis, :] | |||||
| return [ | |||||
| encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, | |||||
| gsrm_slf_attn_bias2 | |||||
| ] | |||||
| def process_image_srn(self, img, image_shape, num_heads, max_text_length): | |||||
| norm_img = self.resize_norm_img_srn(img, image_shape) | |||||
| norm_img = norm_img[np.newaxis, :] | |||||
| [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ | |||||
| self.srn_other_inputs(image_shape, num_heads, max_text_length) | |||||
| gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) | |||||
| gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) | |||||
| encoder_word_pos = encoder_word_pos.astype(np.int64) | |||||
| gsrm_word_pos = gsrm_word_pos.astype(np.int64) | |||||
| return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, | |||||
| gsrm_slf_attn_bias2) | |||||
| def resize_norm_img_sar(self, img, image_shape, | |||||
| width_downsample_ratio=0.25): | |||||
| imgC, imgH, imgW_min, imgW_max = image_shape | |||||
| h = img.shape[0] | |||||
| w = img.shape[1] | |||||
| valid_ratio = 1.0 | |||||
| # make sure new_width is an integral multiple of width_divisor. | |||||
| width_divisor = int(1 / width_downsample_ratio) | |||||
| # resize | |||||
| ratio = w / float(h) | |||||
| resize_w = math.ceil(imgH * ratio) | |||||
| if resize_w % width_divisor != 0: | |||||
| resize_w = round(resize_w / width_divisor) * width_divisor | |||||
| if imgW_min is not None: | |||||
| resize_w = max(imgW_min, resize_w) | |||||
| if imgW_max is not None: | |||||
| valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) | |||||
| resize_w = min(imgW_max, resize_w) | |||||
| resized_image = cv2.resize(img, (resize_w, imgH)) | |||||
| resized_image = resized_image.astype('float32') | |||||
| # norm | |||||
| if image_shape[0] == 1: | |||||
| resized_image = resized_image / 255 | |||||
| resized_image = resized_image[np.newaxis, :] | |||||
| else: | |||||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||||
| resized_image -= 0.5 | |||||
| resized_image /= 0.5 | |||||
| resize_shape = resized_image.shape | |||||
| padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) | |||||
| padding_im[:, :, 0:resize_w] = resized_image | |||||
| pad_shape = padding_im.shape | |||||
| return padding_im, resize_shape, pad_shape, valid_ratio | |||||
| def resize_norm_img_spin(self, img): | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |||||
| # return padding_im | |||||
| img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC) | |||||
| img = np.array(img, np.float32) | |||||
| img = np.expand_dims(img, -1) | |||||
| img = img.transpose((2, 0, 1)) | |||||
| mean = [127.5] | |||||
| std = [127.5] | |||||
| mean = np.array(mean, dtype=np.float32) | |||||
| std = np.array(std, dtype=np.float32) | |||||
| mean = np.float32(mean.reshape(1, -1)) | |||||
| stdinv = 1 / np.float32(std.reshape(1, -1)) | |||||
| img -= mean | |||||
| img *= stdinv | |||||
| return img | |||||
| def resize_norm_img_svtr(self, img, image_shape): | |||||
| imgC, imgH, imgW = image_shape | |||||
| resized_image = cv2.resize( | |||||
| img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) | |||||
| resized_image = resized_image.astype('float32') | |||||
| resized_image = resized_image.transpose((2, 0, 1)) / 255 | |||||
| resized_image -= 0.5 | |||||
| resized_image /= 0.5 | |||||
| return resized_image | |||||
| def resize_norm_img_abinet(self, img, image_shape): | |||||
| imgC, imgH, imgW = image_shape | |||||
| resized_image = cv2.resize( | |||||
| img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) | |||||
| resized_image = resized_image.astype('float32') | |||||
| resized_image = resized_image / 255. | |||||
| mean = np.array([0.485, 0.456, 0.406]) | |||||
| std = np.array([0.229, 0.224, 0.225]) | |||||
| resized_image = ( | |||||
| resized_image - mean[None, None, ...]) / std[None, None, ...] | |||||
| resized_image = resized_image.transpose((2, 0, 1)) | |||||
| resized_image = resized_image.astype('float32') | |||||
| return resized_image | |||||
| def norm_img_can(self, img, image_shape): | |||||
| img = cv2.cvtColor( | |||||
| img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image | |||||
| if self.rec_image_shape[0] == 1: | |||||
| h, w = img.shape | |||||
| _, imgH, imgW = self.rec_image_shape | |||||
| if h < imgH or w < imgW: | |||||
| padding_h = max(imgH - h, 0) | |||||
| padding_w = max(imgW - w, 0) | |||||
| img_padded = np.pad(img, ((0, padding_h), (0, padding_w)), | |||||
| 'constant', | |||||
| constant_values=(255)) | |||||
| img = img_padded | |||||
| img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w | |||||
| img = img.astype('float32') | |||||
| return img | |||||
| def __call__(self, img_list): | |||||
| img_num = len(img_list) | |||||
| # Calculate the aspect ratio of all text bars | |||||
| width_list = [] | |||||
| for img in img_list: | |||||
| width_list.append(img.shape[1] / float(img.shape[0])) | |||||
| # Sorting can speed up the recognition process | |||||
| indices = np.argsort(np.array(width_list)) | |||||
| rec_res = [['', 0.0]] * img_num | |||||
| batch_num = self.rec_batch_num | |||||
| st = time.time() | |||||
| for beg_img_no in range(0, img_num, batch_num): | |||||
| end_img_no = min(img_num, beg_img_no + batch_num) | |||||
| norm_img_batch = [] | |||||
| imgC, imgH, imgW = self.rec_image_shape[:3] | |||||
| max_wh_ratio = imgW / imgH | |||||
| # max_wh_ratio = 0 | |||||
| for ino in range(beg_img_no, end_img_no): | |||||
| h, w = img_list[indices[ino]].shape[0:2] | |||||
| wh_ratio = w * 1.0 / h | |||||
| max_wh_ratio = max(max_wh_ratio, wh_ratio) | |||||
| for ino in range(beg_img_no, end_img_no): | |||||
| norm_img = self.resize_norm_img(img_list[indices[ino]], | |||||
| max_wh_ratio) | |||||
| norm_img = norm_img[np.newaxis, :] | |||||
| norm_img_batch.append(norm_img) | |||||
| norm_img_batch = np.concatenate(norm_img_batch) | |||||
| norm_img_batch = norm_img_batch.copy() | |||||
| input_dict = {} | |||||
| input_dict[self.input_tensor.name] = norm_img_batch | |||||
| outputs = self.predictor.run(None, input_dict) | |||||
| preds = outputs[0] | |||||
| rec_result = self.postprocess_op(preds) | |||||
| for rno in range(len(rec_result)): | |||||
| rec_res[indices[beg_img_no + rno]] = rec_result[rno] | |||||
| return rec_res, time.time() - st | |||||
| class TextDetector(object): | |||||
| def __init__(self, model_dir): | |||||
| pre_process_list = [{ | |||||
| 'DetResizeForTest': { | |||||
| 'limit_side_len': 960, | |||||
| 'limit_type': "max", | |||||
| } | |||||
| }, { | |||||
| 'NormalizeImage': { | |||||
| 'std': [0.229, 0.224, 0.225], | |||||
| 'mean': [0.485, 0.456, 0.406], | |||||
| 'scale': '1./255.', | |||||
| 'order': 'hwc' | |||||
| } | |||||
| }, { | |||||
| 'ToCHWImage': None | |||||
| }, { | |||||
| 'KeepKeys': { | |||||
| 'keep_keys': ['image', 'shape'] | |||||
| } | |||||
| }] | |||||
| postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.6, "max_candidates": 1000, | |||||
| "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"} | |||||
| self.postprocess_op = build_post_process(postprocess_params) | |||||
| self.predictor, self.input_tensor = load_model(model_dir, 'det') | |||||
| img_h, img_w = self.input_tensor.shape[2:] | |||||
| if isinstance(img_h, str) or isinstance(img_w, str): | |||||
| pass | |||||
| elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0: | |||||
| pre_process_list[0] = { | |||||
| 'DetResizeForTest': { | |||||
| 'image_shape': [img_h, img_w] | |||||
| } | |||||
| } | |||||
| self.preprocess_op = create_operators(pre_process_list) | |||||
| def order_points_clockwise(self, pts): | |||||
| rect = np.zeros((4, 2), dtype="float32") | |||||
| s = pts.sum(axis=1) | |||||
| rect[0] = pts[np.argmin(s)] | |||||
| rect[2] = pts[np.argmax(s)] | |||||
| tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) | |||||
| diff = np.diff(np.array(tmp), axis=1) | |||||
| rect[1] = tmp[np.argmin(diff)] | |||||
| rect[3] = tmp[np.argmax(diff)] | |||||
| return rect | |||||
| def clip_det_res(self, points, img_height, img_width): | |||||
| for pno in range(points.shape[0]): | |||||
| points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) | |||||
| points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) | |||||
| return points | |||||
| def filter_tag_det_res(self, dt_boxes, image_shape): | |||||
| img_height, img_width = image_shape[0:2] | |||||
| dt_boxes_new = [] | |||||
| for box in dt_boxes: | |||||
| if isinstance(box, list): | |||||
| box = np.array(box) | |||||
| box = self.order_points_clockwise(box) | |||||
| box = self.clip_det_res(box, img_height, img_width) | |||||
| rect_width = int(np.linalg.norm(box[0] - box[1])) | |||||
| rect_height = int(np.linalg.norm(box[0] - box[3])) | |||||
| if rect_width <= 3 or rect_height <= 3: | |||||
| continue | |||||
| dt_boxes_new.append(box) | |||||
| dt_boxes = np.array(dt_boxes_new) | |||||
| return dt_boxes | |||||
| def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): | |||||
| img_height, img_width = image_shape[0:2] | |||||
| dt_boxes_new = [] | |||||
| for box in dt_boxes: | |||||
| if isinstance(box, list): | |||||
| box = np.array(box) | |||||
| box = self.clip_det_res(box, img_height, img_width) | |||||
| dt_boxes_new.append(box) | |||||
| dt_boxes = np.array(dt_boxes_new) | |||||
| return dt_boxes | |||||
| def __call__(self, img): | |||||
| ori_im = img.copy() | |||||
| data = {'image': img} | |||||
| st = time.time() | |||||
| data = transform(data, self.preprocess_op) | |||||
| img, shape_list = data | |||||
| if img is None: | |||||
| return None, 0 | |||||
| img = np.expand_dims(img, axis=0) | |||||
| shape_list = np.expand_dims(shape_list, axis=0) | |||||
| img = img.copy() | |||||
| input_dict = {} | |||||
| input_dict[self.input_tensor.name] = img | |||||
| outputs = self.predictor.run(None, input_dict) | |||||
| post_result = self.postprocess_op({"maps": outputs[0]}, shape_list) | |||||
| dt_boxes = post_result[0]['points'] | |||||
| dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) | |||||
| return dt_boxes, time.time() - st | |||||
| class OCR(object): | |||||
| def __init__(self, model_dir=None): | |||||
| """ | |||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||||
| For Linux: | |||||
| export HF_ENDPOINT=https://hf-mirror.com | |||||
| For Windows: | |||||
| Good luck | |||||
| ^_- | |||||
| """ | |||||
| if not model_dir: | |||||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||||
| self.text_detector = TextDetector(model_dir) | |||||
| self.text_recognizer = TextRecognizer(model_dir) | |||||
| self.drop_score = 0.5 | |||||
| self.crop_image_res_index = 0 | |||||
| def get_rotate_crop_image(self, img, points): | |||||
| ''' | |||||
| img_height, img_width = img.shape[0:2] | |||||
| left = int(np.min(points[:, 0])) | |||||
| right = int(np.max(points[:, 0])) | |||||
| top = int(np.min(points[:, 1])) | |||||
| bottom = int(np.max(points[:, 1])) | |||||
| img_crop = img[top:bottom, left:right, :].copy() | |||||
| points[:, 0] = points[:, 0] - left | |||||
| points[:, 1] = points[:, 1] - top | |||||
| ''' | |||||
| assert len(points) == 4, "shape of points must be 4*2" | |||||
| img_crop_width = int( | |||||
| max( | |||||
| np.linalg.norm(points[0] - points[1]), | |||||
| np.linalg.norm(points[2] - points[3]))) | |||||
| img_crop_height = int( | |||||
| max( | |||||
| np.linalg.norm(points[0] - points[3]), | |||||
| np.linalg.norm(points[1] - points[2]))) | |||||
| pts_std = np.float32([[0, 0], [img_crop_width, 0], | |||||
| [img_crop_width, img_crop_height], | |||||
| [0, img_crop_height]]) | |||||
| M = cv2.getPerspectiveTransform(points, pts_std) | |||||
| dst_img = cv2.warpPerspective( | |||||
| img, | |||||
| M, (img_crop_width, img_crop_height), | |||||
| borderMode=cv2.BORDER_REPLICATE, | |||||
| flags=cv2.INTER_CUBIC) | |||||
| dst_img_height, dst_img_width = dst_img.shape[0:2] | |||||
| if dst_img_height * 1.0 / dst_img_width >= 1.5: | |||||
| dst_img = np.rot90(dst_img) | |||||
| return dst_img | |||||
| def sorted_boxes(self, dt_boxes): | |||||
| """ | |||||
| Sort text boxes in order from top to bottom, left to right | |||||
| args: | |||||
| dt_boxes(array):detected text boxes with shape [4, 2] | |||||
| return: | |||||
| sorted boxes(array) with shape [4, 2] | |||||
| """ | |||||
| num_boxes = dt_boxes.shape[0] | |||||
| sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) | |||||
| _boxes = list(sorted_boxes) | |||||
| for i in range(num_boxes - 1): | |||||
| for j in range(i, -1, -1): | |||||
| if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ | |||||
| (_boxes[j + 1][0][0] < _boxes[j][0][0]): | |||||
| tmp = _boxes[j] | |||||
| _boxes[j] = _boxes[j + 1] | |||||
| _boxes[j + 1] = tmp | |||||
| else: | |||||
| break | |||||
| return _boxes | |||||
| def __call__(self, img, cls=True): | |||||
| time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} | |||||
| if img is None: | |||||
| return None, None, time_dict | |||||
| start = time.time() | |||||
| ori_im = img.copy() | |||||
| dt_boxes, elapse = self.text_detector(img) | |||||
| time_dict['det'] = elapse | |||||
| if dt_boxes is None: | |||||
| end = time.time() | |||||
| time_dict['all'] = end - start | |||||
| return None, None, time_dict | |||||
| else: | |||||
| cron_logger.debug("dt_boxes num : {}, elapsed : {}".format( | |||||
| len(dt_boxes), elapse)) | |||||
| img_crop_list = [] | |||||
| dt_boxes = self.sorted_boxes(dt_boxes) | |||||
| for bno in range(len(dt_boxes)): | |||||
| tmp_box = copy.deepcopy(dt_boxes[bno]) | |||||
| img_crop = self.get_rotate_crop_image(ori_im, tmp_box) | |||||
| img_crop_list.append(img_crop) | |||||
| rec_res, elapse = self.text_recognizer(img_crop_list) | |||||
| time_dict['rec'] = elapse | |||||
| cron_logger.debug("rec_res num : {}, elapsed : {}".format( | |||||
| len(rec_res), elapse)) | |||||
| filter_boxes, filter_rec_res = [], [] | |||||
| for box, rec_result in zip(dt_boxes, rec_res): | |||||
| text, score = rec_result | |||||
| if score >= self.drop_score: | |||||
| filter_boxes.append(box) | |||||
| filter_rec_res.append(rec_result) | |||||
| end = time.time() | |||||
| time_dict['all'] = end - start | |||||
| #for bno in range(len(img_crop_list)): | |||||
| # print(f"{bno}, {rec_res[bno]}") | |||||
| return list(zip([a.tolist() for a in filter_boxes], filter_rec_res)) |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import sys | |||||
| import six | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import math | |||||
| from PIL import Image | |||||
| class DecodeImage(object): | |||||
| """ decode image """ | |||||
| def __init__(self, | |||||
| img_mode='RGB', | |||||
| channel_first=False, | |||||
| ignore_orientation=False, | |||||
| **kwargs): | |||||
| self.img_mode = img_mode | |||||
| self.channel_first = channel_first | |||||
| self.ignore_orientation = ignore_orientation | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| if six.PY2: | |||||
| assert isinstance(img, str) and len( | |||||
| img) > 0, "invalid input 'img' in DecodeImage" | |||||
| else: | |||||
| assert isinstance(img, bytes) and len( | |||||
| img) > 0, "invalid input 'img' in DecodeImage" | |||||
| img = np.frombuffer(img, dtype='uint8') | |||||
| if self.ignore_orientation: | |||||
| img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | | |||||
| cv2.IMREAD_COLOR) | |||||
| else: | |||||
| img = cv2.imdecode(img, 1) | |||||
| if img is None: | |||||
| return None | |||||
| if self.img_mode == 'GRAY': | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| elif self.img_mode == 'RGB': | |||||
| assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( | |||||
| img.shape) | |||||
| img = img[:, :, ::-1] | |||||
| if self.channel_first: | |||||
| img = img.transpose((2, 0, 1)) | |||||
| data['image'] = img | |||||
| return data | |||||
| class StandardizeImage(object): | |||||
| """normalize image | |||||
| Args: | |||||
| mean (list): im - mean | |||||
| std (list): im / std | |||||
| is_scale (bool): whether need im / 255 | |||||
| norm_type (str): type in ['mean_std', 'none'] | |||||
| """ | |||||
| def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): | |||||
| self.mean = mean | |||||
| self.std = std | |||||
| self.is_scale = is_scale | |||||
| self.norm_type = norm_type | |||||
| def __call__(self, im, im_info): | |||||
| """ | |||||
| Args: | |||||
| im (np.ndarray): image (np.ndarray) | |||||
| im_info (dict): info of image | |||||
| Returns: | |||||
| im (np.ndarray): processed image (np.ndarray) | |||||
| im_info (dict): info of processed image | |||||
| """ | |||||
| im = im.astype(np.float32, copy=False) | |||||
| if self.is_scale: | |||||
| scale = 1.0 / 255.0 | |||||
| im *= scale | |||||
| if self.norm_type == 'mean_std': | |||||
| mean = np.array(self.mean)[np.newaxis, np.newaxis, :] | |||||
| std = np.array(self.std)[np.newaxis, np.newaxis, :] | |||||
| im -= mean | |||||
| im /= std | |||||
| return im, im_info | |||||
| class NormalizeImage(object): | |||||
| """ normalize image such as substract mean, divide std | |||||
| """ | |||||
| def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): | |||||
| if isinstance(scale, str): | |||||
| scale = eval(scale) | |||||
| self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) | |||||
| mean = mean if mean is not None else [0.485, 0.456, 0.406] | |||||
| std = std if std is not None else [0.229, 0.224, 0.225] | |||||
| shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) | |||||
| self.mean = np.array(mean).reshape(shape).astype('float32') | |||||
| self.std = np.array(std).reshape(shape).astype('float32') | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| from PIL import Image | |||||
| if isinstance(img, Image.Image): | |||||
| img = np.array(img) | |||||
| assert isinstance(img, | |||||
| np.ndarray), "invalid input 'img' in NormalizeImage" | |||||
| data['image'] = ( | |||||
| img.astype('float32') * self.scale - self.mean) / self.std | |||||
| return data | |||||
| class ToCHWImage(object): | |||||
| """ convert hwc image to chw image | |||||
| """ | |||||
| def __init__(self, **kwargs): | |||||
| pass | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| from PIL import Image | |||||
| if isinstance(img, Image.Image): | |||||
| img = np.array(img) | |||||
| data['image'] = img.transpose((2, 0, 1)) | |||||
| return data | |||||
| class Fasttext(object): | |||||
| def __init__(self, path="None", **kwargs): | |||||
| import fasttext | |||||
| self.fast_model = fasttext.load_model(path) | |||||
| def __call__(self, data): | |||||
| label = data['label'] | |||||
| fast_label = self.fast_model[label] | |||||
| data['fast_label'] = fast_label | |||||
| return data | |||||
| class KeepKeys(object): | |||||
| def __init__(self, keep_keys, **kwargs): | |||||
| self.keep_keys = keep_keys | |||||
| def __call__(self, data): | |||||
| data_list = [] | |||||
| for key in self.keep_keys: | |||||
| data_list.append(data[key]) | |||||
| return data_list | |||||
| class Pad(object): | |||||
| def __init__(self, size=None, size_div=32, **kwargs): | |||||
| if size is not None and not isinstance(size, (int, list, tuple)): | |||||
| raise TypeError("Type of target_size is invalid. Now is {}".format( | |||||
| type(size))) | |||||
| if isinstance(size, int): | |||||
| size = [size, size] | |||||
| self.size = size | |||||
| self.size_div = size_div | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| img_h, img_w = img.shape[0], img.shape[1] | |||||
| if self.size: | |||||
| resize_h2, resize_w2 = self.size | |||||
| assert ( | |||||
| img_h < resize_h2 and img_w < resize_w2 | |||||
| ), '(h, w) of target size should be greater than (img_h, img_w)' | |||||
| else: | |||||
| resize_h2 = max( | |||||
| int(math.ceil(img.shape[0] / self.size_div) * self.size_div), | |||||
| self.size_div) | |||||
| resize_w2 = max( | |||||
| int(math.ceil(img.shape[1] / self.size_div) * self.size_div), | |||||
| self.size_div) | |||||
| img = cv2.copyMakeBorder( | |||||
| img, | |||||
| 0, | |||||
| resize_h2 - img_h, | |||||
| 0, | |||||
| resize_w2 - img_w, | |||||
| cv2.BORDER_CONSTANT, | |||||
| value=0) | |||||
| data['image'] = img | |||||
| return data | |||||
| class LinearResize(object): | |||||
| """resize image by target_size and max_size | |||||
| Args: | |||||
| target_size (int): the target size of image | |||||
| keep_ratio (bool): whether keep_ratio or not, default true | |||||
| interp (int): method of resize | |||||
| """ | |||||
| def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): | |||||
| if isinstance(target_size, int): | |||||
| target_size = [target_size, target_size] | |||||
| self.target_size = target_size | |||||
| self.keep_ratio = keep_ratio | |||||
| self.interp = interp | |||||
| def __call__(self, im, im_info): | |||||
| """ | |||||
| Args: | |||||
| im (np.ndarray): image (np.ndarray) | |||||
| im_info (dict): info of image | |||||
| Returns: | |||||
| im (np.ndarray): processed image (np.ndarray) | |||||
| im_info (dict): info of processed image | |||||
| """ | |||||
| assert len(self.target_size) == 2 | |||||
| assert self.target_size[0] > 0 and self.target_size[1] > 0 | |||||
| im_channel = im.shape[2] | |||||
| im_scale_y, im_scale_x = self.generate_scale(im) | |||||
| im = cv2.resize( | |||||
| im, | |||||
| None, | |||||
| None, | |||||
| fx=im_scale_x, | |||||
| fy=im_scale_y, | |||||
| interpolation=self.interp) | |||||
| im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') | |||||
| im_info['scale_factor'] = np.array( | |||||
| [im_scale_y, im_scale_x]).astype('float32') | |||||
| return im, im_info | |||||
| def generate_scale(self, im): | |||||
| """ | |||||
| Args: | |||||
| im (np.ndarray): image (np.ndarray) | |||||
| Returns: | |||||
| im_scale_x: the resize ratio of X | |||||
| im_scale_y: the resize ratio of Y | |||||
| """ | |||||
| origin_shape = im.shape[:2] | |||||
| im_c = im.shape[2] | |||||
| if self.keep_ratio: | |||||
| im_size_min = np.min(origin_shape) | |||||
| im_size_max = np.max(origin_shape) | |||||
| target_size_min = np.min(self.target_size) | |||||
| target_size_max = np.max(self.target_size) | |||||
| im_scale = float(target_size_min) / float(im_size_min) | |||||
| if np.round(im_scale * im_size_max) > target_size_max: | |||||
| im_scale = float(target_size_max) / float(im_size_max) | |||||
| im_scale_x = im_scale | |||||
| im_scale_y = im_scale | |||||
| else: | |||||
| resize_h, resize_w = self.target_size | |||||
| im_scale_y = resize_h / float(origin_shape[0]) | |||||
| im_scale_x = resize_w / float(origin_shape[1]) | |||||
| return im_scale_y, im_scale_x | |||||
| class Resize(object): | |||||
| def __init__(self, size=(640, 640), **kwargs): | |||||
| self.size = size | |||||
| def resize_image(self, img): | |||||
| resize_h, resize_w = self.size | |||||
| ori_h, ori_w = img.shape[:2] # (h, w, c) | |||||
| ratio_h = float(resize_h) / ori_h | |||||
| ratio_w = float(resize_w) / ori_w | |||||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||||
| return img, [ratio_h, ratio_w] | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| if 'polys' in data: | |||||
| text_polys = data['polys'] | |||||
| img_resize, [ratio_h, ratio_w] = self.resize_image(img) | |||||
| if 'polys' in data: | |||||
| new_boxes = [] | |||||
| for box in text_polys: | |||||
| new_box = [] | |||||
| for cord in box: | |||||
| new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) | |||||
| new_boxes.append(new_box) | |||||
| data['polys'] = np.array(new_boxes, dtype=np.float32) | |||||
| data['image'] = img_resize | |||||
| return data | |||||
| class DetResizeForTest(object): | |||||
| def __init__(self, **kwargs): | |||||
| super(DetResizeForTest, self).__init__() | |||||
| self.resize_type = 0 | |||||
| self.keep_ratio = False | |||||
| if 'image_shape' in kwargs: | |||||
| self.image_shape = kwargs['image_shape'] | |||||
| self.resize_type = 1 | |||||
| if 'keep_ratio' in kwargs: | |||||
| self.keep_ratio = kwargs['keep_ratio'] | |||||
| elif 'limit_side_len' in kwargs: | |||||
| self.limit_side_len = kwargs['limit_side_len'] | |||||
| self.limit_type = kwargs.get('limit_type', 'min') | |||||
| elif 'resize_long' in kwargs: | |||||
| self.resize_type = 2 | |||||
| self.resize_long = kwargs.get('resize_long', 960) | |||||
| else: | |||||
| self.limit_side_len = 736 | |||||
| self.limit_type = 'min' | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| src_h, src_w, _ = img.shape | |||||
| if sum([src_h, src_w]) < 64: | |||||
| img = self.image_padding(img) | |||||
| if self.resize_type == 0: | |||||
| # img, shape = self.resize_image_type0(img) | |||||
| img, [ratio_h, ratio_w] = self.resize_image_type0(img) | |||||
| elif self.resize_type == 2: | |||||
| img, [ratio_h, ratio_w] = self.resize_image_type2(img) | |||||
| else: | |||||
| # img, shape = self.resize_image_type1(img) | |||||
| img, [ratio_h, ratio_w] = self.resize_image_type1(img) | |||||
| data['image'] = img | |||||
| data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) | |||||
| return data | |||||
| def image_padding(self, im, value=0): | |||||
| h, w, c = im.shape | |||||
| im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value | |||||
| im_pad[:h, :w, :] = im | |||||
| return im_pad | |||||
| def resize_image_type1(self, img): | |||||
| resize_h, resize_w = self.image_shape | |||||
| ori_h, ori_w = img.shape[:2] # (h, w, c) | |||||
| if self.keep_ratio is True: | |||||
| resize_w = ori_w * resize_h / ori_h | |||||
| N = math.ceil(resize_w / 32) | |||||
| resize_w = N * 32 | |||||
| ratio_h = float(resize_h) / ori_h | |||||
| ratio_w = float(resize_w) / ori_w | |||||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||||
| # return img, np.array([ori_h, ori_w]) | |||||
| return img, [ratio_h, ratio_w] | |||||
| def resize_image_type0(self, img): | |||||
| """ | |||||
| resize image to a size multiple of 32 which is required by the network | |||||
| args: | |||||
| img(array): array with shape [h, w, c] | |||||
| return(tuple): | |||||
| img, (ratio_h, ratio_w) | |||||
| """ | |||||
| limit_side_len = self.limit_side_len | |||||
| h, w, c = img.shape | |||||
| # limit the max side | |||||
| if self.limit_type == 'max': | |||||
| if max(h, w) > limit_side_len: | |||||
| if h > w: | |||||
| ratio = float(limit_side_len) / h | |||||
| else: | |||||
| ratio = float(limit_side_len) / w | |||||
| else: | |||||
| ratio = 1. | |||||
| elif self.limit_type == 'min': | |||||
| if min(h, w) < limit_side_len: | |||||
| if h < w: | |||||
| ratio = float(limit_side_len) / h | |||||
| else: | |||||
| ratio = float(limit_side_len) / w | |||||
| else: | |||||
| ratio = 1. | |||||
| elif self.limit_type == 'resize_long': | |||||
| ratio = float(limit_side_len) / max(h, w) | |||||
| else: | |||||
| raise Exception('not support limit type, image ') | |||||
| resize_h = int(h * ratio) | |||||
| resize_w = int(w * ratio) | |||||
| resize_h = max(int(round(resize_h / 32) * 32), 32) | |||||
| resize_w = max(int(round(resize_w / 32) * 32), 32) | |||||
| try: | |||||
| if int(resize_w) <= 0 or int(resize_h) <= 0: | |||||
| return None, (None, None) | |||||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||||
| except BaseException: | |||||
| print(img.shape, resize_w, resize_h) | |||||
| sys.exit(0) | |||||
| ratio_h = resize_h / float(h) | |||||
| ratio_w = resize_w / float(w) | |||||
| return img, [ratio_h, ratio_w] | |||||
| def resize_image_type2(self, img): | |||||
| h, w, _ = img.shape | |||||
| resize_w = w | |||||
| resize_h = h | |||||
| if resize_h > resize_w: | |||||
| ratio = float(self.resize_long) / resize_h | |||||
| else: | |||||
| ratio = float(self.resize_long) / resize_w | |||||
| resize_h = int(resize_h * ratio) | |||||
| resize_w = int(resize_w * ratio) | |||||
| max_stride = 128 | |||||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||||
| img = cv2.resize(img, (int(resize_w), int(resize_h))) | |||||
| ratio_h = resize_h / float(h) | |||||
| ratio_w = resize_w / float(w) | |||||
| return img, [ratio_h, ratio_w] | |||||
| class E2EResizeForTest(object): | |||||
| def __init__(self, **kwargs): | |||||
| super(E2EResizeForTest, self).__init__() | |||||
| self.max_side_len = kwargs['max_side_len'] | |||||
| self.valid_set = kwargs['valid_set'] | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| src_h, src_w, _ = img.shape | |||||
| if self.valid_set == 'totaltext': | |||||
| im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( | |||||
| img, max_side_len=self.max_side_len) | |||||
| else: | |||||
| im_resized, (ratio_h, ratio_w) = self.resize_image( | |||||
| img, max_side_len=self.max_side_len) | |||||
| data['image'] = im_resized | |||||
| data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) | |||||
| return data | |||||
| def resize_image_for_totaltext(self, im, max_side_len=512): | |||||
| h, w, _ = im.shape | |||||
| resize_w = w | |||||
| resize_h = h | |||||
| ratio = 1.25 | |||||
| if h * ratio > max_side_len: | |||||
| ratio = float(max_side_len) / resize_h | |||||
| resize_h = int(resize_h * ratio) | |||||
| resize_w = int(resize_w * ratio) | |||||
| max_stride = 128 | |||||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||||
| im = cv2.resize(im, (int(resize_w), int(resize_h))) | |||||
| ratio_h = resize_h / float(h) | |||||
| ratio_w = resize_w / float(w) | |||||
| return im, (ratio_h, ratio_w) | |||||
| def resize_image(self, im, max_side_len=512): | |||||
| """ | |||||
| resize image to a size multiple of max_stride which is required by the network | |||||
| :param im: the resized image | |||||
| :param max_side_len: limit of max image size to avoid out of memory in gpu | |||||
| :return: the resized image and the resize ratio | |||||
| """ | |||||
| h, w, _ = im.shape | |||||
| resize_w = w | |||||
| resize_h = h | |||||
| # Fix the longer side | |||||
| if resize_h > resize_w: | |||||
| ratio = float(max_side_len) / resize_h | |||||
| else: | |||||
| ratio = float(max_side_len) / resize_w | |||||
| resize_h = int(resize_h * ratio) | |||||
| resize_w = int(resize_w * ratio) | |||||
| max_stride = 128 | |||||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||||
| im = cv2.resize(im, (int(resize_w), int(resize_h))) | |||||
| ratio_h = resize_h / float(h) | |||||
| ratio_w = resize_w / float(w) | |||||
| return im, (ratio_h, ratio_w) | |||||
| class KieResize(object): | |||||
| def __init__(self, **kwargs): | |||||
| super(KieResize, self).__init__() | |||||
| self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[ | |||||
| 'img_scale'][1] | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| points = data['points'] | |||||
| src_h, src_w, _ = img.shape | |||||
| im_resized, scale_factor, [ratio_h, ratio_w | |||||
| ], [new_h, new_w] = self.resize_image(img) | |||||
| resize_points = self.resize_boxes(img, points, scale_factor) | |||||
| data['ori_image'] = img | |||||
| data['ori_boxes'] = points | |||||
| data['points'] = resize_points | |||||
| data['image'] = im_resized | |||||
| data['shape'] = np.array([new_h, new_w]) | |||||
| return data | |||||
| def resize_image(self, img): | |||||
| norm_img = np.zeros([1024, 1024, 3], dtype='float32') | |||||
| scale = [512, 1024] | |||||
| h, w = img.shape[:2] | |||||
| max_long_edge = max(scale) | |||||
| max_short_edge = min(scale) | |||||
| scale_factor = min(max_long_edge / max(h, w), | |||||
| max_short_edge / min(h, w)) | |||||
| resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( | |||||
| scale_factor) + 0.5) | |||||
| max_stride = 32 | |||||
| resize_h = (resize_h + max_stride - 1) // max_stride * max_stride | |||||
| resize_w = (resize_w + max_stride - 1) // max_stride * max_stride | |||||
| im = cv2.resize(img, (resize_w, resize_h)) | |||||
| new_h, new_w = im.shape[:2] | |||||
| w_scale = new_w / w | |||||
| h_scale = new_h / h | |||||
| scale_factor = np.array( | |||||
| [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||||
| norm_img[:new_h, :new_w, :] = im | |||||
| return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] | |||||
| def resize_boxes(self, im, points, scale_factor): | |||||
| points = points * scale_factor | |||||
| img_shape = im.shape[:2] | |||||
| points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) | |||||
| points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) | |||||
| return points | |||||
| class SRResize(object): | |||||
| def __init__(self, | |||||
| imgH=32, | |||||
| imgW=128, | |||||
| down_sample_scale=4, | |||||
| keep_ratio=False, | |||||
| min_ratio=1, | |||||
| mask=False, | |||||
| infer_mode=False, | |||||
| **kwargs): | |||||
| self.imgH = imgH | |||||
| self.imgW = imgW | |||||
| self.keep_ratio = keep_ratio | |||||
| self.min_ratio = min_ratio | |||||
| self.down_sample_scale = down_sample_scale | |||||
| self.mask = mask | |||||
| self.infer_mode = infer_mode | |||||
| def __call__(self, data): | |||||
| imgH = self.imgH | |||||
| imgW = self.imgW | |||||
| images_lr = data["image_lr"] | |||||
| transform2 = ResizeNormalize( | |||||
| (imgW // self.down_sample_scale, imgH // self.down_sample_scale)) | |||||
| images_lr = transform2(images_lr) | |||||
| data["img_lr"] = images_lr | |||||
| if self.infer_mode: | |||||
| return data | |||||
| images_HR = data["image_hr"] | |||||
| label_strs = data["label"] | |||||
| transform = ResizeNormalize((imgW, imgH)) | |||||
| images_HR = transform(images_HR) | |||||
| data["img_hr"] = images_HR | |||||
| return data | |||||
| class ResizeNormalize(object): | |||||
| def __init__(self, size, interpolation=Image.BICUBIC): | |||||
| self.size = size | |||||
| self.interpolation = interpolation | |||||
| def __call__(self, img): | |||||
| img = img.resize(self.size, self.interpolation) | |||||
| img_numpy = np.array(img).astype("float32") | |||||
| img_numpy = img_numpy.transpose((2, 0, 1)) / 255 | |||||
| return img_numpy | |||||
| class GrayImageChannelFormat(object): | |||||
| """ | |||||
| format gray scale image's channel: (3,h,w) -> (1,h,w) | |||||
| Args: | |||||
| inverse: inverse gray image | |||||
| """ | |||||
| def __init__(self, inverse=False, **kwargs): | |||||
| self.inverse = inverse | |||||
| def __call__(self, data): | |||||
| img = data['image'] | |||||
| img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |||||
| img_expanded = np.expand_dims(img_single_channel, 0) | |||||
| if self.inverse: | |||||
| data['image'] = np.abs(img_expanded - 1) | |||||
| else: | |||||
| data['image'] = img_expanded | |||||
| data['src_image'] = img | |||||
| return data | |||||
| class Permute(object): | |||||
| """permute image | |||||
| Args: | |||||
| to_bgr (bool): whether convert RGB to BGR | |||||
| channel_first (bool): whether convert HWC to CHW | |||||
| """ | |||||
| def __init__(self, ): | |||||
| super(Permute, self).__init__() | |||||
| def __call__(self, im, im_info): | |||||
| """ | |||||
| Args: | |||||
| im (np.ndarray): image (np.ndarray) | |||||
| im_info (dict): info of image | |||||
| Returns: | |||||
| im (np.ndarray): processed image (np.ndarray) | |||||
| im_info (dict): info of processed image | |||||
| """ | |||||
| im = im.transpose((2, 0, 1)).copy() | |||||
| return im, im_info | |||||
| class PadStride(object): | |||||
| """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config | |||||
| Args: | |||||
| stride (bool): model with FPN need image shape % stride == 0 | |||||
| """ | |||||
| def __init__(self, stride=0): | |||||
| self.coarsest_stride = stride | |||||
| def __call__(self, im, im_info): | |||||
| """ | |||||
| Args: | |||||
| im (np.ndarray): image (np.ndarray) | |||||
| im_info (dict): info of image | |||||
| Returns: | |||||
| im (np.ndarray): processed image (np.ndarray) | |||||
| im_info (dict): info of processed image | |||||
| """ | |||||
| coarsest_stride = self.coarsest_stride | |||||
| if coarsest_stride <= 0: | |||||
| return im, im_info | |||||
| im_c, im_h, im_w = im.shape | |||||
| pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) | |||||
| pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) | |||||
| padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) | |||||
| padding_im[:, :im_h, :im_w] = im | |||||
| return padding_im, im_info | |||||
| def decode_image(im_file, im_info): | |||||
| """read rgb image | |||||
| Args: | |||||
| im_file (str|np.ndarray): input can be image path or np.ndarray | |||||
| im_info (dict): info of image | |||||
| Returns: | |||||
| im (np.ndarray): processed image (np.ndarray) | |||||
| im_info (dict): info of processed image | |||||
| """ | |||||
| if isinstance(im_file, str): | |||||
| with open(im_file, 'rb') as f: | |||||
| im_read = f.read() | |||||
| data = np.frombuffer(im_read, dtype='uint8') | |||||
| im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode | |||||
| im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | |||||
| else: | |||||
| im = im_file | |||||
| im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) | |||||
| im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) | |||||
| return im, im_info | |||||
| def preprocess(im, preprocess_ops): | |||||
| # process image by preprocess_ops | |||||
| im_info = { | |||||
| 'scale_factor': np.array( | |||||
| [1., 1.], dtype=np.float32), | |||||
| 'im_shape': None, | |||||
| } | |||||
| im, im_info = decode_image(im, im_info) | |||||
| for operator in preprocess_ops: | |||||
| im, im_info = operator(im, im_info) | |||||
| return im, im_info |
| import copy | |||||
| import numpy as np | |||||
| import cv2 | |||||
| import paddle | |||||
| from shapely.geometry import Polygon | |||||
| import pyclipper | |||||
| def build_post_process(config, global_config=None): | |||||
| support_dict = ['DBPostProcess', 'CTCLabelDecode'] | |||||
| config = copy.deepcopy(config) | |||||
| module_name = config.pop('name') | |||||
| if module_name == "None": | |||||
| return | |||||
| if global_config is not None: | |||||
| config.update(global_config) | |||||
| assert module_name in support_dict, Exception( | |||||
| 'post process only support {}'.format(support_dict)) | |||||
| module_class = eval(module_name)(**config) | |||||
| return module_class | |||||
| class DBPostProcess(object): | |||||
| """ | |||||
| The post process for Differentiable Binarization (DB). | |||||
| """ | |||||
| def __init__(self, | |||||
| thresh=0.3, | |||||
| box_thresh=0.7, | |||||
| max_candidates=1000, | |||||
| unclip_ratio=2.0, | |||||
| use_dilation=False, | |||||
| score_mode="fast", | |||||
| box_type='quad', | |||||
| **kwargs): | |||||
| self.thresh = thresh | |||||
| self.box_thresh = box_thresh | |||||
| self.max_candidates = max_candidates | |||||
| self.unclip_ratio = unclip_ratio | |||||
| self.min_size = 3 | |||||
| self.score_mode = score_mode | |||||
| self.box_type = box_type | |||||
| assert score_mode in [ | |||||
| "slow", "fast" | |||||
| ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) | |||||
| self.dilation_kernel = None if not use_dilation else np.array( | |||||
| [[1, 1], [1, 1]]) | |||||
| def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): | |||||
| ''' | |||||
| _bitmap: single map with shape (1, H, W), | |||||
| whose values are binarized as {0, 1} | |||||
| ''' | |||||
| bitmap = _bitmap | |||||
| height, width = bitmap.shape | |||||
| boxes = [] | |||||
| scores = [] | |||||
| contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), | |||||
| cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | |||||
| for contour in contours[:self.max_candidates]: | |||||
| epsilon = 0.002 * cv2.arcLength(contour, True) | |||||
| approx = cv2.approxPolyDP(contour, epsilon, True) | |||||
| points = approx.reshape((-1, 2)) | |||||
| if points.shape[0] < 4: | |||||
| continue | |||||
| score = self.box_score_fast(pred, points.reshape(-1, 2)) | |||||
| if self.box_thresh > score: | |||||
| continue | |||||
| if points.shape[0] > 2: | |||||
| box = self.unclip(points, self.unclip_ratio) | |||||
| if len(box) > 1: | |||||
| continue | |||||
| else: | |||||
| continue | |||||
| box = box.reshape(-1, 2) | |||||
| _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) | |||||
| if sside < self.min_size + 2: | |||||
| continue | |||||
| box = np.array(box) | |||||
| box[:, 0] = np.clip( | |||||
| np.round(box[:, 0] / width * dest_width), 0, dest_width) | |||||
| box[:, 1] = np.clip( | |||||
| np.round(box[:, 1] / height * dest_height), 0, dest_height) | |||||
| boxes.append(box.tolist()) | |||||
| scores.append(score) | |||||
| return boxes, scores | |||||
| def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): | |||||
| ''' | |||||
| _bitmap: single map with shape (1, H, W), | |||||
| whose values are binarized as {0, 1} | |||||
| ''' | |||||
| bitmap = _bitmap | |||||
| height, width = bitmap.shape | |||||
| outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, | |||||
| cv2.CHAIN_APPROX_SIMPLE) | |||||
| if len(outs) == 3: | |||||
| img, contours, _ = outs[0], outs[1], outs[2] | |||||
| elif len(outs) == 2: | |||||
| contours, _ = outs[0], outs[1] | |||||
| num_contours = min(len(contours), self.max_candidates) | |||||
| boxes = [] | |||||
| scores = [] | |||||
| for index in range(num_contours): | |||||
| contour = contours[index] | |||||
| points, sside = self.get_mini_boxes(contour) | |||||
| if sside < self.min_size: | |||||
| continue | |||||
| points = np.array(points) | |||||
| if self.score_mode == "fast": | |||||
| score = self.box_score_fast(pred, points.reshape(-1, 2)) | |||||
| else: | |||||
| score = self.box_score_slow(pred, contour) | |||||
| if self.box_thresh > score: | |||||
| continue | |||||
| box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2) | |||||
| box, sside = self.get_mini_boxes(box) | |||||
| if sside < self.min_size + 2: | |||||
| continue | |||||
| box = np.array(box) | |||||
| box[:, 0] = np.clip( | |||||
| np.round(box[:, 0] / width * dest_width), 0, dest_width) | |||||
| box[:, 1] = np.clip( | |||||
| np.round(box[:, 1] / height * dest_height), 0, dest_height) | |||||
| boxes.append(box.astype("int32")) | |||||
| scores.append(score) | |||||
| return np.array(boxes, dtype="int32"), scores | |||||
| def unclip(self, box, unclip_ratio): | |||||
| poly = Polygon(box) | |||||
| distance = poly.area * unclip_ratio / poly.length | |||||
| offset = pyclipper.PyclipperOffset() | |||||
| offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | |||||
| expanded = np.array(offset.Execute(distance)) | |||||
| return expanded | |||||
| def get_mini_boxes(self, contour): | |||||
| bounding_box = cv2.minAreaRect(contour) | |||||
| points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) | |||||
| index_1, index_2, index_3, index_4 = 0, 1, 2, 3 | |||||
| if points[1][1] > points[0][1]: | |||||
| index_1 = 0 | |||||
| index_4 = 1 | |||||
| else: | |||||
| index_1 = 1 | |||||
| index_4 = 0 | |||||
| if points[3][1] > points[2][1]: | |||||
| index_2 = 2 | |||||
| index_3 = 3 | |||||
| else: | |||||
| index_2 = 3 | |||||
| index_3 = 2 | |||||
| box = [ | |||||
| points[index_1], points[index_2], points[index_3], points[index_4] | |||||
| ] | |||||
| return box, min(bounding_box[1]) | |||||
| def box_score_fast(self, bitmap, _box): | |||||
| ''' | |||||
| box_score_fast: use bbox mean score as the mean score | |||||
| ''' | |||||
| h, w = bitmap.shape[:2] | |||||
| box = _box.copy() | |||||
| xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) | |||||
| xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) | |||||
| ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) | |||||
| ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) | |||||
| mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) | |||||
| box[:, 0] = box[:, 0] - xmin | |||||
| box[:, 1] = box[:, 1] - ymin | |||||
| cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) | |||||
| return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] | |||||
| def box_score_slow(self, bitmap, contour): | |||||
| ''' | |||||
| box_score_slow: use polyon mean score as the mean score | |||||
| ''' | |||||
| h, w = bitmap.shape[:2] | |||||
| contour = contour.copy() | |||||
| contour = np.reshape(contour, (-1, 2)) | |||||
| xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) | |||||
| xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) | |||||
| ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) | |||||
| ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) | |||||
| mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) | |||||
| contour[:, 0] = contour[:, 0] - xmin | |||||
| contour[:, 1] = contour[:, 1] - ymin | |||||
| cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) | |||||
| return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] | |||||
| def __call__(self, outs_dict, shape_list): | |||||
| pred = outs_dict['maps'] | |||||
| if isinstance(pred, paddle.Tensor): | |||||
| pred = pred.numpy() | |||||
| pred = pred[:, 0, :, :] | |||||
| segmentation = pred > self.thresh | |||||
| boxes_batch = [] | |||||
| for batch_index in range(pred.shape[0]): | |||||
| src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] | |||||
| if self.dilation_kernel is not None: | |||||
| mask = cv2.dilate( | |||||
| np.array(segmentation[batch_index]).astype(np.uint8), | |||||
| self.dilation_kernel) | |||||
| else: | |||||
| mask = segmentation[batch_index] | |||||
| if self.box_type == 'poly': | |||||
| boxes, scores = self.polygons_from_bitmap(pred[batch_index], | |||||
| mask, src_w, src_h) | |||||
| elif self.box_type == 'quad': | |||||
| boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, | |||||
| src_w, src_h) | |||||
| else: | |||||
| raise ValueError( | |||||
| "box_type can only be one of ['quad', 'poly']") | |||||
| boxes_batch.append({'points': boxes}) | |||||
| return boxes_batch | |||||
| class BaseRecLabelDecode(object): | |||||
| """ Convert between text-label and text-index """ | |||||
| def __init__(self, character_dict_path=None, use_space_char=False): | |||||
| self.beg_str = "sos" | |||||
| self.end_str = "eos" | |||||
| self.reverse = False | |||||
| self.character_str = [] | |||||
| if character_dict_path is None: | |||||
| self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" | |||||
| dict_character = list(self.character_str) | |||||
| else: | |||||
| with open(character_dict_path, "rb") as fin: | |||||
| lines = fin.readlines() | |||||
| for line in lines: | |||||
| line = line.decode('utf-8').strip("\n").strip("\r\n") | |||||
| self.character_str.append(line) | |||||
| if use_space_char: | |||||
| self.character_str.append(" ") | |||||
| dict_character = list(self.character_str) | |||||
| if 'arabic' in character_dict_path: | |||||
| self.reverse = True | |||||
| dict_character = self.add_special_char(dict_character) | |||||
| self.dict = {} | |||||
| for i, char in enumerate(dict_character): | |||||
| self.dict[char] = i | |||||
| self.character = dict_character | |||||
| def pred_reverse(self, pred): | |||||
| pred_re = [] | |||||
| c_current = '' | |||||
| for c in pred: | |||||
| if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)): | |||||
| if c_current != '': | |||||
| pred_re.append(c_current) | |||||
| pred_re.append(c) | |||||
| c_current = '' | |||||
| else: | |||||
| c_current += c | |||||
| if c_current != '': | |||||
| pred_re.append(c_current) | |||||
| return ''.join(pred_re[::-1]) | |||||
| def add_special_char(self, dict_character): | |||||
| return dict_character | |||||
| def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | |||||
| """ convert text-index into text-label. """ | |||||
| result_list = [] | |||||
| ignored_tokens = self.get_ignored_tokens() | |||||
| batch_size = len(text_index) | |||||
| for batch_idx in range(batch_size): | |||||
| selection = np.ones(len(text_index[batch_idx]), dtype=bool) | |||||
| if is_remove_duplicate: | |||||
| selection[1:] = text_index[batch_idx][1:] != text_index[ | |||||
| batch_idx][:-1] | |||||
| for ignored_token in ignored_tokens: | |||||
| selection &= text_index[batch_idx] != ignored_token | |||||
| char_list = [ | |||||
| self.character[text_id] | |||||
| for text_id in text_index[batch_idx][selection] | |||||
| ] | |||||
| if text_prob is not None: | |||||
| conf_list = text_prob[batch_idx][selection] | |||||
| else: | |||||
| conf_list = [1] * len(selection) | |||||
| if len(conf_list) == 0: | |||||
| conf_list = [0] | |||||
| text = ''.join(char_list) | |||||
| if self.reverse: # for arabic rec | |||||
| text = self.pred_reverse(text) | |||||
| result_list.append((text, np.mean(conf_list).tolist())) | |||||
| return result_list | |||||
| def get_ignored_tokens(self): | |||||
| return [0] # for ctc blank | |||||
| class CTCLabelDecode(BaseRecLabelDecode): | |||||
| """ Convert between text-label and text-index """ | |||||
| def __init__(self, character_dict_path=None, use_space_char=False, | |||||
| **kwargs): | |||||
| super(CTCLabelDecode, self).__init__(character_dict_path, | |||||
| use_space_char) | |||||
| def __call__(self, preds, label=None, *args, **kwargs): | |||||
| if isinstance(preds, tuple) or isinstance(preds, list): | |||||
| preds = preds[-1] | |||||
| if isinstance(preds, paddle.Tensor): | |||||
| preds = preds.numpy() | |||||
| preds_idx = preds.argmax(axis=2) | |||||
| preds_prob = preds.max(axis=2) | |||||
| text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) | |||||
| if label is None: | |||||
| return text | |||||
| label = self.decode(label) | |||||
| return text, label | |||||
| def add_special_char(self, dict_character): | |||||
| dict_character = ['blank'] + dict_character | |||||
| return dict_character |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| import onnxruntime as ort | |||||
| from huggingface_hub import snapshot_download | |||||
| from .operators import * | |||||
| from rag.settings import cron_logger | |||||
| class Recognizer(object): | |||||
| def __init__(self, label_list, task_name, model_dir=None): | |||||
| """ | |||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||||
| For Linux: | |||||
| export HF_ENDPOINT=https://hf-mirror.com | |||||
| For Windows: | |||||
| Good luck | |||||
| ^_- | |||||
| """ | |||||
| if not model_dir: | |||||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||||
| if not os.path.exists(model_file_path): | |||||
| raise ValueError("not find model file path {}".format( | |||||
| model_file_path)) | |||||
| if ort.get_device() == "GPU": | |||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||||
| else: | |||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||||
| self.label_list = label_list | |||||
| def create_inputs(self, imgs, im_info): | |||||
| """generate input for different model type | |||||
| Args: | |||||
| imgs (list(numpy)): list of images (np.ndarray) | |||||
| im_info (list(dict)): list of image info | |||||
| Returns: | |||||
| inputs (dict): input of model | |||||
| """ | |||||
| inputs = {} | |||||
| im_shape = [] | |||||
| scale_factor = [] | |||||
| if len(imgs) == 1: | |||||
| inputs['image'] = np.array((imgs[0],)).astype('float32') | |||||
| inputs['im_shape'] = np.array( | |||||
| (im_info[0]['im_shape'],)).astype('float32') | |||||
| inputs['scale_factor'] = np.array( | |||||
| (im_info[0]['scale_factor'],)).astype('float32') | |||||
| return inputs | |||||
| for e in im_info: | |||||
| im_shape.append(np.array((e['im_shape'],)).astype('float32')) | |||||
| scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) | |||||
| inputs['im_shape'] = np.concatenate(im_shape, axis=0) | |||||
| inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) | |||||
| imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] | |||||
| max_shape_h = max([e[0] for e in imgs_shape]) | |||||
| max_shape_w = max([e[1] for e in imgs_shape]) | |||||
| padding_imgs = [] | |||||
| for img in imgs: | |||||
| im_c, im_h, im_w = img.shape[:] | |||||
| padding_im = np.zeros( | |||||
| (im_c, max_shape_h, max_shape_w), dtype=np.float32) | |||||
| padding_im[:, :im_h, :im_w] = img | |||||
| padding_imgs.append(padding_im) | |||||
| inputs['image'] = np.stack(padding_imgs, axis=0) | |||||
| return inputs | |||||
| def preprocess(self, image_list): | |||||
| preprocess_ops = [] | |||||
| for op_info in [ | |||||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||||
| {'type': 'Permute'}, | |||||
| {'stride': 32, 'type': 'PadStride'} | |||||
| ]: | |||||
| new_op_info = op_info.copy() | |||||
| op_type = new_op_info.pop('type') | |||||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||||
| inputs = [] | |||||
| for im_path in image_list: | |||||
| im, im_info = preprocess(im_path, preprocess_ops) | |||||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||||
| return inputs | |||||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||||
| res = [] | |||||
| imgs = [] | |||||
| for i in range(len(image_list)): | |||||
| if not isinstance(image_list[i], np.ndarray): | |||||
| imgs.append(np.array(image_list[i])) | |||||
| else: imgs.append(image_list[i]) | |||||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||||
| for i in range(batch_loop_cnt): | |||||
| start_index = i * batch_size | |||||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||||
| batch_image_list = imgs[start_index:end_index] | |||||
| inputs = self.preprocess(batch_image_list) | |||||
| for ins in inputs: | |||||
| bb = [] | |||||
| for b in self.ort_sess.run(None, ins)[0]: | |||||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||||
| if score < thr: | |||||
| continue | |||||
| if clsid >= len(self.label_list): | |||||
| cron_logger.warning(f"bad category id") | |||||
| continue | |||||
| bb.append({ | |||||
| "type": self.label_list[clsid].lower(), | |||||
| "bbox": [float(t) for t in bbox.tolist()], | |||||
| "score": float(score) | |||||
| }) | |||||
| res.append(bb) | |||||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||||
| return res |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| import PIL | |||||
| from PIL import ImageDraw | |||||
| def save_results(image_list, results, labels, output_dir='output/', threshold=0.5): | |||||
| if not os.path.exists(output_dir): | |||||
| os.makedirs(output_dir) | |||||
| for idx, im in enumerate(image_list): | |||||
| im = draw_box(im, results[idx], labels, threshold=threshold) | |||||
| out_path = os.path.join(output_dir, f"{idx}.jpg") | |||||
| im.save(out_path, quality=95) | |||||
| print("save result to: " + out_path) | |||||
| def draw_box(im, result, lables, threshold=0.5): | |||||
| draw_thickness = min(im.size) // 320 | |||||
| draw = ImageDraw.Draw(im) | |||||
| color_list = get_color_map_list(len(lables)) | |||||
| clsid2color = {n.lower():color_list[i] for i,n in enumerate(lables)} | |||||
| result = [r for r in result if r["score"] >= threshold] | |||||
| for dt in result: | |||||
| color = tuple(clsid2color[dt["type"]]) | |||||
| xmin, ymin, xmax, ymax = dt["bbox"] | |||||
| draw.line( | |||||
| [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), | |||||
| (xmin, ymin)], | |||||
| width=draw_thickness, | |||||
| fill=color) | |||||
| # draw label | |||||
| text = "{} {:.4f}".format(dt["type"], dt["score"]) | |||||
| tw, th = imagedraw_textsize_c(draw, text) | |||||
| draw.rectangle( | |||||
| [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) | |||||
| draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) | |||||
| return im | |||||
| def get_color_map_list(num_classes): | |||||
| """ | |||||
| Args: | |||||
| num_classes (int): number of class | |||||
| Returns: | |||||
| color_map (list): RGB color list | |||||
| """ | |||||
| color_map = num_classes * [0, 0, 0] | |||||
| for i in range(0, num_classes): | |||||
| j = 0 | |||||
| lab = i | |||||
| while lab: | |||||
| color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) | |||||
| color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) | |||||
| color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) | |||||
| j += 1 | |||||
| lab >>= 3 | |||||
| color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] | |||||
| return color_map | |||||
| def imagedraw_textsize_c(draw, text): | |||||
| if int(PIL.__version__.split('.')[0]) < 10: | |||||
| tw, th = draw.textsize(text) | |||||
| else: | |||||
| left, top, right, bottom = draw.textbbox((0, 0), text) | |||||
| tw, th = right - left, bottom - top | |||||
| return tw, th |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | import copy | ||||
| import random | |||||
| import re | import re | ||||
| import numpy as np | |||||
| from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \ | |||||
| from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, \ | |||||
| hierarchical_merge, make_colon_as_title, naive_merge, random_choices | hierarchical_merge, make_colon_as_title, naive_merge, random_choices | ||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.docx_parser import HuDocxParser | |||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import PdfParser, DocxParser | |||||
| class Pdf(HuParser): | |||||
| class Pdf(PdfParser): | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| self.__images__( | self.__images__( | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| start = timer() | start = timer() | ||||
| self._layouts_paddle(zoomin) | |||||
| self._layouts_rec(zoomin) | |||||
| callback(0.47, "Layout analysis finished") | callback(0.47, "Layout analysis finished") | ||||
| print("paddle layouts:", timer() - start) | print("paddle layouts:", timer() - start) | ||||
| self._table_transformer_job(zoomin) | self._table_transformer_job(zoomin) | ||||
| sections,tbls = [], [] | sections,tbls = [], [] | ||||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | if re.search(r"\.docx?$", filename, re.IGNORECASE): | ||||
| callback(0.1, "Start to parse.") | callback(0.1, "Start to parse.") | ||||
| doc_parser = HuDocxParser() | |||||
| doc_parser = DocxParser() | |||||
| # TODO: table of contents need to be removed | # TODO: table of contents need to be removed | ||||
| sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | ||||
| remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) | remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | import copy | ||||
| import re | import re | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from docx import Document | from docx import Document | ||||
| from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \ | |||||
| from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \ | |||||
| make_colon_as_title | make_colon_as_title | ||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.docx_parser import HuDocxParser | |||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import PdfParser, DocxParser | |||||
| from rag.settings import cron_logger | from rag.settings import cron_logger | ||||
| class Docx(HuDocxParser): | |||||
| class Docx(DocxParser): | |||||
| def __init__(self): | def __init__(self): | ||||
| pass | pass | ||||
| return [l for l in lines if l] | return [l for l in lines if l] | ||||
| class Pdf(HuParser): | |||||
| class Pdf(PdfParser): | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| self.__images__( | self.__images__( | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| start = timer() | start = timer() | ||||
| self._layouts_paddle(zoomin) | |||||
| self._layouts_rec(zoomin) | |||||
| callback(0.77, "Layout analysis finished") | callback(0.77, "Layout analysis finished") | ||||
| cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1))) | cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1))) | ||||
| self._naive_vertical_merge() | self._naive_vertical_merge() |
| import copy | import copy | ||||
| import re | import re | ||||
| from rag.parser import tokenize | |||||
| from deepdoc.parser import tokenize | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import PdfParser | |||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| class Pdf(HuParser): | |||||
| class Pdf(PdfParser): | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| self.__images__( | self.__images__( | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| start = timer() | start = timer() | ||||
| self._layouts_paddle(zoomin) | |||||
| self._layouts_rec(zoomin) | |||||
| callback(0.5, "Layout analysis finished.") | callback(0.5, "Layout analysis finished.") | ||||
| print("paddle layouts:", timer() - start) | print("paddle layouts:", timer() - start) | ||||
| self._table_transformer_job(zoomin) | self._table_transformer_job(zoomin) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | import copy | ||||
| import re | import re | ||||
| from rag.app import laws | from rag.app import laws | ||||
| from rag.parser import is_english, tokenize, naive_merge | |||||
| from deepdoc.parser import is_english, tokenize, naive_merge | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import PdfParser | |||||
| from rag.settings import cron_logger | from rag.settings import cron_logger | ||||
| class Pdf(HuParser): | |||||
| class Pdf(PdfParser): | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| self.__images__( | self.__images__( | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| start = timer() | start = timer() | ||||
| self._layouts_paddle(zoomin) | |||||
| self._layouts_rec(zoomin) | |||||
| callback(0.77, "Layout analysis finished") | callback(0.77, "Layout analysis finished") | ||||
| cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | ||||
| self._naive_vertical_merge() | self._naive_vertical_merge() |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | import copy | ||||
| import re | import re | ||||
| from collections import Counter | from collections import Counter | ||||
| from api.db import ParserType | from api.db import ParserType | ||||
| from rag.parser import tokenize | |||||
| from deepdoc.parser import tokenize | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import PdfParser | |||||
| import numpy as np | import numpy as np | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| class Pdf(HuParser): | |||||
| class Pdf(PdfParser): | |||||
| def __init__(self): | def __init__(self): | ||||
| self.model_speciess = ParserType.PAPER.value | self.model_speciess = ParserType.PAPER.value | ||||
| super().__init__() | super().__init__() | ||||
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| start = timer() | start = timer() | ||||
| self._layouts_paddle(zoomin) | |||||
| self._layouts_rec(zoomin) | |||||
| callback(0.47, "Layout analysis finished") | callback(0.47, "Layout analysis finished") | ||||
| print("paddle layouts:", timer() - start) | print("paddle layouts:", timer() - start) | ||||
| self._table_transformer_job(zoomin) | self._table_transformer_job(zoomin) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | import copy | ||||
| import re | import re | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from pptx import Presentation | from pptx import Presentation | ||||
| from rag.parser import tokenize, is_english | |||||
| from deepdoc.parser import tokenize, is_english | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import PdfParser | |||||
| class Ppt(object): | class Ppt(object): | ||||
| return [(txts[i], imgs[i]) for i in range(len(txts))] | return [(txts[i], imgs[i]) for i in range(len(txts))] | ||||
| class Pdf(HuParser): | |||||
| class Pdf(PdfParser): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) | assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) | ||||
| res = [] | res = [] | ||||
| #################### More precisely ################### | #################### More precisely ################### | ||||
| # self._layouts_paddle(zoomin) | |||||
| # self._layouts_rec(zoomin) | |||||
| # self._text_merge() | # self._text_merge() | ||||
| # pages = {} | # pages = {} | ||||
| # for b in self.boxes: | # for b in self.boxes: |
| import random | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import re | import re | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from nltk import word_tokenize | from nltk import word_tokenize | ||||
| from openpyxl import load_workbook | from openpyxl import load_workbook | ||||
| from rag.parser import is_english, random_choices | |||||
| from deepdoc.parser import is_english, random_choices | |||||
| from rag.nlp import huqie, stemmer | from rag.nlp import huqie, stemmer | ||||
| from deepdoc.parser import ExcelParser | |||||
| class Excel(object): | |||||
| class Excel(ExcelParser): | |||||
| def __call__(self, fnm, binary=None, callback=None): | def __call__(self, fnm, binary=None, callback=None): | ||||
| if not binary: | if not binary: | ||||
| wb = load_workbook(fnm) | wb = load_workbook(fnm) |
| import copy | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import base64 | |||||
| import datetime | |||||
| import json | import json | ||||
| import os | |||||
| import re | import re | ||||
| import pandas as pd | |||||
| import requests | import requests | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.settings import stat_logger | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from deepdoc.parser.resume import refactor | |||||
| from deepdoc.parser.resume import step_one, step_two | |||||
| from rag.settings import cron_logger | from rag.settings import cron_logger | ||||
| from rag.utils import rmSpace | from rag.utils import rmSpace | ||||
| forbidden_select_fields4resume = [ | forbidden_select_fields4resume = [ | ||||
| "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" | "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" | ||||
| ] | ] | ||||
| def remote_call(filename, binary): | |||||
| q = { | |||||
| "header": { | |||||
| "uid": 1, | |||||
| "user": "kevinhu", | |||||
| "log_id": filename | |||||
| }, | |||||
| "request": { | |||||
| "p": { | |||||
| "request_id": "1", | |||||
| "encrypt_type": "base64", | |||||
| "filename": filename, | |||||
| "langtype": '', | |||||
| "fileori": base64.b64encode(binary.stream.read()).decode('utf-8') | |||||
| }, | |||||
| "c": "resume_parse_module", | |||||
| "m": "resume_parse" | |||||
| } | |||||
| } | |||||
| for _ in range(3): | |||||
| try: | |||||
| resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q)) | |||||
| resume = resume.json()["response"]["results"] | |||||
| resume = refactor(resume) | |||||
| for k in ["education", "work", "project", "training", "skill", "certificate", "language"]: | |||||
| if not resume.get(k) and k in resume: del resume[k] | |||||
| resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", | |||||
| "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) | |||||
| resume = step_two.parse(resume) | |||||
| return resume | |||||
| except Exception as e: | |||||
| cron_logger.error("Resume parser error: "+str(e)) | |||||
| return {} | |||||
| def chunk(filename, binary=None, callback=None, **kwargs): | def chunk(filename, binary=None, callback=None, **kwargs): | ||||
| """ | """ | ||||
| The supported file formats are pdf, docx and txt. | The supported file formats are pdf, docx and txt. | ||||
| To maximize the effectiveness, parse the resume correctly, | |||||
| please visit https://github.com/infiniflow/ragflow, and sign in the our demo web-site | |||||
| to get token. It's FREE! | |||||
| Set INFINIFLOW_SERVER and INFINIFLOW_TOKEN in '.env' file or | |||||
| using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN in docker container. | |||||
| To maximize the effectiveness, parse the resume correctly, please concat us: https://github.com/infiniflow/ragflow | |||||
| """ | """ | ||||
| if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): | if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): | ||||
| raise NotImplementedError("file type not supported yet(pdf supported)") | raise NotImplementedError("file type not supported yet(pdf supported)") | ||||
| url = os.environ.get("INFINIFLOW_SERVER") | |||||
| token = os.environ.get("INFINIFLOW_TOKEN") | |||||
| if not url or not token: | |||||
| stat_logger.warning( | |||||
| "INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.") | |||||
| return [] | |||||
| if not binary: | if not binary: | ||||
| with open(filename, "rb") as f: | with open(filename, "rb") as f: | ||||
| binary = f.read() | binary = f.read() | ||||
| def remote_call(): | |||||
| nonlocal filename, binary | |||||
| for _ in range(3): | |||||
| try: | |||||
| res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)], | |||||
| headers={"Authorization": token}, timeout=180) | |||||
| res = res.json() | |||||
| if res["retcode"] != 0: | |||||
| raise RuntimeError(res["retmsg"]) | |||||
| return res["data"] | |||||
| except RuntimeError as e: | |||||
| raise e | |||||
| except Exception as e: | |||||
| cron_logger.error("resume parsing:" + str(e)) | |||||
| callback(0.2, "Resume parsing is going on...") | callback(0.2, "Resume parsing is going on...") | ||||
| resume = remote_call() | |||||
| resume = remote_call(filename, binary) | |||||
| if len(resume.keys()) < 7: | if len(resume.keys()) < 7: | ||||
| callback(-1, "Resume is not successfully parsed.") | callback(-1, "Resume is not successfully parsed.") | ||||
| return [] | return [] |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import copy | import copy | ||||
| import re | import re | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from dateutil.parser import parse as datetime_parse | from dateutil.parser import parse as datetime_parse | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from rag.parser import is_english, tokenize | |||||
| from rag.nlp import huqie, stemmer | |||||
| from deepdoc.parser import is_english, tokenize | |||||
| from rag.nlp import huqie | |||||
| from deepdoc.parser import ExcelParser | |||||
| class Excel(object): | |||||
| class Excel(ExcelParser): | |||||
| def __call__(self, fnm, binary=None, callback=None): | def __call__(self, fnm, binary=None, callback=None): | ||||
| if not binary: | if not binary: | ||||
| wb = load_workbook(fnm) | wb = load_workbook(fnm) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import re | import re | ||||
| import os | import os | ||||
| import copy | import copy | ||||
| import sys | import sys | ||||
| sys.path.append(os.path.dirname(__file__) + "/../") | sys.path.append(os.path.dirname(__file__) + "/../") | ||||
| if sys.argv[1].split(".")[-1].lower() == "pdf": | if sys.argv[1].split(".")[-1].lower() == "pdf": | ||||
| from parser import PdfParser | |||||
| from deepdoc.parser import PdfParser | |||||
| ckr = PdfChunker(PdfParser()) | ckr = PdfChunker(PdfParser()) | ||||
| if sys.argv[1].split(".")[-1].lower().find("doc") >= 0: | if sys.argv[1].split(".")[-1].lower().find("doc") >= 0: | ||||
| from parser import DocxParser | |||||
| from deepdoc.parser import DocxParser | |||||
| ckr = DocxChunker(DocxParser()) | ckr = DocxChunker(DocxParser()) | ||||
| if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0: | if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0: | ||||
| from parser import ExcelParser | |||||
| from deepdoc.parser import ExcelParser | |||||
| ckr = ExcelChunker(ExcelParser()) | ckr = ExcelChunker(ExcelParser()) | ||||
| # ckr.html(sys.argv[1]) | # ckr.html(sys.argv[1]) |
| from api.db.db_models import Task | from api.db.db_models import Task | ||||
| from api.db.db_utils import bulk_insert_into_db | from api.db.db_utils import bulk_insert_into_db | ||||
| from api.db.services.task_service import TaskService | from api.db.services.task_service import TaskService | ||||
| from rag.parser.pdf_parser import HuParser | |||||
| from deepdoc.parser import HuParser | |||||
| from rag.settings import cron_logger | from rag.settings import cron_logger | ||||
| from rag.utils import MINIO | from rag.utils import MINIO | ||||
| from rag.utils import findMaxTm | from rag.utils import findMaxTm |