| @@ -58,7 +58,7 @@ def set_conversation(): | |||
| conv = { | |||
| "id": get_uuid(), | |||
| "dialog_id": req["dialog_id"], | |||
| "name": "New conversation", | |||
| "name": req.get("name", "New conversation"), | |||
| "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] | |||
| } | |||
| ConversationService.save(**conv) | |||
| @@ -102,7 +102,7 @@ def rm(): | |||
| def list_convsersation(): | |||
| dialog_id = request.args["dialog_id"] | |||
| try: | |||
| convs = ConversationService.query(dialog_id=dialog_id) | |||
| convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) | |||
| convs = [d.to_dict() for d in convs] | |||
| return get_json_result(data=convs) | |||
| except Exception as e: | |||
| @@ -185,5 +185,11 @@ def thumbnail(filename, blob): | |||
| pass | |||
| def traversal_files(base): | |||
| for root, ds, fs in os.walk(base): | |||
| for f in fs: | |||
| fullname = os.path.join(root, f) | |||
| yield fullname | |||
| @@ -11,7 +11,36 @@ English | [简体中文](./README_zh.md) | |||
| With a bunch of documents from various domains with various formats and along with diverse retrieval requirements, | |||
| an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose. | |||
| There 2 parts in *Deep*Doc so far: vision and parser. | |||
| There are 2 parts in *Deep*Doc so far: vision and parser. | |||
| You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR. | |||
| ```bash | |||
| python deepdoc/vision/t_ocr.py -h | |||
| usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] | |||
| options: | |||
| -h, --help show this help message and exit | |||
| --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF | |||
| --output_dir OUTPUT_DIR | |||
| Directory where to store the output images. Default: './ocr_outputs' | |||
| ``` | |||
| ```bash | |||
| python deepdoc/vision/t_recognizer.py -h | |||
| usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}] | |||
| options: | |||
| -h, --help show this help message and exit | |||
| --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF | |||
| --output_dir OUTPUT_DIR | |||
| Directory where to store the output images. Default: './layouts_outputs' | |||
| --threshold THRESHOLD | |||
| A threshold to filter out detections. Default: 0.5 | |||
| --mode {layout,tsr} Task mode: layout recognition or table structure recognition | |||
| ``` | |||
| Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!! | |||
| ```bash | |||
| export HF_ENDPOINT=https://hf-mirror.com | |||
| ``` | |||
| <a name="2"></a> | |||
| ## 2. Vision | |||
| @@ -19,9 +48,14 @@ There 2 parts in *Deep*Doc so far: vision and parser. | |||
| We use vision information to resolve problems as human being. | |||
| - OCR. Since a lot of documents presented as images or at least be able to transform to image, | |||
| OCR is a very essential and fundamental or even universal solution for text extraction. | |||
| ```bash | |||
| python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result | |||
| ``` | |||
| The inputs could be directory to images or PDF, or a image or PDF. | |||
| You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results, | |||
| txt files which contain the OCR text. | |||
| <div align="center" style="margin-top:20px;margin-bottom:20px;"> | |||
| <img src="https://lh6.googleusercontent.com/2xdiSjaGWkZ71YdORc71Ujf7jCHmO6G-6ONklzGiUYEh3QZpjPo6MQ9eqEFX20am_cdW4Ck0YRraXEetXWnM08kJd99yhik13Cy0_YKUAq2zVGR15LzkovRAmK9iT4o3hcJ8dTpspaJKUwt6R4gN7So" width="300"/> | |||
| <img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/> | |||
| </div> | |||
| - Layout recognition. Documents from different domain may have various layouts, | |||
| @@ -39,11 +73,18 @@ We use vision information to resolve problems as human being. | |||
| - Footer | |||
| - Reference | |||
| - Equation | |||
| Have a try on the following command to see the layout detection results. | |||
| ```bash | |||
| python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result | |||
| ``` | |||
| The inputs could be directory to images or PDF, or a image or PDF. | |||
| You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following: | |||
| <div align="center" style="margin-top:20px;margin-bottom:20px;"> | |||
| <img src="https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/layout/layout.png?raw=true" width="900"/> | |||
| <img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/> | |||
| </div> | |||
| - Table Structure Recognition(TSR). Data table is a frequently used structure present data including numbers or text. | |||
| - Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text. | |||
| And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers. | |||
| Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM. | |||
| We have five labels for TSR task: | |||
| @@ -52,8 +93,15 @@ We use vision information to resolve problems as human being. | |||
| - Column header | |||
| - Projected row header | |||
| - Spanning cell | |||
| Have a try on the following command to see the layout detection results. | |||
| ```bash | |||
| python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result | |||
| ``` | |||
| The inputs could be directory to images or PDF, or a image or PDF. | |||
| You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following: | |||
| <div align="center" style="margin-top:20px;margin-bottom:20px;"> | |||
| <img src="https://user-images.githubusercontent.com/10793386/139559159-cd23c972-8731-48ed-91df-f3f27e9f4d79.jpg" width="900"/> | |||
| <img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/> | |||
| </div> | |||
| <a name="3"></a> | |||
| @@ -71,4 +119,4 @@ The résumé is a very complicated kind of document. A résumé which is compose | |||
| with various layouts could be resolved into structured data composed of nearly a hundred of fields. | |||
| We haven't opened the parser yet, as we open the processing method after parsing procedure. | |||
| @@ -1,4 +1,49 @@ | |||
| from .ocr import OCR | |||
| from .recognizer import Recognizer | |||
| from .layout_recognizer import LayoutRecognizer | |||
| from .table_structure_recognizer import TableStructureRecognizer | |||
| def init_in_out(args): | |||
| from PIL import Image | |||
| import fitz | |||
| import os | |||
| import traceback | |||
| from api.utils.file_utils import traversal_files | |||
| images = [] | |||
| outputs = [] | |||
| if not os.path.exists(args.output_dir): | |||
| os.mkdir(args.output_dir) | |||
| def pdf_pages(fnm, zoomin=3): | |||
| nonlocal outputs, images | |||
| pdf = fitz.open(fnm) | |||
| mat = fitz.Matrix(zoomin, zoomin) | |||
| for i, page in enumerate(pdf): | |||
| pix = page.get_pixmap(matrix=mat) | |||
| img = Image.frombytes("RGB", [pix.width, pix.height], | |||
| pix.samples) | |||
| images.append(img) | |||
| outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg") | |||
| def images_and_outputs(fnm): | |||
| nonlocal outputs, images | |||
| if fnm.split(".")[-1].lower() == "pdf": | |||
| pdf_pages(fnm) | |||
| return | |||
| try: | |||
| images.append(Image.open(fnm)) | |||
| outputs.append(os.path.split(fnm)[-1]) | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| if os.path.isdir(args.inputs): | |||
| for fnm in traversal_files(args.inputs): | |||
| images_and_outputs(fnm) | |||
| else: | |||
| images_and_outputs(args.inputs) | |||
| for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i]) | |||
| return images, outputs | |||
| @@ -1,17 +1,26 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import re | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from .recognizer import Recognizer | |||
| from deepdoc.vision import Recognizer | |||
| class LayoutRecognizer(Recognizer): | |||
| def __init__(self, domain): | |||
| self.layout_labels = [ | |||
| labels = [ | |||
| "_background_", | |||
| "Text", | |||
| "Title", | |||
| @@ -24,7 +33,8 @@ class LayoutRecognizer(Recognizer): | |||
| "Reference", | |||
| "Equation", | |||
| ] | |||
| super().__init__(self.layout_labels, domain, | |||
| def __init__(self, domain): | |||
| super().__init__(self.labels, domain, | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | |||
| @@ -37,7 +47,7 @@ class LayoutRecognizer(Recognizer): | |||
| return any([re.search(p, b["text"]) for p in patt]) | |||
| layouts = super().__call__(image_list, thr, batch_size) | |||
| # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7) | |||
| # save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7) | |||
| assert len(image_list) == len(ocr_res) | |||
| # Tag layout type | |||
| boxes = [] | |||
| @@ -117,3 +127,5 @@ class LayoutRecognizer(Recognizer): | |||
| ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | |||
| return ocr_res, page_layout | |||
| @@ -17,7 +17,6 @@ from copy import deepcopy | |||
| import onnxruntime as ort | |||
| from huggingface_hub import snapshot_download | |||
| from . import seeit | |||
| from .operators import * | |||
| from rag.settings import cron_logger | |||
| @@ -36,7 +35,7 @@ class Recognizer(object): | |||
| """ | |||
| if not model_dir: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") | |||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | |||
| if not os.path.exists(model_file_path): | |||
| @@ -46,6 +45,9 @@ class Recognizer(object): | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | |||
| else: | |||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | |||
| self.input_names = [node.name for node in self.ort_sess.get_inputs()] | |||
| self.output_names = [node.name for node in self.ort_sess.get_outputs()] | |||
| self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4] | |||
| self.label_list = label_list | |||
| @staticmethod | |||
| @@ -275,23 +277,131 @@ class Recognizer(object): | |||
| return max_overlaped_i | |||
| def preprocess(self, image_list): | |||
| preprocess_ops = [] | |||
| for op_info in [ | |||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||
| {'type': 'Permute'}, | |||
| {'stride': 32, 'type': 'PadStride'} | |||
| ]: | |||
| new_op_info = op_info.copy() | |||
| op_type = new_op_info.pop('type') | |||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||
| inputs = [] | |||
| for im_path in image_list: | |||
| im, im_info = preprocess(im_path, preprocess_ops) | |||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||
| if "scale_factor" in self.input_names: | |||
| preprocess_ops = [] | |||
| for op_info in [ | |||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||
| {'type': 'Permute'}, | |||
| {'stride': 32, 'type': 'PadStride'} | |||
| ]: | |||
| new_op_info = op_info.copy() | |||
| op_type = new_op_info.pop('type') | |||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||
| for im_path in image_list: | |||
| im, im_info = preprocess(im_path, preprocess_ops) | |||
| inputs.append({"image": np.array((im,)).astype('float32'), | |||
| "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||
| else: | |||
| hh, ww = self.input_shape | |||
| for img in image_list: | |||
| h, w = img.shape[:2] | |||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||
| img = cv2.resize(np.array(img).astype('float32'), (ww, hh)) | |||
| # Scale input pixel values to 0 to 1 | |||
| img /= 255.0 | |||
| img = img.transpose(2, 0, 1) | |||
| img = img[np.newaxis, :, :, :].astype(np.float32) | |||
| inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]}) | |||
| return inputs | |||
| def postprocess(self, boxes, inputs, thr): | |||
| if "scale_factor" in self.input_names: | |||
| bb = [] | |||
| for b in boxes: | |||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||
| if score < thr: | |||
| continue | |||
| if clsid >= len(self.label_list): | |||
| cron_logger.warning(f"bad category id") | |||
| continue | |||
| bb.append({ | |||
| "type": self.label_list[clsid].lower(), | |||
| "bbox": [float(t) for t in bbox.tolist()], | |||
| "score": float(score) | |||
| }) | |||
| return bb | |||
| def xywh2xyxy(x): | |||
| # [x, y, w, h] to [x1, y1, x2, y2] | |||
| y = np.copy(x) | |||
| y[:, 0] = x[:, 0] - x[:, 2] / 2 | |||
| y[:, 1] = x[:, 1] - x[:, 3] / 2 | |||
| y[:, 2] = x[:, 0] + x[:, 2] / 2 | |||
| y[:, 3] = x[:, 1] + x[:, 3] / 2 | |||
| return y | |||
| def compute_iou(box, boxes): | |||
| # Compute xmin, ymin, xmax, ymax for both boxes | |||
| xmin = np.maximum(box[0], boxes[:, 0]) | |||
| ymin = np.maximum(box[1], boxes[:, 1]) | |||
| xmax = np.minimum(box[2], boxes[:, 2]) | |||
| ymax = np.minimum(box[3], boxes[:, 3]) | |||
| # Compute intersection area | |||
| intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) | |||
| # Compute union area | |||
| box_area = (box[2] - box[0]) * (box[3] - box[1]) | |||
| boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||
| union_area = box_area + boxes_area - intersection_area | |||
| # Compute IoU | |||
| iou = intersection_area / union_area | |||
| return iou | |||
| def iou_filter(boxes, scores, iou_threshold): | |||
| sorted_indices = np.argsort(scores)[::-1] | |||
| keep_boxes = [] | |||
| while sorted_indices.size > 0: | |||
| # Pick the last box | |||
| box_id = sorted_indices[0] | |||
| keep_boxes.append(box_id) | |||
| # Compute IoU of the picked box with the rest | |||
| ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) | |||
| # Remove boxes with IoU over the threshold | |||
| keep_indices = np.where(ious < iou_threshold)[0] | |||
| # print(keep_indices.shape, sorted_indices.shape) | |||
| sorted_indices = sorted_indices[keep_indices + 1] | |||
| return keep_boxes | |||
| boxes = np.squeeze(boxes).T | |||
| # Filter out object confidence scores below threshold | |||
| scores = np.max(boxes[:, 4:], axis=1) | |||
| boxes = boxes[scores > thr, :] | |||
| scores = scores[scores > thr] | |||
| if len(boxes) == 0: return [] | |||
| # Get the class with the highest confidence | |||
| class_ids = np.argmax(boxes[:, 4:], axis=1) | |||
| boxes = boxes[:, :4] | |||
| input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]]) | |||
| boxes = np.multiply(boxes, input_shape, dtype=np.float32) | |||
| boxes = xywh2xyxy(boxes) | |||
| unique_class_ids = np.unique(class_ids) | |||
| indices = [] | |||
| for class_id in unique_class_ids: | |||
| class_indices = np.where(class_ids == class_id)[0] | |||
| class_boxes = boxes[class_indices, :] | |||
| class_scores = scores[class_indices] | |||
| class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2) | |||
| indices.extend(class_indices[class_keep_boxes]) | |||
| return [{ | |||
| "type": self.label_list[class_ids[i]].lower(), | |||
| "bbox": [float(t) for t in boxes[i].tolist()], | |||
| "score": float(scores[i]) | |||
| } for i in indices] | |||
| def __call__(self, image_list, thr=0.7, batch_size=16): | |||
| res = [] | |||
| imgs = [] | |||
| @@ -306,22 +416,14 @@ class Recognizer(object): | |||
| end_index = min((i + 1) * batch_size, len(imgs)) | |||
| batch_image_list = imgs[start_index:end_index] | |||
| inputs = self.preprocess(batch_image_list) | |||
| print("preprocess") | |||
| for ins in inputs: | |||
| bb = [] | |||
| for b in self.ort_sess.run(None, ins)[0]: | |||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||
| if score < thr: | |||
| continue | |||
| if clsid >= len(self.label_list): | |||
| cron_logger.warning(f"bad category id") | |||
| continue | |||
| bb.append({ | |||
| "type": self.label_list[clsid].lower(), | |||
| "bbox": [float(t) for t in bbox.tolist()], | |||
| "score": float(score) | |||
| }) | |||
| bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr) | |||
| res.append(bb) | |||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | |||
| return res | |||
| @@ -0,0 +1,47 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os, sys | |||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) | |||
| import numpy as np | |||
| import argparse | |||
| from deepdoc.vision import OCR, init_in_out | |||
| from deepdoc.vision.seeit import draw_box | |||
| def main(args): | |||
| ocr = OCR() | |||
| images, outputs = init_in_out(args) | |||
| for i, img in enumerate(images): | |||
| bxs = ocr(np.array(img)) | |||
| bxs = [(line[0], line[1][0]) for line in bxs] | |||
| bxs = [{ | |||
| "text": t, | |||
| "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], | |||
| "type": "ocr", | |||
| "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] | |||
| img = draw_box(images[i], bxs, ["ocr"], 1.) | |||
| img.save(outputs[i], quality=95) | |||
| with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs])) | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--inputs', | |||
| help="Directory where to store images or PDFs, or a file path to a single image or PDF", | |||
| required=True) | |||
| parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'", | |||
| default="./ocr_outputs") | |||
| args = parser.parse_args() | |||
| main(args) | |||
| @@ -0,0 +1,173 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os, sys | |||
| import re | |||
| import numpy as np | |||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) | |||
| import argparse | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out | |||
| from deepdoc.vision.seeit import draw_box | |||
| def main(args): | |||
| images, outputs = init_in_out(args) | |||
| if args.mode.lower() == "layout": | |||
| labels = LayoutRecognizer.labels | |||
| detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| if args.mode.lower() == "tsr": | |||
| labels = TableStructureRecognizer.labels | |||
| detr = TableStructureRecognizer() | |||
| ocr = OCR() | |||
| layouts = detr(images, float(args.threshold)) | |||
| for i, lyt in enumerate(layouts): | |||
| if args.mode.lower() == "tsr": | |||
| #lyt = [t for t in lyt if t["type"] == "table column"] | |||
| html = get_table_html(images[i], lyt, ocr) | |||
| with open(outputs[i]+".html", "w+") as f: f.write(html) | |||
| lyt = [{ | |||
| "type": t["label"], | |||
| "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], | |||
| "score": t["score"] | |||
| } for t in lyt] | |||
| img = draw_box(images[i], lyt, labels, float(args.threshold)) | |||
| img.save(outputs[i], quality=95) | |||
| print("save result to: " + outputs[i]) | |||
| def get_table_html(img, tb_cpns, ocr): | |||
| boxes = ocr(np.array(img)) | |||
| boxes = Recognizer.sort_Y_firstly( | |||
| [{"x0": b[0][0], "x1": b[1][0], | |||
| "top": b[0][1], "text": t[0], | |||
| "bottom": b[-1][1], | |||
| "layout_type": "table", | |||
| "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], | |||
| np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3 | |||
| ) | |||
| def gather(kwd, fzy=10, ption=0.6): | |||
| nonlocal boxes | |||
| eles = Recognizer.sort_Y_firstly( | |||
| [r for r in tb_cpns if re.match(kwd, r["label"])], fzy) | |||
| eles = Recognizer.layouts_cleanup(boxes, eles, 5, ption) | |||
| return Recognizer.sort_Y_firstly(eles, 0) | |||
| headers = gather(r".*header$") | |||
| rows = gather(r".* (row|header)") | |||
| spans = gather(r".*spanning") | |||
| clmns = sorted([r for r in tb_cpns if re.match( | |||
| r"table column$", r["label"])], key=lambda x: x["x0"]) | |||
| clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) | |||
| for b in boxes: | |||
| ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) | |||
| if ii is not None: | |||
| b["R"] = ii | |||
| b["R_top"] = rows[ii]["top"] | |||
| b["R_bott"] = rows[ii]["bottom"] | |||
| ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3) | |||
| if ii is not None: | |||
| b["H_top"] = headers[ii]["top"] | |||
| b["H_bott"] = headers[ii]["bottom"] | |||
| b["H_left"] = headers[ii]["x0"] | |||
| b["H_right"] = headers[ii]["x1"] | |||
| b["H"] = ii | |||
| ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) | |||
| if ii is not None: | |||
| b["C"] = ii | |||
| b["C_left"] = clmns[ii]["x0"] | |||
| b["C_right"] = clmns[ii]["x1"] | |||
| ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3) | |||
| if ii is not None: | |||
| b["H_top"] = spans[ii]["top"] | |||
| b["H_bott"] = spans[ii]["bottom"] | |||
| b["H_left"] = spans[ii]["x0"] | |||
| b["H_right"] = spans[ii]["x1"] | |||
| b["SP"] = ii | |||
| html = """ | |||
| <html> | |||
| <head> | |||
| <style> | |||
| ._table_1nkzy_11 { | |||
| margin: auto; | |||
| width: 70%%; | |||
| padding: 10px; | |||
| } | |||
| ._table_1nkzy_11 p { | |||
| margin-bottom: 50px; | |||
| border: 1px solid #e1e1e1; | |||
| } | |||
| caption { | |||
| color: #6ac1ca; | |||
| font-size: 20px; | |||
| height: 50px; | |||
| line-height: 50px; | |||
| font-weight: 600; | |||
| margin-bottom: 10px; | |||
| } | |||
| ._table_1nkzy_11 table { | |||
| width: 100%%; | |||
| border-collapse: collapse; | |||
| } | |||
| th { | |||
| color: #fff; | |||
| background-color: #6ac1ca; | |||
| } | |||
| td:hover { | |||
| background: #c1e8e8; | |||
| } | |||
| tr:nth-child(even) { | |||
| background-color: #f2f2f2; | |||
| } | |||
| ._table_1nkzy_11 th, | |||
| ._table_1nkzy_11 td { | |||
| text-align: center; | |||
| border: 1px solid #ddd; | |||
| padding: 8px; | |||
| } | |||
| </style> | |||
| </head> | |||
| <body> | |||
| %s | |||
| </body> | |||
| </html> | |||
| """% TableStructureRecognizer.construct_table(boxes, html=True) | |||
| return html | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--inputs', | |||
| help="Directory where to store images or PDFs, or a file path to a single image or PDF", | |||
| required=True) | |||
| parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", | |||
| default="./layouts_outputs") | |||
| parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5) | |||
| parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], | |||
| default="layout") | |||
| args = parser.parse_args() | |||
| main(args) | |||
| @@ -1,3 +1,15 @@ | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import os | |||
| import re | |||
| @@ -12,15 +24,16 @@ from .recognizer import Recognizer | |||
| class TableStructureRecognizer(Recognizer): | |||
| labels = [ | |||
| "table", | |||
| "table column", | |||
| "table row", | |||
| "table column header", | |||
| "table projected row header", | |||
| "table spanning cell", | |||
| ] | |||
| def __init__(self): | |||
| self.labels = [ | |||
| "table", | |||
| "table column", | |||
| "table row", | |||
| "table column header", | |||
| "table projected row header", | |||
| "table spanning cell", | |||
| ] | |||
| super().__init__(self.labels, "tsr", | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| @@ -79,7 +92,8 @@ class TableStructureRecognizer(Recognizer): | |||
| return True | |||
| return False | |||
| def __blockType(self, b): | |||
| @staticmethod | |||
| def blockType(b): | |||
| patt = [ | |||
| ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | |||
| (r"^(20|19)[0-9]{2}年$", "Dt"), | |||
| @@ -109,11 +123,12 @@ class TableStructureRecognizer(Recognizer): | |||
| return "Ot" | |||
| def construct_table(self, boxes, is_english=False, html=False): | |||
| @staticmethod | |||
| def construct_table(boxes, is_english=False, html=False): | |||
| cap = "" | |||
| i = 0 | |||
| while i < len(boxes): | |||
| if self.is_caption(boxes[i]): | |||
| if TableStructureRecognizer.is_caption(boxes[i]): | |||
| cap += boxes[i]["text"] | |||
| boxes.pop(i) | |||
| i -= 1 | |||
| @@ -122,14 +137,15 @@ class TableStructureRecognizer(Recognizer): | |||
| if not boxes: | |||
| return [] | |||
| for b in boxes: | |||
| b["btype"] = self.__blockType(b) | |||
| b["btype"] = TableStructureRecognizer.blockType(b) | |||
| max_type = Counter([b["btype"] for b in boxes]).items() | |||
| max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" | |||
| logging.debug("MAXTYPE: " + max_type) | |||
| rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] | |||
| rowh = np.min(rowh) if rowh else 0 | |||
| boxes = self.sort_R_firstly(boxes, rowh / 2) | |||
| boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) | |||
| #for b in boxes:print(b) | |||
| boxes[0]["rn"] = 0 | |||
| rows = [[boxes[0]]] | |||
| btm = boxes[0]["bottom"] | |||
| @@ -150,9 +166,9 @@ class TableStructureRecognizer(Recognizer): | |||
| colwm = np.min(colwm) if colwm else 0 | |||
| crosspage = len(set([b["page_number"] for b in boxes])) > 1 | |||
| if crosspage: | |||
| boxes = self.sort_X_firstly(boxes, colwm / 2, False) | |||
| boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False) | |||
| else: | |||
| boxes = self.sort_C_firstly(boxes, colwm / 2) | |||
| boxes = Recognizer.sort_C_firstly(boxes, colwm / 2) | |||
| boxes[0]["cn"] = 0 | |||
| cols = [[boxes[0]]] | |||
| right = boxes[0]["x1"] | |||
| @@ -313,16 +329,18 @@ class TableStructureRecognizer(Recognizer): | |||
| hdset.add(i) | |||
| if html: | |||
| return [self.__html_table(cap, hdset, | |||
| self.__cal_spans(boxes, rows, | |||
| cols, tbl, True) | |||
| )] | |||
| return TableStructureRecognizer.__html_table(cap, hdset, | |||
| TableStructureRecognizer.__cal_spans(boxes, rows, | |||
| cols, tbl, True) | |||
| ) | |||
| return self.__desc_table(cap, hdset, | |||
| self.__cal_spans(boxes, rows, cols, tbl, False), | |||
| is_english) | |||
| return TableStructureRecognizer.__desc_table(cap, hdset, | |||
| TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, | |||
| False), | |||
| is_english) | |||
| def __html_table(self, cap, hdset, tbl): | |||
| @staticmethod | |||
| def __html_table(cap, hdset, tbl): | |||
| # constrcut HTML | |||
| html = "<table>" | |||
| if cap: | |||
| @@ -339,8 +357,8 @@ class TableStructureRecognizer(Recognizer): | |||
| txt = "" | |||
| if arr: | |||
| h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) | |||
| txt = "".join([c["text"] | |||
| for c in self.sort_Y_firstly(arr, h)]) | |||
| txt = " ".join([c["text"] | |||
| for c in Recognizer.sort_Y_firstly(arr, h)]) | |||
| txts.append(txt) | |||
| sp = "" | |||
| if arr[0].get("colspan"): | |||
| @@ -366,7 +384,8 @@ class TableStructureRecognizer(Recognizer): | |||
| html += "\n</table>" | |||
| return html | |||
| def __desc_table(self, cap, hdr_rowno, tbl, is_english): | |||
| @staticmethod | |||
| def __desc_table(cap, hdr_rowno, tbl, is_english): | |||
| # get text of every colomn in header row to become header text | |||
| clmno = len(tbl[0]) | |||
| rowno = len(tbl) | |||
| @@ -469,7 +488,8 @@ class TableStructureRecognizer(Recognizer): | |||
| row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] | |||
| return row_txt | |||
| def __cal_spans(self, boxes, rows, cols, tbl, html=True): | |||
| @staticmethod | |||
| def __cal_spans(boxes, rows, cols, tbl, html=True): | |||
| # caculate span | |||
| clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) | |||
| for cln in cols] | |||
| @@ -553,4 +573,3 @@ class TableStructureRecognizer(Recognizer): | |||
| tbl[rowspan[0]][colspan[0]] = arr | |||
| return tbl | |||