* rename vision, add layour and tsr recognizer * trivial fixingtags/v0.1.0
| @@ -34,7 +34,6 @@ from rag.utils import num_tokens_from_string, encoder, rmSpace | |||
| @manager.route('/set', methods=['POST']) | |||
| @login_required | |||
| @validate_request("dialog_id") | |||
| def set_conversation(): | |||
| req = request.json | |||
| conv_id = req.get("conversation_id") | |||
| @@ -145,7 +144,7 @@ def message_fit_in(msg, max_length=4000): | |||
| @manager.route('/completion', methods=['POST']) | |||
| @login_required | |||
| @validate_request("dialog_id", "messages") | |||
| @validate_request("conversation_id", "messages") | |||
| def completion(): | |||
| req = request.json | |||
| msg = [] | |||
| @@ -154,12 +153,20 @@ def completion(): | |||
| if m["role"] == "assistant" and not msg: continue | |||
| msg.append({"role": m["role"], "content": m["content"]}) | |||
| try: | |||
| e, dia = DialogService.get_by_id(req["dialog_id"]) | |||
| e, conv = ConversationService.get_by_id(req["conversation_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Conversation not found!") | |||
| conv.message.append(msg[-1]) | |||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| del req["dialog_id"] | |||
| del req["conversation_id"] | |||
| del req["messages"] | |||
| return get_json_result(data=chat(dia, msg, **req)) | |||
| ans = chat(dia, msg, **req) | |||
| conv.reference.append(ans["reference"]) | |||
| conv.message.append({"role": "assistant", "content": ans["answer"]}) | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| return get_json_result(data=ans) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -194,8 +201,8 @@ def chat(dialog, messages, **kwargs): | |||
| dialog.vector_similarity_weight, top=1024, aggs=False) | |||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||
| if not knowledges and prompt_config["empty_response"]: | |||
| return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | |||
| if not knowledges and prompt_config.get("empty_response"): | |||
| return {"answer": prompt_config["empty_response"], "reference": kbinfos} | |||
| kwargs["knowledge"] = "\n".join(knowledges) | |||
| gen_conf = dialog.llm_setting | |||
| @@ -205,7 +212,8 @@ def chat(dialog, messages, **kwargs): | |||
| gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | |||
| answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | |||
| answer = retrievaler.insert_citations(answer, | |||
| if knowledges: | |||
| answer = retrievaler.insert_citations(answer, | |||
| [ck["content_ltks"] for ck in kbinfos["chunks"]], | |||
| [ck["vector"] for ck in kbinfos["chunks"]], | |||
| embd_mdl, | |||
| @@ -213,7 +221,7 @@ def chat(dialog, messages, **kwargs): | |||
| vtweight=dialog.vector_similarity_weight) | |||
| for c in kbinfos["chunks"]: | |||
| if c.get("vector"): del c["vector"] | |||
| return {"answer": answer, "retrieval": kbinfos} | |||
| return {"answer": answer, "reference": kbinfos} | |||
| def use_sql(question, field_map, tenant_id, chat_mdl): | |||
| @@ -94,11 +94,11 @@ def list(): | |||
| model_type = request.args.get("model_type") | |||
| try: | |||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||
| mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key]) | |||
| facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) | |||
| llms = LLMService.get_all() | |||
| llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | |||
| for m in llms: | |||
| m["available"] = m["llm_name"] in mdlnms | |||
| m["available"] = m["fid"] in facts | |||
| res = {} | |||
| for m in llms: | |||
| @@ -500,7 +500,7 @@ class Document(DataBaseModel): | |||
| token_num = IntegerField(default=0) | |||
| chunk_num = IntegerField(default=0) | |||
| progress = FloatField(default=0) | |||
| progress_msg = CharField(max_length=512, null=True, help_text="process message", default="") | |||
| progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") | |||
| process_begin_at = DateTimeField(null=True) | |||
| process_duation = FloatField(default=0) | |||
| run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") | |||
| @@ -518,7 +518,7 @@ class Task(DataBaseModel): | |||
| begin_at = DateTimeField(null=True) | |||
| process_duation = FloatField(default=0) | |||
| progress = FloatField(default=0) | |||
| progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") | |||
| progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") | |||
| class Dialog(DataBaseModel): | |||
| @@ -561,6 +561,7 @@ class Conversation(DataBaseModel): | |||
| dialog_id = CharField(max_length=32, null=False, index=True) | |||
| name = CharField(max_length=255, null=True, help_text="converastion name") | |||
| message = JSONField(null=True) | |||
| reference = JSONField(null=True, default=[]) | |||
| class Meta: | |||
| db_table = "conversation" | |||
| @@ -75,7 +75,7 @@ class TenantLLMService(CommonService): | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| if not model_config: | |||
| raise LookupError("Model({}) not found".format(mdlnm)) | |||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||
| model_config = model_config.to_dict() | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| if model_config["llm_factory"] not in EmbeddingModel: | |||
| @@ -0,0 +1,4 @@ | |||
| from .ocr import OCR | |||
| from .recognizer import Recognizer | |||
| from .layout_recognizer import LayoutRecognizer | |||
| from .table_structure_recognizer import TableStructureRecognizer | |||
| @@ -0,0 +1,119 @@ | |||
| import os | |||
| import re | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from .recognizer import Recognizer | |||
| class LayoutRecognizer(Recognizer): | |||
| def __init__(self, domain): | |||
| self.layout_labels = [ | |||
| "_background_", | |||
| "Text", | |||
| "Title", | |||
| "Figure", | |||
| "Figure caption", | |||
| "Table", | |||
| "Table caption", | |||
| "Header", | |||
| "Footer", | |||
| "Reference", | |||
| "Equation", | |||
| ] | |||
| super().__init__(self.layout_labels, domain, | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | |||
| def __is_garbage(b): | |||
| patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | |||
| r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", | |||
| "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}", | |||
| "\\(cid *: *[0-9]+ *\\)" | |||
| ] | |||
| return any([re.search(p, b["text"]) for p in patt]) | |||
| layouts = super().__call__(image_list, thr, batch_size) | |||
| # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7) | |||
| assert len(image_list) == len(ocr_res) | |||
| # Tag layout type | |||
| boxes = [] | |||
| assert len(image_list) == len(layouts) | |||
| garbages = {} | |||
| page_layout = [] | |||
| for pn, lts in enumerate(layouts): | |||
| bxs = ocr_res[pn] | |||
| lts = [{"type": b["type"], | |||
| "score": float(b["score"]), | |||
| "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor, | |||
| "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, | |||
| "page_number": pn, | |||
| } for b in lts] | |||
| lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2) | |||
| lts = self.layouts_cleanup(bxs, lts) | |||
| page_layout.append(lts) | |||
| # Tag layout type, layouts are ready | |||
| def findLayout(ty): | |||
| nonlocal bxs, lts, self | |||
| lts_ = [lt for lt in lts if lt["type"] == ty] | |||
| i = 0 | |||
| while i < len(bxs): | |||
| if bxs[i].get("layout_type"): | |||
| i += 1 | |||
| continue | |||
| if __is_garbage(bxs[i]): | |||
| bxs.pop(i) | |||
| continue | |||
| ii = self.find_overlapped_with_threashold(bxs[i], lts_, | |||
| thr=0.4) | |||
| if ii is None: # belong to nothing | |||
| bxs[i]["layout_type"] = "" | |||
| i += 1 | |||
| continue | |||
| lts_[ii]["visited"] = True | |||
| if lts_[ii]["type"] in ["footer", "header", "reference"]: | |||
| if lts_[ii]["type"] not in garbages: | |||
| garbages[lts_[ii]["type"]] = [] | |||
| garbages[lts_[ii]["type"]].append(bxs[i]["text"]) | |||
| bxs.pop(i) | |||
| continue | |||
| bxs[i]["layoutno"] = f"{ty}-{ii}" | |||
| bxs[i]["layout_type"] = lts_[ii]["type"] | |||
| i += 1 | |||
| for lt in ["footer", "header", "reference", "figure caption", | |||
| "table caption", "title", "text", "table", "figure", "equation"]: | |||
| findLayout(lt) | |||
| # add box to figure layouts which has not text box | |||
| for i, lt in enumerate( | |||
| [lt for lt in lts if lt["type"] == "figure"]): | |||
| if lt.get("visited"): | |||
| continue | |||
| lt = deepcopy(lt) | |||
| del lt["type"] | |||
| lt["text"] = "" | |||
| lt["layout_type"] = "figure" | |||
| lt["layoutno"] = f"figure-{i}" | |||
| bxs.append(lt) | |||
| boxes.extend(bxs) | |||
| ocr_res = boxes | |||
| garbag_set = set() | |||
| for k in garbages.keys(): | |||
| garbages[k] = Counter(garbages[k]) | |||
| for g, c in garbages[k].items(): | |||
| if c > 1: | |||
| garbag_set.add(g) | |||
| ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | |||
| return ocr_res, page_layout | |||
| @@ -74,7 +74,7 @@ class TextRecognizer(object): | |||
| self.rec_batch_num = 16 | |||
| postprocess_params = { | |||
| 'name': 'CTCLabelDecode', | |||
| "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"), | |||
| "character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"), | |||
| "use_space_char": True | |||
| } | |||
| self.postprocess_op = build_post_process(postprocess_params) | |||
| @@ -450,7 +450,7 @@ class OCR(object): | |||
| """ | |||
| if not model_dir: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") | |||
| self.text_detector = TextDetector(model_dir) | |||
| self.text_recognizer = TextRecognizer(model_dir) | |||
| @@ -0,0 +1,327 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from copy import deepcopy | |||
| import onnxruntime as ort | |||
| from huggingface_hub import snapshot_download | |||
| from . import seeit | |||
| from .operators import * | |||
| from rag.settings import cron_logger | |||
| class Recognizer(object): | |||
| def __init__(self, label_list, task_name, model_dir=None): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| For Linux: | |||
| export HF_ENDPOINT=https://hf-mirror.com | |||
| For Windows: | |||
| Good luck | |||
| ^_- | |||
| """ | |||
| if not model_dir: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||
| if not os.path.exists(model_file_path): | |||
| raise ValueError("not find model file path {}".format( | |||
| model_file_path)) | |||
| if ort.get_device() == "GPU": | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||
| else: | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||
| self.label_list = label_list | |||
| @staticmethod | |||
| def sort_Y_firstly(arr, threashold): | |||
| # sort using y1 first and then x1 | |||
| arr = sorted(arr, key=lambda r: (r["top"], r["x0"])) | |||
| for i in range(len(arr) - 1): | |||
| for j in range(i, -1, -1): | |||
| # restore the order using th | |||
| if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \ | |||
| and arr[j + 1]["x0"] < arr[j]["x0"]: | |||
| tmp = deepcopy(arr[j]) | |||
| arr[j] = deepcopy(arr[j + 1]) | |||
| arr[j + 1] = deepcopy(tmp) | |||
| return arr | |||
| @staticmethod | |||
| def sort_X_firstly(arr, threashold, copy=True): | |||
| # sort using y1 first and then x1 | |||
| arr = sorted(arr, key=lambda r: (r["x0"], r["top"])) | |||
| for i in range(len(arr) - 1): | |||
| for j in range(i, -1, -1): | |||
| # restore the order using th | |||
| if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \ | |||
| and arr[j + 1]["top"] < arr[j]["top"]: | |||
| tmp = deepcopy(arr[j]) if copy else arr[j] | |||
| arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1] | |||
| arr[j + 1] = deepcopy(tmp) if copy else tmp | |||
| return arr | |||
| @staticmethod | |||
| def sort_C_firstly(arr, thr=0): | |||
| # sort using y1 first and then x1 | |||
| # sorted(arr, key=lambda r: (r["x0"], r["top"])) | |||
| arr = Recognizer.sort_X_firstly(arr, thr) | |||
| for i in range(len(arr) - 1): | |||
| for j in range(i, -1, -1): | |||
| # restore the order using th | |||
| if "C" not in arr[j] or "C" not in arr[j + 1]: | |||
| continue | |||
| if arr[j + 1]["C"] < arr[j]["C"] \ | |||
| or ( | |||
| arr[j + 1]["C"] == arr[j]["C"] | |||
| and arr[j + 1]["top"] < arr[j]["top"] | |||
| ): | |||
| tmp = arr[j] | |||
| arr[j] = arr[j + 1] | |||
| arr[j + 1] = tmp | |||
| return arr | |||
| return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"])) | |||
| @staticmethod | |||
| def sort_R_firstly(arr, thr=0): | |||
| # sort using y1 first and then x1 | |||
| # sorted(arr, key=lambda r: (r["top"], r["x0"])) | |||
| arr = Recognizer.sort_Y_firstly(arr, thr) | |||
| for i in range(len(arr) - 1): | |||
| for j in range(i, -1, -1): | |||
| if "R" not in arr[j] or "R" not in arr[j + 1]: | |||
| continue | |||
| if arr[j + 1]["R"] < arr[j]["R"] \ | |||
| or ( | |||
| arr[j + 1]["R"] == arr[j]["R"] | |||
| and arr[j + 1]["x0"] < arr[j]["x0"] | |||
| ): | |||
| tmp = arr[j] | |||
| arr[j] = arr[j + 1] | |||
| arr[j + 1] = tmp | |||
| return arr | |||
| @staticmethod | |||
| def overlapped_area(a, b, ratio=True): | |||
| tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"] | |||
| if b["x0"] > x1 or b["x1"] < x0: | |||
| return 0 | |||
| if b["bottom"] < tp or b["top"] > btm: | |||
| return 0 | |||
| x0_ = max(b["x0"], x0) | |||
| x1_ = min(b["x1"], x1) | |||
| assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format( | |||
| tp, btm, x0, x1, b) | |||
| tp_ = max(b["top"], tp) | |||
| btm_ = min(b["bottom"], btm) | |||
| assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format( | |||
| tp, btm, x0, x1, b) | |||
| ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \ | |||
| x0 != 0 and btm - tp != 0 else 0 | |||
| if ov > 0 and ratio: | |||
| ov /= (x1 - x0) * (btm - tp) | |||
| return ov | |||
| @staticmethod | |||
| def layouts_cleanup(boxes, layouts, far=2, thr=0.7): | |||
| def notOverlapped(a, b): | |||
| return any([a["x1"] < b["x0"], | |||
| a["x0"] > b["x1"], | |||
| a["bottom"] < b["top"], | |||
| a["top"] > b["bottom"]]) | |||
| i = 0 | |||
| while i + 1 < len(layouts): | |||
| j = i + 1 | |||
| while j < min(i + far, len(layouts)) \ | |||
| and (layouts[i].get("type", "") != layouts[j].get("type", "") | |||
| or notOverlapped(layouts[i], layouts[j])): | |||
| j += 1 | |||
| if j >= min(i + far, len(layouts)): | |||
| i += 1 | |||
| continue | |||
| if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \ | |||
| and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr: | |||
| i += 1 | |||
| continue | |||
| if layouts[i].get("score") and layouts[j].get("score"): | |||
| if layouts[i]["score"] > layouts[j]["score"]: | |||
| layouts.pop(j) | |||
| else: | |||
| layouts.pop(i) | |||
| continue | |||
| area_i, area_i_1 = 0, 0 | |||
| for b in boxes: | |||
| if not notOverlapped(b, layouts[i]): | |||
| area_i += Recognizer.overlapped_area(b, layouts[i], False) | |||
| if not notOverlapped(b, layouts[j]): | |||
| area_i_1 += Recognizer.overlapped_area(b, layouts[j], False) | |||
| if area_i > area_i_1: | |||
| layouts.pop(j) | |||
| else: | |||
| layouts.pop(i) | |||
| return layouts | |||
| def create_inputs(self, imgs, im_info): | |||
| """generate input for different model type | |||
| Args: | |||
| imgs (list(numpy)): list of images (np.ndarray) | |||
| im_info (list(dict)): list of image info | |||
| Returns: | |||
| inputs (dict): input of model | |||
| """ | |||
| inputs = {} | |||
| im_shape = [] | |||
| scale_factor = [] | |||
| if len(imgs) == 1: | |||
| inputs['image'] = np.array((imgs[0],)).astype('float32') | |||
| inputs['im_shape'] = np.array( | |||
| (im_info[0]['im_shape'],)).astype('float32') | |||
| inputs['scale_factor'] = np.array( | |||
| (im_info[0]['scale_factor'],)).astype('float32') | |||
| return inputs | |||
| for e in im_info: | |||
| im_shape.append(np.array((e['im_shape'],)).astype('float32')) | |||
| scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) | |||
| inputs['im_shape'] = np.concatenate(im_shape, axis=0) | |||
| inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) | |||
| imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] | |||
| max_shape_h = max([e[0] for e in imgs_shape]) | |||
| max_shape_w = max([e[1] for e in imgs_shape]) | |||
| padding_imgs = [] | |||
| for img in imgs: | |||
| im_c, im_h, im_w = img.shape[:] | |||
| padding_im = np.zeros( | |||
| (im_c, max_shape_h, max_shape_w), dtype=np.float32) | |||
| padding_im[:, :im_h, :im_w] = img | |||
| padding_imgs.append(padding_im) | |||
| inputs['image'] = np.stack(padding_imgs, axis=0) | |||
| return inputs | |||
| @staticmethod | |||
| def find_overlapped(box, boxes_sorted_by_y, naive=False): | |||
| if not boxes_sorted_by_y: | |||
| return | |||
| bxs = boxes_sorted_by_y | |||
| s, e, ii = 0, len(bxs), 0 | |||
| while s < e and not naive: | |||
| ii = (e + s) // 2 | |||
| pv = bxs[ii] | |||
| if box["bottom"] < pv["top"]: | |||
| e = ii | |||
| continue | |||
| if box["top"] > pv["bottom"]: | |||
| s = ii + 1 | |||
| continue | |||
| break | |||
| while s < ii: | |||
| if box["top"] > bxs[s]["bottom"]: | |||
| s += 1 | |||
| break | |||
| while e - 1 > ii: | |||
| if box["bottom"] < bxs[e - 1]["top"]: | |||
| e -= 1 | |||
| break | |||
| max_overlaped_i, max_overlaped = None, 0 | |||
| for i in range(s, e): | |||
| ov = Recognizer.overlapped_area(bxs[i], box) | |||
| if ov <= max_overlaped: | |||
| continue | |||
| max_overlaped_i = i | |||
| max_overlaped = ov | |||
| return max_overlaped_i | |||
| @staticmethod | |||
| def find_overlapped_with_threashold(box, boxes, thr=0.3): | |||
| if not boxes: | |||
| return | |||
| max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0 | |||
| s, e = 0, len(boxes) | |||
| for i in range(s, e): | |||
| ov = Recognizer.overlapped_area(box, boxes[i]) | |||
| _ov = Recognizer.overlapped_area(boxes[i], box) | |||
| if (ov, _ov) < (max_overlaped, _max_overlaped): | |||
| continue | |||
| max_overlaped_i = i | |||
| max_overlaped = ov | |||
| _max_overlaped = _ov | |||
| return max_overlaped_i | |||
| def preprocess(self, image_list): | |||
| preprocess_ops = [] | |||
| for op_info in [ | |||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||
| {'type': 'Permute'}, | |||
| {'stride': 32, 'type': 'PadStride'} | |||
| ]: | |||
| new_op_info = op_info.copy() | |||
| op_type = new_op_info.pop('type') | |||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||
| inputs = [] | |||
| for im_path in image_list: | |||
| im, im_info = preprocess(im_path, preprocess_ops) | |||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||
| return inputs | |||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||
| res = [] | |||
| imgs = [] | |||
| for i in range(len(image_list)): | |||
| if not isinstance(image_list[i], np.ndarray): | |||
| imgs.append(np.array(image_list[i])) | |||
| else: imgs.append(image_list[i]) | |||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||
| for i in range(batch_loop_cnt): | |||
| start_index = i * batch_size | |||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||
| batch_image_list = imgs[start_index:end_index] | |||
| inputs = self.preprocess(batch_image_list) | |||
| for ins in inputs: | |||
| bb = [] | |||
| for b in self.ort_sess.run(None, ins)[0]: | |||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||
| if score < thr: | |||
| continue | |||
| if clsid >= len(self.label_list): | |||
| cron_logger.warning(f"bad category id") | |||
| continue | |||
| bb.append({ | |||
| "type": self.label_list[clsid].lower(), | |||
| "bbox": [float(t) for t in bbox.tolist()], | |||
| "score": float(score) | |||
| }) | |||
| res.append(bb) | |||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||
| return res | |||
| @@ -0,0 +1,556 @@ | |||
| import logging | |||
| import os | |||
| import re | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from rag.nlp import huqie | |||
| from .recognizer import Recognizer | |||
| class TableStructureRecognizer(Recognizer): | |||
| def __init__(self): | |||
| self.labels = [ | |||
| "table", | |||
| "table column", | |||
| "table row", | |||
| "table column header", | |||
| "table projected row header", | |||
| "table spanning cell", | |||
| ] | |||
| super().__init__(self.labels, "tsr", | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, images, thr=0.5): | |||
| tbls = super().__call__(images, thr) | |||
| res = [] | |||
| # align left&right for rows, align top&bottom for columns | |||
| for tbl in tbls: | |||
| lts = [{"label": b["type"], | |||
| "score": b["score"], | |||
| "x0": b["bbox"][0], "x1": b["bbox"][2], | |||
| "top": b["bbox"][1], "bottom": b["bbox"][-1] | |||
| } for b in tbl] | |||
| if not lts: | |||
| continue | |||
| left = [b["x0"] for b in lts if b["label"].find( | |||
| "row") > 0 or b["label"].find("header") > 0] | |||
| right = [b["x1"] for b in lts if b["label"].find( | |||
| "row") > 0 or b["label"].find("header") > 0] | |||
| if not left: | |||
| continue | |||
| left = np.median(left) if len(left) > 4 else np.min(left) | |||
| right = np.median(right) if len(right) > 4 else np.max(right) | |||
| for b in lts: | |||
| if b["label"].find("row") > 0 or b["label"].find("header") > 0: | |||
| if b["x0"] > left: | |||
| b["x0"] = left | |||
| if b["x1"] < right: | |||
| b["x1"] = right | |||
| top = [b["top"] for b in lts if b["label"] == "table column"] | |||
| bottom = [b["bottom"] for b in lts if b["label"] == "table column"] | |||
| if not top: | |||
| res.append(lts) | |||
| continue | |||
| top = np.median(top) if len(top) > 4 else np.min(top) | |||
| bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) | |||
| for b in lts: | |||
| if b["label"] == "table column": | |||
| if b["top"] > top: | |||
| b["top"] = top | |||
| if b["bottom"] < bottom: | |||
| b["bottom"] = bottom | |||
| res.append(lts) | |||
| return res | |||
| @staticmethod | |||
| def is_caption(bx): | |||
| patt = [ | |||
| r"[图表]+[ 0-9::]{2,}" | |||
| ] | |||
| if any([re.match(p, bx["text"].strip()) for p in patt]) \ | |||
| or bx["layout_type"].find("caption") >= 0: | |||
| return True | |||
| return False | |||
| def __blockType(self, b): | |||
| patt = [ | |||
| ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | |||
| (r"^(20|19)[0-9]{2}年$", "Dt"), | |||
| (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"), | |||
| ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"), | |||
| (r"^第*[一二三四1-4]季度$", "Dt"), | |||
| (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), | |||
| (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), | |||
| ("^[0-9.,+%/ -]+$", "Nu"), | |||
| (r"^[0-9A-Z/\._~-]+$", "Ca"), | |||
| (r"^[A-Z]*[a-z' -]+$", "En"), | |||
| (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), | |||
| (r"^.{1}$", "Sg") | |||
| ] | |||
| for p, n in patt: | |||
| if re.search(p, b["text"].strip()): | |||
| return n | |||
| tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1] | |||
| if len(tks) > 3: | |||
| if len(tks) < 12: | |||
| return "Tx" | |||
| else: | |||
| return "Lx" | |||
| if len(tks) == 1 and huqie.tag(tks[0]) == "nr": | |||
| return "Nr" | |||
| return "Ot" | |||
| def construct_table(self, boxes, is_english=False, html=False): | |||
| cap = "" | |||
| i = 0 | |||
| while i < len(boxes): | |||
| if self.is_caption(boxes[i]): | |||
| cap += boxes[i]["text"] | |||
| boxes.pop(i) | |||
| i -= 1 | |||
| i += 1 | |||
| if not boxes: | |||
| return [] | |||
| for b in boxes: | |||
| b["btype"] = self.__blockType(b) | |||
| max_type = Counter([b["btype"] for b in boxes]).items() | |||
| max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" | |||
| logging.debug("MAXTYPE: " + max_type) | |||
| rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] | |||
| rowh = np.min(rowh) if rowh else 0 | |||
| boxes = self.sort_R_firstly(boxes, rowh / 2) | |||
| boxes[0]["rn"] = 0 | |||
| rows = [[boxes[0]]] | |||
| btm = boxes[0]["bottom"] | |||
| for b in boxes[1:]: | |||
| b["rn"] = len(rows) - 1 | |||
| lst_r = rows[-1] | |||
| if lst_r[-1].get("R", "") != b.get("R", "") \ | |||
| or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") | |||
| ): # new row | |||
| btm = b["bottom"] | |||
| b["rn"] += 1 | |||
| rows.append([b]) | |||
| continue | |||
| btm = (btm + b["bottom"]) / 2. | |||
| rows[-1].append(b) | |||
| colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] | |||
| colwm = np.min(colwm) if colwm else 0 | |||
| crosspage = len(set([b["page_number"] for b in boxes])) > 1 | |||
| if crosspage: | |||
| boxes = self.sort_X_firstly(boxes, colwm / 2, False) | |||
| else: | |||
| boxes = self.sort_C_firstly(boxes, colwm / 2) | |||
| boxes[0]["cn"] = 0 | |||
| cols = [[boxes[0]]] | |||
| right = boxes[0]["x1"] | |||
| for b in boxes[1:]: | |||
| b["cn"] = len(cols) - 1 | |||
| lst_c = cols[-1] | |||
| if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ | |||
| "page_number"]) \ | |||
| or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col | |||
| right = b["x1"] | |||
| b["cn"] += 1 | |||
| cols.append([b]) | |||
| continue | |||
| right = (right + b["x1"]) / 2. | |||
| cols[-1].append(b) | |||
| tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] | |||
| for b in boxes: | |||
| tbl[b["rn"]][b["cn"]].append(b) | |||
| if len(rows) >= 4: | |||
| # remove single in column | |||
| j = 0 | |||
| while j < len(tbl[0]): | |||
| e, ii = 0, 0 | |||
| for i in range(len(tbl)): | |||
| if tbl[i][j]: | |||
| e += 1 | |||
| ii = i | |||
| if e > 1: | |||
| break | |||
| if e > 1: | |||
| j += 1 | |||
| continue | |||
| f = (j > 0 and tbl[ii][j - 1] and tbl[ii] | |||
| [j - 1][0].get("text")) or j == 0 | |||
| ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] | |||
| [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) | |||
| if f and ff: | |||
| j += 1 | |||
| continue | |||
| bx = tbl[ii][j][0] | |||
| logging.debug("Relocate column single: " + bx["text"]) | |||
| # j column only has one value | |||
| left, right = 100000, 100000 | |||
| if j > 0 and not f: | |||
| for i in range(len(tbl)): | |||
| if tbl[i][j - 1]: | |||
| left = min(left, np.min( | |||
| [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) | |||
| if j + 1 < len(tbl[0]) and not ff: | |||
| for i in range(len(tbl)): | |||
| if tbl[i][j + 1]: | |||
| right = min(right, np.min( | |||
| [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) | |||
| assert left < 100000 or right < 100000 | |||
| if left < right: | |||
| for jj in range(j, len(tbl[0])): | |||
| for i in range(len(tbl)): | |||
| for a in tbl[i][jj]: | |||
| a["cn"] -= 1 | |||
| if tbl[ii][j - 1]: | |||
| tbl[ii][j - 1].extend(tbl[ii][j]) | |||
| else: | |||
| tbl[ii][j - 1] = tbl[ii][j] | |||
| for i in range(len(tbl)): | |||
| tbl[i].pop(j) | |||
| else: | |||
| for jj in range(j + 1, len(tbl[0])): | |||
| for i in range(len(tbl)): | |||
| for a in tbl[i][jj]: | |||
| a["cn"] -= 1 | |||
| if tbl[ii][j + 1]: | |||
| tbl[ii][j + 1].extend(tbl[ii][j]) | |||
| else: | |||
| tbl[ii][j + 1] = tbl[ii][j] | |||
| for i in range(len(tbl)): | |||
| tbl[i].pop(j) | |||
| cols.pop(j) | |||
| assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( | |||
| len(cols), len(tbl[0])) | |||
| if len(cols) >= 4: | |||
| # remove single in row | |||
| i = 0 | |||
| while i < len(tbl): | |||
| e, jj = 0, 0 | |||
| for j in range(len(tbl[i])): | |||
| if tbl[i][j]: | |||
| e += 1 | |||
| jj = j | |||
| if e > 1: | |||
| break | |||
| if e > 1: | |||
| i += 1 | |||
| continue | |||
| f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] | |||
| [jj][0].get("text")) or i == 0 | |||
| ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] | |||
| [jj][0].get("text")) or i + 1 >= len(tbl) | |||
| if f and ff: | |||
| i += 1 | |||
| continue | |||
| bx = tbl[i][jj][0] | |||
| logging.debug("Relocate row single: " + bx["text"]) | |||
| # i row only has one value | |||
| up, down = 100000, 100000 | |||
| if i > 0 and not f: | |||
| for j in range(len(tbl[i - 1])): | |||
| if tbl[i - 1][j]: | |||
| up = min(up, np.min( | |||
| [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) | |||
| if i + 1 < len(tbl) and not ff: | |||
| for j in range(len(tbl[i + 1])): | |||
| if tbl[i + 1][j]: | |||
| down = min(down, np.min( | |||
| [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) | |||
| assert up < 100000 or down < 100000 | |||
| if up < down: | |||
| for ii in range(i, len(tbl)): | |||
| for j in range(len(tbl[ii])): | |||
| for a in tbl[ii][j]: | |||
| a["rn"] -= 1 | |||
| if tbl[i - 1][jj]: | |||
| tbl[i - 1][jj].extend(tbl[i][jj]) | |||
| else: | |||
| tbl[i - 1][jj] = tbl[i][jj] | |||
| tbl.pop(i) | |||
| else: | |||
| for ii in range(i + 1, len(tbl)): | |||
| for j in range(len(tbl[ii])): | |||
| for a in tbl[ii][j]: | |||
| a["rn"] -= 1 | |||
| if tbl[i + 1][jj]: | |||
| tbl[i + 1][jj].extend(tbl[i][jj]) | |||
| else: | |||
| tbl[i + 1][jj] = tbl[i][jj] | |||
| tbl.pop(i) | |||
| rows.pop(i) | |||
| # which rows are headers | |||
| hdset = set([]) | |||
| for i in range(len(tbl)): | |||
| cnt, h = 0, 0 | |||
| for j, arr in enumerate(tbl[i]): | |||
| if not arr: | |||
| continue | |||
| cnt += 1 | |||
| if max_type == "Nu" and arr[0]["btype"] == "Nu": | |||
| continue | |||
| if any([a.get("H") for a in arr]) \ | |||
| or (max_type == "Nu" and arr[0]["btype"] != "Nu"): | |||
| h += 1 | |||
| if h / cnt > 0.5: | |||
| hdset.add(i) | |||
| if html: | |||
| return [self.__html_table(cap, hdset, | |||
| self.__cal_spans(boxes, rows, | |||
| cols, tbl, True) | |||
| )] | |||
| return self.__desc_table(cap, hdset, | |||
| self.__cal_spans(boxes, rows, cols, tbl, False), | |||
| is_english) | |||
| def __html_table(self, cap, hdset, tbl): | |||
| # constrcut HTML | |||
| html = "<table>" | |||
| if cap: | |||
| html += f"<caption>{cap}</caption>" | |||
| for i in range(len(tbl)): | |||
| row = "<tr>" | |||
| txts = [] | |||
| for j, arr in enumerate(tbl[i]): | |||
| if arr is None: | |||
| continue | |||
| if not arr: | |||
| row += "<td></td>" if i not in hdset else "<th></th>" | |||
| continue | |||
| txt = "" | |||
| if arr: | |||
| h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) | |||
| txt = "".join([c["text"] | |||
| for c in self.sort_Y_firstly(arr, h)]) | |||
| txts.append(txt) | |||
| sp = "" | |||
| if arr[0].get("colspan"): | |||
| sp = "colspan={}".format(arr[0]["colspan"]) | |||
| if arr[0].get("rowspan"): | |||
| sp += " rowspan={}".format(arr[0]["rowspan"]) | |||
| if i in hdset: | |||
| row += f"<th {sp} >" + txt + "</th>" | |||
| else: | |||
| row += f"<td {sp} >" + txt + "</td>" | |||
| if i in hdset: | |||
| if all([t in hdset for t in txts]): | |||
| continue | |||
| for t in txts: | |||
| hdset.add(t) | |||
| if row != "<tr>": | |||
| row += "</tr>" | |||
| else: | |||
| row = "" | |||
| html += "\n" + row | |||
| html += "\n</table>" | |||
| return html | |||
| def __desc_table(self, cap, hdr_rowno, tbl, is_english): | |||
| # get text of every colomn in header row to become header text | |||
| clmno = len(tbl[0]) | |||
| rowno = len(tbl) | |||
| headers = {} | |||
| hdrset = set() | |||
| lst_hdr = [] | |||
| de = "的" if not is_english else " for " | |||
| for r in sorted(list(hdr_rowno)): | |||
| headers[r] = ["" for _ in range(clmno)] | |||
| for i in range(clmno): | |||
| if not tbl[r][i]: | |||
| continue | |||
| txt = "".join([a["text"].strip() for a in tbl[r][i]]) | |||
| headers[r][i] = txt | |||
| hdrset.add(txt) | |||
| if all([not t for t in headers[r]]): | |||
| del headers[r] | |||
| hdr_rowno.remove(r) | |||
| continue | |||
| for j in range(clmno): | |||
| if headers[r][j]: | |||
| continue | |||
| if j >= len(lst_hdr): | |||
| break | |||
| headers[r][j] = lst_hdr[j] | |||
| lst_hdr = headers[r] | |||
| for i in range(rowno): | |||
| if i not in hdr_rowno: | |||
| continue | |||
| for j in range(i + 1, rowno): | |||
| if j not in hdr_rowno: | |||
| break | |||
| for k in range(clmno): | |||
| if not headers[j - 1][k]: | |||
| continue | |||
| if headers[j][k].find(headers[j - 1][k]) >= 0: | |||
| continue | |||
| if len(headers[j][k]) > len(headers[j - 1][k]): | |||
| headers[j][k] += (de if headers[j][k] | |||
| else "") + headers[j - 1][k] | |||
| else: | |||
| headers[j][k] = headers[j - 1][k] \ | |||
| + (de if headers[j - 1][k] else "") \ | |||
| + headers[j][k] | |||
| logging.debug( | |||
| f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") | |||
| row_txt = [] | |||
| for i in range(rowno): | |||
| if i in hdr_rowno: | |||
| continue | |||
| rtxt = [] | |||
| def append(delimer): | |||
| nonlocal rtxt, row_txt | |||
| rtxt = delimer.join(rtxt) | |||
| if row_txt and len(row_txt[-1]) + len(rtxt) < 64: | |||
| row_txt[-1] += "\n" + rtxt | |||
| else: | |||
| row_txt.append(rtxt) | |||
| r = 0 | |||
| if len(headers.items()): | |||
| _arr = [(i - r, r) for r, _ in headers.items() if r < i] | |||
| if _arr: | |||
| _, r = min(_arr, key=lambda x: x[0]) | |||
| if r not in headers and clmno <= 2: | |||
| for j in range(clmno): | |||
| if not tbl[i][j]: | |||
| continue | |||
| txt = "".join([a["text"].strip() for a in tbl[i][j]]) | |||
| if txt: | |||
| rtxt.append(txt) | |||
| if rtxt: | |||
| append(":") | |||
| continue | |||
| for j in range(clmno): | |||
| if not tbl[i][j]: | |||
| continue | |||
| txt = "".join([a["text"].strip() for a in tbl[i][j]]) | |||
| if not txt: | |||
| continue | |||
| ctt = headers[r][j] if r in headers else "" | |||
| if ctt: | |||
| ctt += ":" | |||
| ctt += txt | |||
| if ctt: | |||
| rtxt.append(ctt) | |||
| if rtxt: | |||
| row_txt.append("; ".join(rtxt)) | |||
| if cap: | |||
| if is_english: | |||
| from_ = " in " | |||
| else: | |||
| from_ = "来自" | |||
| row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] | |||
| return row_txt | |||
| def __cal_spans(self, boxes, rows, cols, tbl, html=True): | |||
| # caculate span | |||
| clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) | |||
| for cln in cols] | |||
| crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) | |||
| for cln in cols] | |||
| rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) | |||
| for row in rows] | |||
| rbtm = [np.mean([c.get("R_btm", c["bottom"]) | |||
| for c in row]) for row in rows] | |||
| for b in boxes: | |||
| if "SP" not in b: | |||
| continue | |||
| b["colspan"] = [b["cn"]] | |||
| b["rowspan"] = [b["rn"]] | |||
| # col span | |||
| for j in range(0, len(clft)): | |||
| if j == b["cn"]: | |||
| continue | |||
| if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: | |||
| continue | |||
| if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: | |||
| continue | |||
| b["colspan"].append(j) | |||
| # row span | |||
| for j in range(0, len(rtop)): | |||
| if j == b["rn"]: | |||
| continue | |||
| if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: | |||
| continue | |||
| if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: | |||
| continue | |||
| b["rowspan"].append(j) | |||
| def join(arr): | |||
| if not arr: | |||
| return "" | |||
| return "".join([t["text"] for t in arr]) | |||
| # rm the spaning cells | |||
| for i in range(len(tbl)): | |||
| for j, arr in enumerate(tbl[i]): | |||
| if not arr: | |||
| continue | |||
| if all(["rowspan" not in a and "colspan" not in a for a in arr]): | |||
| continue | |||
| rowspan, colspan = [], [] | |||
| for a in arr: | |||
| if isinstance(a.get("rowspan", 0), list): | |||
| rowspan.extend(a["rowspan"]) | |||
| if isinstance(a.get("colspan", 0), list): | |||
| colspan.extend(a["colspan"]) | |||
| rowspan, colspan = set(rowspan), set(colspan) | |||
| if len(rowspan) < 2 and len(colspan) < 2: | |||
| for a in arr: | |||
| if "rowspan" in a: | |||
| del a["rowspan"] | |||
| if "colspan" in a: | |||
| del a["colspan"] | |||
| continue | |||
| rowspan, colspan = sorted(rowspan), sorted(colspan) | |||
| rowspan = list(range(rowspan[0], rowspan[-1] + 1)) | |||
| colspan = list(range(colspan[0], colspan[-1] + 1)) | |||
| assert i in rowspan, rowspan | |||
| assert j in colspan, colspan | |||
| arr = [] | |||
| for r in rowspan: | |||
| for c in colspan: | |||
| arr_txt = join(arr) | |||
| if tbl[r][c] and join(tbl[r][c]) != arr_txt: | |||
| arr.extend(tbl[r][c]) | |||
| tbl[r][c] = None if html else arr | |||
| for a in arr: | |||
| if len(rowspan) > 1: | |||
| a["rowspan"] = len(rowspan) | |||
| elif "rowspan" in a: | |||
| del a["rowspan"] | |||
| if len(colspan) > 1: | |||
| a["colspan"] = len(colspan) | |||
| elif "colspan" in a: | |||
| del a["colspan"] | |||
| tbl[rowspan[0]][colspan[0]] = arr | |||
| return tbl | |||
| @@ -1,2 +0,0 @@ | |||
| from .ocr import OCR | |||
| from .recognizer import Recognizer | |||
| @@ -1,139 +0,0 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import onnxruntime as ort | |||
| from huggingface_hub import snapshot_download | |||
| from .operators import * | |||
| from rag.settings import cron_logger | |||
| class Recognizer(object): | |||
| def __init__(self, label_list, task_name, model_dir=None): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| For Linux: | |||
| export HF_ENDPOINT=https://hf-mirror.com | |||
| For Windows: | |||
| Good luck | |||
| ^_- | |||
| """ | |||
| if not model_dir: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||
| if not os.path.exists(model_file_path): | |||
| raise ValueError("not find model file path {}".format( | |||
| model_file_path)) | |||
| if ort.get_device() == "GPU": | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||
| else: | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||
| self.label_list = label_list | |||
| def create_inputs(self, imgs, im_info): | |||
| """generate input for different model type | |||
| Args: | |||
| imgs (list(numpy)): list of images (np.ndarray) | |||
| im_info (list(dict)): list of image info | |||
| Returns: | |||
| inputs (dict): input of model | |||
| """ | |||
| inputs = {} | |||
| im_shape = [] | |||
| scale_factor = [] | |||
| if len(imgs) == 1: | |||
| inputs['image'] = np.array((imgs[0],)).astype('float32') | |||
| inputs['im_shape'] = np.array( | |||
| (im_info[0]['im_shape'],)).astype('float32') | |||
| inputs['scale_factor'] = np.array( | |||
| (im_info[0]['scale_factor'],)).astype('float32') | |||
| return inputs | |||
| for e in im_info: | |||
| im_shape.append(np.array((e['im_shape'],)).astype('float32')) | |||
| scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) | |||
| inputs['im_shape'] = np.concatenate(im_shape, axis=0) | |||
| inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) | |||
| imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] | |||
| max_shape_h = max([e[0] for e in imgs_shape]) | |||
| max_shape_w = max([e[1] for e in imgs_shape]) | |||
| padding_imgs = [] | |||
| for img in imgs: | |||
| im_c, im_h, im_w = img.shape[:] | |||
| padding_im = np.zeros( | |||
| (im_c, max_shape_h, max_shape_w), dtype=np.float32) | |||
| padding_im[:, :im_h, :im_w] = img | |||
| padding_imgs.append(padding_im) | |||
| inputs['image'] = np.stack(padding_imgs, axis=0) | |||
| return inputs | |||
| def preprocess(self, image_list): | |||
| preprocess_ops = [] | |||
| for op_info in [ | |||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||
| {'type': 'Permute'}, | |||
| {'stride': 32, 'type': 'PadStride'} | |||
| ]: | |||
| new_op_info = op_info.copy() | |||
| op_type = new_op_info.pop('type') | |||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||
| inputs = [] | |||
| for im_path in image_list: | |||
| im, im_info = preprocess(im_path, preprocess_ops) | |||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||
| return inputs | |||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||
| res = [] | |||
| imgs = [] | |||
| for i in range(len(image_list)): | |||
| if not isinstance(image_list[i], np.ndarray): | |||
| imgs.append(np.array(image_list[i])) | |||
| else: imgs.append(image_list[i]) | |||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||
| for i in range(batch_loop_cnt): | |||
| start_index = i * batch_size | |||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||
| batch_image_list = imgs[start_index:end_index] | |||
| inputs = self.preprocess(batch_image_list) | |||
| for ins in inputs: | |||
| bb = [] | |||
| for b in self.ort_sess.run(None, ins)[0]: | |||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||
| if score < thr: | |||
| continue | |||
| if clsid >= len(self.label_list): | |||
| cron_logger.warning(f"bad category id") | |||
| continue | |||
| bb.append({ | |||
| "type": self.label_list[clsid].lower(), | |||
| "bbox": [float(t) for t in bbox.tolist()], | |||
| "score": float(score) | |||
| }) | |||
| res.append(bb) | |||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||
| return res | |||
| @@ -21,7 +21,7 @@ from datetime import datetime | |||
| from api.db.db_models import Task | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.db.services.task_service import TaskService | |||
| from deepdoc.parser import HuParser | |||
| from deepdoc.parser import PdfParser | |||
| from rag.settings import cron_logger | |||
| from rag.utils import MINIO | |||
| from rag.utils import findMaxTm | |||
| @@ -80,7 +80,7 @@ def dispatch(): | |||
| tsks = [] | |||
| if r["type"] == FileType.PDF.value: | |||
| pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| for s,e in r["parser_config"].get("pages", [(0,100000)]): | |||
| e = min(e, pages) | |||
| for p in range(s, e, 10): | |||