* rename vision, add layour and tsr recognizer * trivial fixingtags/v0.1.0
| @manager.route('/set', methods=['POST']) | @manager.route('/set', methods=['POST']) | ||||
| @login_required | @login_required | ||||
| @validate_request("dialog_id") | |||||
| def set_conversation(): | def set_conversation(): | ||||
| req = request.json | req = request.json | ||||
| conv_id = req.get("conversation_id") | conv_id = req.get("conversation_id") | ||||
| @manager.route('/completion', methods=['POST']) | @manager.route('/completion', methods=['POST']) | ||||
| @login_required | @login_required | ||||
| @validate_request("dialog_id", "messages") | |||||
| @validate_request("conversation_id", "messages") | |||||
| def completion(): | def completion(): | ||||
| req = request.json | req = request.json | ||||
| msg = [] | msg = [] | ||||
| if m["role"] == "assistant" and not msg: continue | if m["role"] == "assistant" and not msg: continue | ||||
| msg.append({"role": m["role"], "content": m["content"]}) | msg.append({"role": m["role"], "content": m["content"]}) | ||||
| try: | try: | ||||
| e, dia = DialogService.get_by_id(req["dialog_id"]) | |||||
| e, conv = ConversationService.get_by_id(req["conversation_id"]) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Conversation not found!") | |||||
| conv.message.append(msg[-1]) | |||||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||||
| if not e: | if not e: | ||||
| return get_data_error_result(retmsg="Dialog not found!") | return get_data_error_result(retmsg="Dialog not found!") | ||||
| del req["dialog_id"] | |||||
| del req["conversation_id"] | |||||
| del req["messages"] | del req["messages"] | ||||
| return get_json_result(data=chat(dia, msg, **req)) | |||||
| ans = chat(dia, msg, **req) | |||||
| conv.reference.append(ans["reference"]) | |||||
| conv.message.append({"role": "assistant", "content": ans["answer"]}) | |||||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||||
| return get_json_result(data=ans) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| dialog.vector_similarity_weight, top=1024, aggs=False) | dialog.vector_similarity_weight, top=1024, aggs=False) | ||||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | ||||
| if not knowledges and prompt_config["empty_response"]: | |||||
| return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | |||||
| if not knowledges and prompt_config.get("empty_response"): | |||||
| return {"answer": prompt_config["empty_response"], "reference": kbinfos} | |||||
| kwargs["knowledge"] = "\n".join(knowledges) | kwargs["knowledge"] = "\n".join(knowledges) | ||||
| gen_conf = dialog.llm_setting | gen_conf = dialog.llm_setting | ||||
| gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | ||||
| answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | ||||
| answer = retrievaler.insert_citations(answer, | |||||
| if knowledges: | |||||
| answer = retrievaler.insert_citations(answer, | |||||
| [ck["content_ltks"] for ck in kbinfos["chunks"]], | [ck["content_ltks"] for ck in kbinfos["chunks"]], | ||||
| [ck["vector"] for ck in kbinfos["chunks"]], | [ck["vector"] for ck in kbinfos["chunks"]], | ||||
| embd_mdl, | embd_mdl, | ||||
| vtweight=dialog.vector_similarity_weight) | vtweight=dialog.vector_similarity_weight) | ||||
| for c in kbinfos["chunks"]: | for c in kbinfos["chunks"]: | ||||
| if c.get("vector"): del c["vector"] | if c.get("vector"): del c["vector"] | ||||
| return {"answer": answer, "retrieval": kbinfos} | |||||
| return {"answer": answer, "reference": kbinfos} | |||||
| def use_sql(question, field_map, tenant_id, chat_mdl): | def use_sql(question, field_map, tenant_id, chat_mdl): |
| model_type = request.args.get("model_type") | model_type = request.args.get("model_type") | ||||
| try: | try: | ||||
| objs = TenantLLMService.query(tenant_id=current_user.id) | objs = TenantLLMService.query(tenant_id=current_user.id) | ||||
| mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key]) | |||||
| facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) | |||||
| llms = LLMService.get_all() | llms = LLMService.get_all() | ||||
| llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | ||||
| for m in llms: | for m in llms: | ||||
| m["available"] = m["llm_name"] in mdlnms | |||||
| m["available"] = m["fid"] in facts | |||||
| res = {} | res = {} | ||||
| for m in llms: | for m in llms: |
| token_num = IntegerField(default=0) | token_num = IntegerField(default=0) | ||||
| chunk_num = IntegerField(default=0) | chunk_num = IntegerField(default=0) | ||||
| progress = FloatField(default=0) | progress = FloatField(default=0) | ||||
| progress_msg = CharField(max_length=512, null=True, help_text="process message", default="") | |||||
| progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") | |||||
| process_begin_at = DateTimeField(null=True) | process_begin_at = DateTimeField(null=True) | ||||
| process_duation = FloatField(default=0) | process_duation = FloatField(default=0) | ||||
| run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") | run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") | ||||
| begin_at = DateTimeField(null=True) | begin_at = DateTimeField(null=True) | ||||
| process_duation = FloatField(default=0) | process_duation = FloatField(default=0) | ||||
| progress = FloatField(default=0) | progress = FloatField(default=0) | ||||
| progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") | |||||
| progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") | |||||
| class Dialog(DataBaseModel): | class Dialog(DataBaseModel): | ||||
| dialog_id = CharField(max_length=32, null=False, index=True) | dialog_id = CharField(max_length=32, null=False, index=True) | ||||
| name = CharField(max_length=255, null=True, help_text="converastion name") | name = CharField(max_length=255, null=True, help_text="converastion name") | ||||
| message = JSONField(null=True) | message = JSONField(null=True) | ||||
| reference = JSONField(null=True, default=[]) | |||||
| class Meta: | class Meta: | ||||
| db_table = "conversation" | db_table = "conversation" |
| model_config = cls.get_api_key(tenant_id, mdlnm) | model_config = cls.get_api_key(tenant_id, mdlnm) | ||||
| if not model_config: | if not model_config: | ||||
| raise LookupError("Model({}) not found".format(mdlnm)) | |||||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||||
| model_config = model_config.to_dict() | model_config = model_config.to_dict() | ||||
| if llm_type == LLMType.EMBEDDING.value: | if llm_type == LLMType.EMBEDDING.value: | ||||
| if model_config["llm_factory"] not in EmbeddingModel: | if model_config["llm_factory"] not in EmbeddingModel: |
| from .ocr import OCR | |||||
| from .recognizer import Recognizer | |||||
| from .layout_recognizer import LayoutRecognizer | |||||
| from .table_structure_recognizer import TableStructureRecognizer |
| import os | |||||
| import re | |||||
| from collections import Counter | |||||
| from copy import deepcopy | |||||
| import numpy as np | |||||
| from api.utils.file_utils import get_project_base_directory | |||||
| from .recognizer import Recognizer | |||||
| class LayoutRecognizer(Recognizer): | |||||
| def __init__(self, domain): | |||||
| self.layout_labels = [ | |||||
| "_background_", | |||||
| "Text", | |||||
| "Title", | |||||
| "Figure", | |||||
| "Figure caption", | |||||
| "Table", | |||||
| "Table caption", | |||||
| "Header", | |||||
| "Footer", | |||||
| "Reference", | |||||
| "Equation", | |||||
| ] | |||||
| super().__init__(self.layout_labels, domain, | |||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | |||||
| def __is_garbage(b): | |||||
| patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | |||||
| r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", | |||||
| "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}", | |||||
| "\\(cid *: *[0-9]+ *\\)" | |||||
| ] | |||||
| return any([re.search(p, b["text"]) for p in patt]) | |||||
| layouts = super().__call__(image_list, thr, batch_size) | |||||
| # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7) | |||||
| assert len(image_list) == len(ocr_res) | |||||
| # Tag layout type | |||||
| boxes = [] | |||||
| assert len(image_list) == len(layouts) | |||||
| garbages = {} | |||||
| page_layout = [] | |||||
| for pn, lts in enumerate(layouts): | |||||
| bxs = ocr_res[pn] | |||||
| lts = [{"type": b["type"], | |||||
| "score": float(b["score"]), | |||||
| "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor, | |||||
| "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, | |||||
| "page_number": pn, | |||||
| } for b in lts] | |||||
| lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2) | |||||
| lts = self.layouts_cleanup(bxs, lts) | |||||
| page_layout.append(lts) | |||||
| # Tag layout type, layouts are ready | |||||
| def findLayout(ty): | |||||
| nonlocal bxs, lts, self | |||||
| lts_ = [lt for lt in lts if lt["type"] == ty] | |||||
| i = 0 | |||||
| while i < len(bxs): | |||||
| if bxs[i].get("layout_type"): | |||||
| i += 1 | |||||
| continue | |||||
| if __is_garbage(bxs[i]): | |||||
| bxs.pop(i) | |||||
| continue | |||||
| ii = self.find_overlapped_with_threashold(bxs[i], lts_, | |||||
| thr=0.4) | |||||
| if ii is None: # belong to nothing | |||||
| bxs[i]["layout_type"] = "" | |||||
| i += 1 | |||||
| continue | |||||
| lts_[ii]["visited"] = True | |||||
| if lts_[ii]["type"] in ["footer", "header", "reference"]: | |||||
| if lts_[ii]["type"] not in garbages: | |||||
| garbages[lts_[ii]["type"]] = [] | |||||
| garbages[lts_[ii]["type"]].append(bxs[i]["text"]) | |||||
| bxs.pop(i) | |||||
| continue | |||||
| bxs[i]["layoutno"] = f"{ty}-{ii}" | |||||
| bxs[i]["layout_type"] = lts_[ii]["type"] | |||||
| i += 1 | |||||
| for lt in ["footer", "header", "reference", "figure caption", | |||||
| "table caption", "title", "text", "table", "figure", "equation"]: | |||||
| findLayout(lt) | |||||
| # add box to figure layouts which has not text box | |||||
| for i, lt in enumerate( | |||||
| [lt for lt in lts if lt["type"] == "figure"]): | |||||
| if lt.get("visited"): | |||||
| continue | |||||
| lt = deepcopy(lt) | |||||
| del lt["type"] | |||||
| lt["text"] = "" | |||||
| lt["layout_type"] = "figure" | |||||
| lt["layoutno"] = f"figure-{i}" | |||||
| bxs.append(lt) | |||||
| boxes.extend(bxs) | |||||
| ocr_res = boxes | |||||
| garbag_set = set() | |||||
| for k in garbages.keys(): | |||||
| garbages[k] = Counter(garbages[k]) | |||||
| for g, c in garbages[k].items(): | |||||
| if c > 1: | |||||
| garbag_set.add(g) | |||||
| ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | |||||
| return ocr_res, page_layout | |||||
| self.rec_batch_num = 16 | self.rec_batch_num = 16 | ||||
| postprocess_params = { | postprocess_params = { | ||||
| 'name': 'CTCLabelDecode', | 'name': 'CTCLabelDecode', | ||||
| "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"), | |||||
| "character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"), | |||||
| "use_space_char": True | "use_space_char": True | ||||
| } | } | ||||
| self.postprocess_op = build_post_process(postprocess_params) | self.postprocess_op = build_post_process(postprocess_params) | ||||
| """ | """ | ||||
| if not model_dir: | if not model_dir: | ||||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") | |||||
| self.text_detector = TextDetector(model_dir) | self.text_detector = TextDetector(model_dir) | ||||
| self.text_recognizer = TextRecognizer(model_dir) | self.text_recognizer = TextRecognizer(model_dir) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| from copy import deepcopy | |||||
| import onnxruntime as ort | |||||
| from huggingface_hub import snapshot_download | |||||
| from . import seeit | |||||
| from .operators import * | |||||
| from rag.settings import cron_logger | |||||
| class Recognizer(object): | |||||
| def __init__(self, label_list, task_name, model_dir=None): | |||||
| """ | |||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||||
| For Linux: | |||||
| export HF_ENDPOINT=https://hf-mirror.com | |||||
| For Windows: | |||||
| Good luck | |||||
| ^_- | |||||
| """ | |||||
| if not model_dir: | |||||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||||
| if not os.path.exists(model_file_path): | |||||
| raise ValueError("not find model file path {}".format( | |||||
| model_file_path)) | |||||
| if ort.get_device() == "GPU": | |||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||||
| else: | |||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||||
| self.label_list = label_list | |||||
| @staticmethod | |||||
| def sort_Y_firstly(arr, threashold): | |||||
| # sort using y1 first and then x1 | |||||
| arr = sorted(arr, key=lambda r: (r["top"], r["x0"])) | |||||
| for i in range(len(arr) - 1): | |||||
| for j in range(i, -1, -1): | |||||
| # restore the order using th | |||||
| if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \ | |||||
| and arr[j + 1]["x0"] < arr[j]["x0"]: | |||||
| tmp = deepcopy(arr[j]) | |||||
| arr[j] = deepcopy(arr[j + 1]) | |||||
| arr[j + 1] = deepcopy(tmp) | |||||
| return arr | |||||
| @staticmethod | |||||
| def sort_X_firstly(arr, threashold, copy=True): | |||||
| # sort using y1 first and then x1 | |||||
| arr = sorted(arr, key=lambda r: (r["x0"], r["top"])) | |||||
| for i in range(len(arr) - 1): | |||||
| for j in range(i, -1, -1): | |||||
| # restore the order using th | |||||
| if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \ | |||||
| and arr[j + 1]["top"] < arr[j]["top"]: | |||||
| tmp = deepcopy(arr[j]) if copy else arr[j] | |||||
| arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1] | |||||
| arr[j + 1] = deepcopy(tmp) if copy else tmp | |||||
| return arr | |||||
| @staticmethod | |||||
| def sort_C_firstly(arr, thr=0): | |||||
| # sort using y1 first and then x1 | |||||
| # sorted(arr, key=lambda r: (r["x0"], r["top"])) | |||||
| arr = Recognizer.sort_X_firstly(arr, thr) | |||||
| for i in range(len(arr) - 1): | |||||
| for j in range(i, -1, -1): | |||||
| # restore the order using th | |||||
| if "C" not in arr[j] or "C" not in arr[j + 1]: | |||||
| continue | |||||
| if arr[j + 1]["C"] < arr[j]["C"] \ | |||||
| or ( | |||||
| arr[j + 1]["C"] == arr[j]["C"] | |||||
| and arr[j + 1]["top"] < arr[j]["top"] | |||||
| ): | |||||
| tmp = arr[j] | |||||
| arr[j] = arr[j + 1] | |||||
| arr[j + 1] = tmp | |||||
| return arr | |||||
| return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"])) | |||||
| @staticmethod | |||||
| def sort_R_firstly(arr, thr=0): | |||||
| # sort using y1 first and then x1 | |||||
| # sorted(arr, key=lambda r: (r["top"], r["x0"])) | |||||
| arr = Recognizer.sort_Y_firstly(arr, thr) | |||||
| for i in range(len(arr) - 1): | |||||
| for j in range(i, -1, -1): | |||||
| if "R" not in arr[j] or "R" not in arr[j + 1]: | |||||
| continue | |||||
| if arr[j + 1]["R"] < arr[j]["R"] \ | |||||
| or ( | |||||
| arr[j + 1]["R"] == arr[j]["R"] | |||||
| and arr[j + 1]["x0"] < arr[j]["x0"] | |||||
| ): | |||||
| tmp = arr[j] | |||||
| arr[j] = arr[j + 1] | |||||
| arr[j + 1] = tmp | |||||
| return arr | |||||
| @staticmethod | |||||
| def overlapped_area(a, b, ratio=True): | |||||
| tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"] | |||||
| if b["x0"] > x1 or b["x1"] < x0: | |||||
| return 0 | |||||
| if b["bottom"] < tp or b["top"] > btm: | |||||
| return 0 | |||||
| x0_ = max(b["x0"], x0) | |||||
| x1_ = min(b["x1"], x1) | |||||
| assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format( | |||||
| tp, btm, x0, x1, b) | |||||
| tp_ = max(b["top"], tp) | |||||
| btm_ = min(b["bottom"], btm) | |||||
| assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format( | |||||
| tp, btm, x0, x1, b) | |||||
| ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \ | |||||
| x0 != 0 and btm - tp != 0 else 0 | |||||
| if ov > 0 and ratio: | |||||
| ov /= (x1 - x0) * (btm - tp) | |||||
| return ov | |||||
| @staticmethod | |||||
| def layouts_cleanup(boxes, layouts, far=2, thr=0.7): | |||||
| def notOverlapped(a, b): | |||||
| return any([a["x1"] < b["x0"], | |||||
| a["x0"] > b["x1"], | |||||
| a["bottom"] < b["top"], | |||||
| a["top"] > b["bottom"]]) | |||||
| i = 0 | |||||
| while i + 1 < len(layouts): | |||||
| j = i + 1 | |||||
| while j < min(i + far, len(layouts)) \ | |||||
| and (layouts[i].get("type", "") != layouts[j].get("type", "") | |||||
| or notOverlapped(layouts[i], layouts[j])): | |||||
| j += 1 | |||||
| if j >= min(i + far, len(layouts)): | |||||
| i += 1 | |||||
| continue | |||||
| if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \ | |||||
| and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr: | |||||
| i += 1 | |||||
| continue | |||||
| if layouts[i].get("score") and layouts[j].get("score"): | |||||
| if layouts[i]["score"] > layouts[j]["score"]: | |||||
| layouts.pop(j) | |||||
| else: | |||||
| layouts.pop(i) | |||||
| continue | |||||
| area_i, area_i_1 = 0, 0 | |||||
| for b in boxes: | |||||
| if not notOverlapped(b, layouts[i]): | |||||
| area_i += Recognizer.overlapped_area(b, layouts[i], False) | |||||
| if not notOverlapped(b, layouts[j]): | |||||
| area_i_1 += Recognizer.overlapped_area(b, layouts[j], False) | |||||
| if area_i > area_i_1: | |||||
| layouts.pop(j) | |||||
| else: | |||||
| layouts.pop(i) | |||||
| return layouts | |||||
| def create_inputs(self, imgs, im_info): | |||||
| """generate input for different model type | |||||
| Args: | |||||
| imgs (list(numpy)): list of images (np.ndarray) | |||||
| im_info (list(dict)): list of image info | |||||
| Returns: | |||||
| inputs (dict): input of model | |||||
| """ | |||||
| inputs = {} | |||||
| im_shape = [] | |||||
| scale_factor = [] | |||||
| if len(imgs) == 1: | |||||
| inputs['image'] = np.array((imgs[0],)).astype('float32') | |||||
| inputs['im_shape'] = np.array( | |||||
| (im_info[0]['im_shape'],)).astype('float32') | |||||
| inputs['scale_factor'] = np.array( | |||||
| (im_info[0]['scale_factor'],)).astype('float32') | |||||
| return inputs | |||||
| for e in im_info: | |||||
| im_shape.append(np.array((e['im_shape'],)).astype('float32')) | |||||
| scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) | |||||
| inputs['im_shape'] = np.concatenate(im_shape, axis=0) | |||||
| inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) | |||||
| imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] | |||||
| max_shape_h = max([e[0] for e in imgs_shape]) | |||||
| max_shape_w = max([e[1] for e in imgs_shape]) | |||||
| padding_imgs = [] | |||||
| for img in imgs: | |||||
| im_c, im_h, im_w = img.shape[:] | |||||
| padding_im = np.zeros( | |||||
| (im_c, max_shape_h, max_shape_w), dtype=np.float32) | |||||
| padding_im[:, :im_h, :im_w] = img | |||||
| padding_imgs.append(padding_im) | |||||
| inputs['image'] = np.stack(padding_imgs, axis=0) | |||||
| return inputs | |||||
| @staticmethod | |||||
| def find_overlapped(box, boxes_sorted_by_y, naive=False): | |||||
| if not boxes_sorted_by_y: | |||||
| return | |||||
| bxs = boxes_sorted_by_y | |||||
| s, e, ii = 0, len(bxs), 0 | |||||
| while s < e and not naive: | |||||
| ii = (e + s) // 2 | |||||
| pv = bxs[ii] | |||||
| if box["bottom"] < pv["top"]: | |||||
| e = ii | |||||
| continue | |||||
| if box["top"] > pv["bottom"]: | |||||
| s = ii + 1 | |||||
| continue | |||||
| break | |||||
| while s < ii: | |||||
| if box["top"] > bxs[s]["bottom"]: | |||||
| s += 1 | |||||
| break | |||||
| while e - 1 > ii: | |||||
| if box["bottom"] < bxs[e - 1]["top"]: | |||||
| e -= 1 | |||||
| break | |||||
| max_overlaped_i, max_overlaped = None, 0 | |||||
| for i in range(s, e): | |||||
| ov = Recognizer.overlapped_area(bxs[i], box) | |||||
| if ov <= max_overlaped: | |||||
| continue | |||||
| max_overlaped_i = i | |||||
| max_overlaped = ov | |||||
| return max_overlaped_i | |||||
| @staticmethod | |||||
| def find_overlapped_with_threashold(box, boxes, thr=0.3): | |||||
| if not boxes: | |||||
| return | |||||
| max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0 | |||||
| s, e = 0, len(boxes) | |||||
| for i in range(s, e): | |||||
| ov = Recognizer.overlapped_area(box, boxes[i]) | |||||
| _ov = Recognizer.overlapped_area(boxes[i], box) | |||||
| if (ov, _ov) < (max_overlaped, _max_overlaped): | |||||
| continue | |||||
| max_overlaped_i = i | |||||
| max_overlaped = ov | |||||
| _max_overlaped = _ov | |||||
| return max_overlaped_i | |||||
| def preprocess(self, image_list): | |||||
| preprocess_ops = [] | |||||
| for op_info in [ | |||||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||||
| {'type': 'Permute'}, | |||||
| {'stride': 32, 'type': 'PadStride'} | |||||
| ]: | |||||
| new_op_info = op_info.copy() | |||||
| op_type = new_op_info.pop('type') | |||||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||||
| inputs = [] | |||||
| for im_path in image_list: | |||||
| im, im_info = preprocess(im_path, preprocess_ops) | |||||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||||
| return inputs | |||||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||||
| res = [] | |||||
| imgs = [] | |||||
| for i in range(len(image_list)): | |||||
| if not isinstance(image_list[i], np.ndarray): | |||||
| imgs.append(np.array(image_list[i])) | |||||
| else: imgs.append(image_list[i]) | |||||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||||
| for i in range(batch_loop_cnt): | |||||
| start_index = i * batch_size | |||||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||||
| batch_image_list = imgs[start_index:end_index] | |||||
| inputs = self.preprocess(batch_image_list) | |||||
| for ins in inputs: | |||||
| bb = [] | |||||
| for b in self.ort_sess.run(None, ins)[0]: | |||||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||||
| if score < thr: | |||||
| continue | |||||
| if clsid >= len(self.label_list): | |||||
| cron_logger.warning(f"bad category id") | |||||
| continue | |||||
| bb.append({ | |||||
| "type": self.label_list[clsid].lower(), | |||||
| "bbox": [float(t) for t in bbox.tolist()], | |||||
| "score": float(score) | |||||
| }) | |||||
| res.append(bb) | |||||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||||
| return res |
| import logging | |||||
| import os | |||||
| import re | |||||
| from collections import Counter | |||||
| from copy import deepcopy | |||||
| import numpy as np | |||||
| from api.utils.file_utils import get_project_base_directory | |||||
| from rag.nlp import huqie | |||||
| from .recognizer import Recognizer | |||||
| class TableStructureRecognizer(Recognizer): | |||||
| def __init__(self): | |||||
| self.labels = [ | |||||
| "table", | |||||
| "table column", | |||||
| "table row", | |||||
| "table column header", | |||||
| "table projected row header", | |||||
| "table spanning cell", | |||||
| ] | |||||
| super().__init__(self.labels, "tsr", | |||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| def __call__(self, images, thr=0.5): | |||||
| tbls = super().__call__(images, thr) | |||||
| res = [] | |||||
| # align left&right for rows, align top&bottom for columns | |||||
| for tbl in tbls: | |||||
| lts = [{"label": b["type"], | |||||
| "score": b["score"], | |||||
| "x0": b["bbox"][0], "x1": b["bbox"][2], | |||||
| "top": b["bbox"][1], "bottom": b["bbox"][-1] | |||||
| } for b in tbl] | |||||
| if not lts: | |||||
| continue | |||||
| left = [b["x0"] for b in lts if b["label"].find( | |||||
| "row") > 0 or b["label"].find("header") > 0] | |||||
| right = [b["x1"] for b in lts if b["label"].find( | |||||
| "row") > 0 or b["label"].find("header") > 0] | |||||
| if not left: | |||||
| continue | |||||
| left = np.median(left) if len(left) > 4 else np.min(left) | |||||
| right = np.median(right) if len(right) > 4 else np.max(right) | |||||
| for b in lts: | |||||
| if b["label"].find("row") > 0 or b["label"].find("header") > 0: | |||||
| if b["x0"] > left: | |||||
| b["x0"] = left | |||||
| if b["x1"] < right: | |||||
| b["x1"] = right | |||||
| top = [b["top"] for b in lts if b["label"] == "table column"] | |||||
| bottom = [b["bottom"] for b in lts if b["label"] == "table column"] | |||||
| if not top: | |||||
| res.append(lts) | |||||
| continue | |||||
| top = np.median(top) if len(top) > 4 else np.min(top) | |||||
| bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) | |||||
| for b in lts: | |||||
| if b["label"] == "table column": | |||||
| if b["top"] > top: | |||||
| b["top"] = top | |||||
| if b["bottom"] < bottom: | |||||
| b["bottom"] = bottom | |||||
| res.append(lts) | |||||
| return res | |||||
| @staticmethod | |||||
| def is_caption(bx): | |||||
| patt = [ | |||||
| r"[图表]+[ 0-9::]{2,}" | |||||
| ] | |||||
| if any([re.match(p, bx["text"].strip()) for p in patt]) \ | |||||
| or bx["layout_type"].find("caption") >= 0: | |||||
| return True | |||||
| return False | |||||
| def __blockType(self, b): | |||||
| patt = [ | |||||
| ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | |||||
| (r"^(20|19)[0-9]{2}年$", "Dt"), | |||||
| (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"), | |||||
| ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"), | |||||
| (r"^第*[一二三四1-4]季度$", "Dt"), | |||||
| (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), | |||||
| (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), | |||||
| ("^[0-9.,+%/ -]+$", "Nu"), | |||||
| (r"^[0-9A-Z/\._~-]+$", "Ca"), | |||||
| (r"^[A-Z]*[a-z' -]+$", "En"), | |||||
| (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), | |||||
| (r"^.{1}$", "Sg") | |||||
| ] | |||||
| for p, n in patt: | |||||
| if re.search(p, b["text"].strip()): | |||||
| return n | |||||
| tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1] | |||||
| if len(tks) > 3: | |||||
| if len(tks) < 12: | |||||
| return "Tx" | |||||
| else: | |||||
| return "Lx" | |||||
| if len(tks) == 1 and huqie.tag(tks[0]) == "nr": | |||||
| return "Nr" | |||||
| return "Ot" | |||||
| def construct_table(self, boxes, is_english=False, html=False): | |||||
| cap = "" | |||||
| i = 0 | |||||
| while i < len(boxes): | |||||
| if self.is_caption(boxes[i]): | |||||
| cap += boxes[i]["text"] | |||||
| boxes.pop(i) | |||||
| i -= 1 | |||||
| i += 1 | |||||
| if not boxes: | |||||
| return [] | |||||
| for b in boxes: | |||||
| b["btype"] = self.__blockType(b) | |||||
| max_type = Counter([b["btype"] for b in boxes]).items() | |||||
| max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" | |||||
| logging.debug("MAXTYPE: " + max_type) | |||||
| rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] | |||||
| rowh = np.min(rowh) if rowh else 0 | |||||
| boxes = self.sort_R_firstly(boxes, rowh / 2) | |||||
| boxes[0]["rn"] = 0 | |||||
| rows = [[boxes[0]]] | |||||
| btm = boxes[0]["bottom"] | |||||
| for b in boxes[1:]: | |||||
| b["rn"] = len(rows) - 1 | |||||
| lst_r = rows[-1] | |||||
| if lst_r[-1].get("R", "") != b.get("R", "") \ | |||||
| or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") | |||||
| ): # new row | |||||
| btm = b["bottom"] | |||||
| b["rn"] += 1 | |||||
| rows.append([b]) | |||||
| continue | |||||
| btm = (btm + b["bottom"]) / 2. | |||||
| rows[-1].append(b) | |||||
| colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] | |||||
| colwm = np.min(colwm) if colwm else 0 | |||||
| crosspage = len(set([b["page_number"] for b in boxes])) > 1 | |||||
| if crosspage: | |||||
| boxes = self.sort_X_firstly(boxes, colwm / 2, False) | |||||
| else: | |||||
| boxes = self.sort_C_firstly(boxes, colwm / 2) | |||||
| boxes[0]["cn"] = 0 | |||||
| cols = [[boxes[0]]] | |||||
| right = boxes[0]["x1"] | |||||
| for b in boxes[1:]: | |||||
| b["cn"] = len(cols) - 1 | |||||
| lst_c = cols[-1] | |||||
| if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ | |||||
| "page_number"]) \ | |||||
| or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col | |||||
| right = b["x1"] | |||||
| b["cn"] += 1 | |||||
| cols.append([b]) | |||||
| continue | |||||
| right = (right + b["x1"]) / 2. | |||||
| cols[-1].append(b) | |||||
| tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] | |||||
| for b in boxes: | |||||
| tbl[b["rn"]][b["cn"]].append(b) | |||||
| if len(rows) >= 4: | |||||
| # remove single in column | |||||
| j = 0 | |||||
| while j < len(tbl[0]): | |||||
| e, ii = 0, 0 | |||||
| for i in range(len(tbl)): | |||||
| if tbl[i][j]: | |||||
| e += 1 | |||||
| ii = i | |||||
| if e > 1: | |||||
| break | |||||
| if e > 1: | |||||
| j += 1 | |||||
| continue | |||||
| f = (j > 0 and tbl[ii][j - 1] and tbl[ii] | |||||
| [j - 1][0].get("text")) or j == 0 | |||||
| ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] | |||||
| [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) | |||||
| if f and ff: | |||||
| j += 1 | |||||
| continue | |||||
| bx = tbl[ii][j][0] | |||||
| logging.debug("Relocate column single: " + bx["text"]) | |||||
| # j column only has one value | |||||
| left, right = 100000, 100000 | |||||
| if j > 0 and not f: | |||||
| for i in range(len(tbl)): | |||||
| if tbl[i][j - 1]: | |||||
| left = min(left, np.min( | |||||
| [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) | |||||
| if j + 1 < len(tbl[0]) and not ff: | |||||
| for i in range(len(tbl)): | |||||
| if tbl[i][j + 1]: | |||||
| right = min(right, np.min( | |||||
| [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) | |||||
| assert left < 100000 or right < 100000 | |||||
| if left < right: | |||||
| for jj in range(j, len(tbl[0])): | |||||
| for i in range(len(tbl)): | |||||
| for a in tbl[i][jj]: | |||||
| a["cn"] -= 1 | |||||
| if tbl[ii][j - 1]: | |||||
| tbl[ii][j - 1].extend(tbl[ii][j]) | |||||
| else: | |||||
| tbl[ii][j - 1] = tbl[ii][j] | |||||
| for i in range(len(tbl)): | |||||
| tbl[i].pop(j) | |||||
| else: | |||||
| for jj in range(j + 1, len(tbl[0])): | |||||
| for i in range(len(tbl)): | |||||
| for a in tbl[i][jj]: | |||||
| a["cn"] -= 1 | |||||
| if tbl[ii][j + 1]: | |||||
| tbl[ii][j + 1].extend(tbl[ii][j]) | |||||
| else: | |||||
| tbl[ii][j + 1] = tbl[ii][j] | |||||
| for i in range(len(tbl)): | |||||
| tbl[i].pop(j) | |||||
| cols.pop(j) | |||||
| assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( | |||||
| len(cols), len(tbl[0])) | |||||
| if len(cols) >= 4: | |||||
| # remove single in row | |||||
| i = 0 | |||||
| while i < len(tbl): | |||||
| e, jj = 0, 0 | |||||
| for j in range(len(tbl[i])): | |||||
| if tbl[i][j]: | |||||
| e += 1 | |||||
| jj = j | |||||
| if e > 1: | |||||
| break | |||||
| if e > 1: | |||||
| i += 1 | |||||
| continue | |||||
| f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] | |||||
| [jj][0].get("text")) or i == 0 | |||||
| ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] | |||||
| [jj][0].get("text")) or i + 1 >= len(tbl) | |||||
| if f and ff: | |||||
| i += 1 | |||||
| continue | |||||
| bx = tbl[i][jj][0] | |||||
| logging.debug("Relocate row single: " + bx["text"]) | |||||
| # i row only has one value | |||||
| up, down = 100000, 100000 | |||||
| if i > 0 and not f: | |||||
| for j in range(len(tbl[i - 1])): | |||||
| if tbl[i - 1][j]: | |||||
| up = min(up, np.min( | |||||
| [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) | |||||
| if i + 1 < len(tbl) and not ff: | |||||
| for j in range(len(tbl[i + 1])): | |||||
| if tbl[i + 1][j]: | |||||
| down = min(down, np.min( | |||||
| [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) | |||||
| assert up < 100000 or down < 100000 | |||||
| if up < down: | |||||
| for ii in range(i, len(tbl)): | |||||
| for j in range(len(tbl[ii])): | |||||
| for a in tbl[ii][j]: | |||||
| a["rn"] -= 1 | |||||
| if tbl[i - 1][jj]: | |||||
| tbl[i - 1][jj].extend(tbl[i][jj]) | |||||
| else: | |||||
| tbl[i - 1][jj] = tbl[i][jj] | |||||
| tbl.pop(i) | |||||
| else: | |||||
| for ii in range(i + 1, len(tbl)): | |||||
| for j in range(len(tbl[ii])): | |||||
| for a in tbl[ii][j]: | |||||
| a["rn"] -= 1 | |||||
| if tbl[i + 1][jj]: | |||||
| tbl[i + 1][jj].extend(tbl[i][jj]) | |||||
| else: | |||||
| tbl[i + 1][jj] = tbl[i][jj] | |||||
| tbl.pop(i) | |||||
| rows.pop(i) | |||||
| # which rows are headers | |||||
| hdset = set([]) | |||||
| for i in range(len(tbl)): | |||||
| cnt, h = 0, 0 | |||||
| for j, arr in enumerate(tbl[i]): | |||||
| if not arr: | |||||
| continue | |||||
| cnt += 1 | |||||
| if max_type == "Nu" and arr[0]["btype"] == "Nu": | |||||
| continue | |||||
| if any([a.get("H") for a in arr]) \ | |||||
| or (max_type == "Nu" and arr[0]["btype"] != "Nu"): | |||||
| h += 1 | |||||
| if h / cnt > 0.5: | |||||
| hdset.add(i) | |||||
| if html: | |||||
| return [self.__html_table(cap, hdset, | |||||
| self.__cal_spans(boxes, rows, | |||||
| cols, tbl, True) | |||||
| )] | |||||
| return self.__desc_table(cap, hdset, | |||||
| self.__cal_spans(boxes, rows, cols, tbl, False), | |||||
| is_english) | |||||
| def __html_table(self, cap, hdset, tbl): | |||||
| # constrcut HTML | |||||
| html = "<table>" | |||||
| if cap: | |||||
| html += f"<caption>{cap}</caption>" | |||||
| for i in range(len(tbl)): | |||||
| row = "<tr>" | |||||
| txts = [] | |||||
| for j, arr in enumerate(tbl[i]): | |||||
| if arr is None: | |||||
| continue | |||||
| if not arr: | |||||
| row += "<td></td>" if i not in hdset else "<th></th>" | |||||
| continue | |||||
| txt = "" | |||||
| if arr: | |||||
| h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) | |||||
| txt = "".join([c["text"] | |||||
| for c in self.sort_Y_firstly(arr, h)]) | |||||
| txts.append(txt) | |||||
| sp = "" | |||||
| if arr[0].get("colspan"): | |||||
| sp = "colspan={}".format(arr[0]["colspan"]) | |||||
| if arr[0].get("rowspan"): | |||||
| sp += " rowspan={}".format(arr[0]["rowspan"]) | |||||
| if i in hdset: | |||||
| row += f"<th {sp} >" + txt + "</th>" | |||||
| else: | |||||
| row += f"<td {sp} >" + txt + "</td>" | |||||
| if i in hdset: | |||||
| if all([t in hdset for t in txts]): | |||||
| continue | |||||
| for t in txts: | |||||
| hdset.add(t) | |||||
| if row != "<tr>": | |||||
| row += "</tr>" | |||||
| else: | |||||
| row = "" | |||||
| html += "\n" + row | |||||
| html += "\n</table>" | |||||
| return html | |||||
| def __desc_table(self, cap, hdr_rowno, tbl, is_english): | |||||
| # get text of every colomn in header row to become header text | |||||
| clmno = len(tbl[0]) | |||||
| rowno = len(tbl) | |||||
| headers = {} | |||||
| hdrset = set() | |||||
| lst_hdr = [] | |||||
| de = "的" if not is_english else " for " | |||||
| for r in sorted(list(hdr_rowno)): | |||||
| headers[r] = ["" for _ in range(clmno)] | |||||
| for i in range(clmno): | |||||
| if not tbl[r][i]: | |||||
| continue | |||||
| txt = "".join([a["text"].strip() for a in tbl[r][i]]) | |||||
| headers[r][i] = txt | |||||
| hdrset.add(txt) | |||||
| if all([not t for t in headers[r]]): | |||||
| del headers[r] | |||||
| hdr_rowno.remove(r) | |||||
| continue | |||||
| for j in range(clmno): | |||||
| if headers[r][j]: | |||||
| continue | |||||
| if j >= len(lst_hdr): | |||||
| break | |||||
| headers[r][j] = lst_hdr[j] | |||||
| lst_hdr = headers[r] | |||||
| for i in range(rowno): | |||||
| if i not in hdr_rowno: | |||||
| continue | |||||
| for j in range(i + 1, rowno): | |||||
| if j not in hdr_rowno: | |||||
| break | |||||
| for k in range(clmno): | |||||
| if not headers[j - 1][k]: | |||||
| continue | |||||
| if headers[j][k].find(headers[j - 1][k]) >= 0: | |||||
| continue | |||||
| if len(headers[j][k]) > len(headers[j - 1][k]): | |||||
| headers[j][k] += (de if headers[j][k] | |||||
| else "") + headers[j - 1][k] | |||||
| else: | |||||
| headers[j][k] = headers[j - 1][k] \ | |||||
| + (de if headers[j - 1][k] else "") \ | |||||
| + headers[j][k] | |||||
| logging.debug( | |||||
| f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") | |||||
| row_txt = [] | |||||
| for i in range(rowno): | |||||
| if i in hdr_rowno: | |||||
| continue | |||||
| rtxt = [] | |||||
| def append(delimer): | |||||
| nonlocal rtxt, row_txt | |||||
| rtxt = delimer.join(rtxt) | |||||
| if row_txt and len(row_txt[-1]) + len(rtxt) < 64: | |||||
| row_txt[-1] += "\n" + rtxt | |||||
| else: | |||||
| row_txt.append(rtxt) | |||||
| r = 0 | |||||
| if len(headers.items()): | |||||
| _arr = [(i - r, r) for r, _ in headers.items() if r < i] | |||||
| if _arr: | |||||
| _, r = min(_arr, key=lambda x: x[0]) | |||||
| if r not in headers and clmno <= 2: | |||||
| for j in range(clmno): | |||||
| if not tbl[i][j]: | |||||
| continue | |||||
| txt = "".join([a["text"].strip() for a in tbl[i][j]]) | |||||
| if txt: | |||||
| rtxt.append(txt) | |||||
| if rtxt: | |||||
| append(":") | |||||
| continue | |||||
| for j in range(clmno): | |||||
| if not tbl[i][j]: | |||||
| continue | |||||
| txt = "".join([a["text"].strip() for a in tbl[i][j]]) | |||||
| if not txt: | |||||
| continue | |||||
| ctt = headers[r][j] if r in headers else "" | |||||
| if ctt: | |||||
| ctt += ":" | |||||
| ctt += txt | |||||
| if ctt: | |||||
| rtxt.append(ctt) | |||||
| if rtxt: | |||||
| row_txt.append("; ".join(rtxt)) | |||||
| if cap: | |||||
| if is_english: | |||||
| from_ = " in " | |||||
| else: | |||||
| from_ = "来自" | |||||
| row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] | |||||
| return row_txt | |||||
| def __cal_spans(self, boxes, rows, cols, tbl, html=True): | |||||
| # caculate span | |||||
| clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) | |||||
| for cln in cols] | |||||
| crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) | |||||
| for cln in cols] | |||||
| rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) | |||||
| for row in rows] | |||||
| rbtm = [np.mean([c.get("R_btm", c["bottom"]) | |||||
| for c in row]) for row in rows] | |||||
| for b in boxes: | |||||
| if "SP" not in b: | |||||
| continue | |||||
| b["colspan"] = [b["cn"]] | |||||
| b["rowspan"] = [b["rn"]] | |||||
| # col span | |||||
| for j in range(0, len(clft)): | |||||
| if j == b["cn"]: | |||||
| continue | |||||
| if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: | |||||
| continue | |||||
| if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: | |||||
| continue | |||||
| b["colspan"].append(j) | |||||
| # row span | |||||
| for j in range(0, len(rtop)): | |||||
| if j == b["rn"]: | |||||
| continue | |||||
| if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: | |||||
| continue | |||||
| if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: | |||||
| continue | |||||
| b["rowspan"].append(j) | |||||
| def join(arr): | |||||
| if not arr: | |||||
| return "" | |||||
| return "".join([t["text"] for t in arr]) | |||||
| # rm the spaning cells | |||||
| for i in range(len(tbl)): | |||||
| for j, arr in enumerate(tbl[i]): | |||||
| if not arr: | |||||
| continue | |||||
| if all(["rowspan" not in a and "colspan" not in a for a in arr]): | |||||
| continue | |||||
| rowspan, colspan = [], [] | |||||
| for a in arr: | |||||
| if isinstance(a.get("rowspan", 0), list): | |||||
| rowspan.extend(a["rowspan"]) | |||||
| if isinstance(a.get("colspan", 0), list): | |||||
| colspan.extend(a["colspan"]) | |||||
| rowspan, colspan = set(rowspan), set(colspan) | |||||
| if len(rowspan) < 2 and len(colspan) < 2: | |||||
| for a in arr: | |||||
| if "rowspan" in a: | |||||
| del a["rowspan"] | |||||
| if "colspan" in a: | |||||
| del a["colspan"] | |||||
| continue | |||||
| rowspan, colspan = sorted(rowspan), sorted(colspan) | |||||
| rowspan = list(range(rowspan[0], rowspan[-1] + 1)) | |||||
| colspan = list(range(colspan[0], colspan[-1] + 1)) | |||||
| assert i in rowspan, rowspan | |||||
| assert j in colspan, colspan | |||||
| arr = [] | |||||
| for r in rowspan: | |||||
| for c in colspan: | |||||
| arr_txt = join(arr) | |||||
| if tbl[r][c] and join(tbl[r][c]) != arr_txt: | |||||
| arr.extend(tbl[r][c]) | |||||
| tbl[r][c] = None if html else arr | |||||
| for a in arr: | |||||
| if len(rowspan) > 1: | |||||
| a["rowspan"] = len(rowspan) | |||||
| elif "rowspan" in a: | |||||
| del a["rowspan"] | |||||
| if len(colspan) > 1: | |||||
| a["colspan"] = len(colspan) | |||||
| elif "colspan" in a: | |||||
| del a["colspan"] | |||||
| tbl[rowspan[0]][colspan[0]] = arr | |||||
| return tbl | |||||
| from .ocr import OCR | |||||
| from .recognizer import Recognizer |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | |||||
| import onnxruntime as ort | |||||
| from huggingface_hub import snapshot_download | |||||
| from .operators import * | |||||
| from rag.settings import cron_logger | |||||
| class Recognizer(object): | |||||
| def __init__(self, label_list, task_name, model_dir=None): | |||||
| """ | |||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||||
| For Linux: | |||||
| export HF_ENDPOINT=https://hf-mirror.com | |||||
| For Windows: | |||||
| Good luck | |||||
| ^_- | |||||
| """ | |||||
| if not model_dir: | |||||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||||
| if not os.path.exists(model_file_path): | |||||
| raise ValueError("not find model file path {}".format( | |||||
| model_file_path)) | |||||
| if ort.get_device() == "GPU": | |||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||||
| else: | |||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||||
| self.label_list = label_list | |||||
| def create_inputs(self, imgs, im_info): | |||||
| """generate input for different model type | |||||
| Args: | |||||
| imgs (list(numpy)): list of images (np.ndarray) | |||||
| im_info (list(dict)): list of image info | |||||
| Returns: | |||||
| inputs (dict): input of model | |||||
| """ | |||||
| inputs = {} | |||||
| im_shape = [] | |||||
| scale_factor = [] | |||||
| if len(imgs) == 1: | |||||
| inputs['image'] = np.array((imgs[0],)).astype('float32') | |||||
| inputs['im_shape'] = np.array( | |||||
| (im_info[0]['im_shape'],)).astype('float32') | |||||
| inputs['scale_factor'] = np.array( | |||||
| (im_info[0]['scale_factor'],)).astype('float32') | |||||
| return inputs | |||||
| for e in im_info: | |||||
| im_shape.append(np.array((e['im_shape'],)).astype('float32')) | |||||
| scale_factor.append(np.array((e['scale_factor'],)).astype('float32')) | |||||
| inputs['im_shape'] = np.concatenate(im_shape, axis=0) | |||||
| inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) | |||||
| imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] | |||||
| max_shape_h = max([e[0] for e in imgs_shape]) | |||||
| max_shape_w = max([e[1] for e in imgs_shape]) | |||||
| padding_imgs = [] | |||||
| for img in imgs: | |||||
| im_c, im_h, im_w = img.shape[:] | |||||
| padding_im = np.zeros( | |||||
| (im_c, max_shape_h, max_shape_w), dtype=np.float32) | |||||
| padding_im[:, :im_h, :im_w] = img | |||||
| padding_imgs.append(padding_im) | |||||
| inputs['image'] = np.stack(padding_imgs, axis=0) | |||||
| return inputs | |||||
| def preprocess(self, image_list): | |||||
| preprocess_ops = [] | |||||
| for op_info in [ | |||||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||||
| {'type': 'Permute'}, | |||||
| {'stride': 32, 'type': 'PadStride'} | |||||
| ]: | |||||
| new_op_info = op_info.copy() | |||||
| op_type = new_op_info.pop('type') | |||||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||||
| inputs = [] | |||||
| for im_path in image_list: | |||||
| im, im_info = preprocess(im_path, preprocess_ops) | |||||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||||
| return inputs | |||||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||||
| res = [] | |||||
| imgs = [] | |||||
| for i in range(len(image_list)): | |||||
| if not isinstance(image_list[i], np.ndarray): | |||||
| imgs.append(np.array(image_list[i])) | |||||
| else: imgs.append(image_list[i]) | |||||
| batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size) | |||||
| for i in range(batch_loop_cnt): | |||||
| start_index = i * batch_size | |||||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||||
| batch_image_list = imgs[start_index:end_index] | |||||
| inputs = self.preprocess(batch_image_list) | |||||
| for ins in inputs: | |||||
| bb = [] | |||||
| for b in self.ort_sess.run(None, ins)[0]: | |||||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||||
| if score < thr: | |||||
| continue | |||||
| if clsid >= len(self.label_list): | |||||
| cron_logger.warning(f"bad category id") | |||||
| continue | |||||
| bb.append({ | |||||
| "type": self.label_list[clsid].lower(), | |||||
| "bbox": [float(t) for t in bbox.tolist()], | |||||
| "score": float(score) | |||||
| }) | |||||
| res.append(bb) | |||||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||||
| return res |
| from api.db.db_models import Task | from api.db.db_models import Task | ||||
| from api.db.db_utils import bulk_insert_into_db | from api.db.db_utils import bulk_insert_into_db | ||||
| from api.db.services.task_service import TaskService | from api.db.services.task_service import TaskService | ||||
| from deepdoc.parser import HuParser | |||||
| from deepdoc.parser import PdfParser | |||||
| from rag.settings import cron_logger | from rag.settings import cron_logger | ||||
| from rag.utils import MINIO | from rag.utils import MINIO | ||||
| from rag.utils import findMaxTm | from rag.utils import findMaxTm | ||||
| tsks = [] | tsks = [] | ||||
| if r["type"] == FileType.PDF.value: | if r["type"] == FileType.PDF.value: | ||||
| pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||||
| pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||||
| for s,e in r["parser_config"].get("pages", [(0,100000)]): | for s,e in r["parser_config"].get("pages", [(0,100000)]): | ||||
| e = min(e, pages) | e = min(e, pages) | ||||
| for p in range(s, e, 10): | for p in range(s, e, 10): |