| conv = { | conv = { | ||||
| "id": get_uuid(), | "id": get_uuid(), | ||||
| "dialog_id": req["dialog_id"], | "dialog_id": req["dialog_id"], | ||||
| "name": "New conversation", | |||||
| "name": req.get("name", "New conversation"), | |||||
| "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] | "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] | ||||
| } | } | ||||
| ConversationService.save(**conv) | ConversationService.save(**conv) | ||||
| def list_convsersation(): | def list_convsersation(): | ||||
| dialog_id = request.args["dialog_id"] | dialog_id = request.args["dialog_id"] | ||||
| try: | try: | ||||
| convs = ConversationService.query(dialog_id=dialog_id) | |||||
| convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) | |||||
| convs = [d.to_dict() for d in convs] | convs = [d.to_dict() for d in convs] | ||||
| return get_json_result(data=convs) | return get_json_result(data=convs) | ||||
| except Exception as e: | except Exception as e: |
| pass | pass | ||||
| def traversal_files(base): | |||||
| for root, ds, fs in os.walk(base): | |||||
| for f in fs: | |||||
| fullname = os.path.join(root, f) | |||||
| yield fullname | |||||
| With a bunch of documents from various domains with various formats and along with diverse retrieval requirements, | With a bunch of documents from various domains with various formats and along with diverse retrieval requirements, | ||||
| an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose. | an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose. | ||||
| There 2 parts in *Deep*Doc so far: vision and parser. | |||||
| There are 2 parts in *Deep*Doc so far: vision and parser. | |||||
| You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR. | |||||
| ```bash | |||||
| python deepdoc/vision/t_ocr.py -h | |||||
| usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] | |||||
| options: | |||||
| -h, --help show this help message and exit | |||||
| --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF | |||||
| --output_dir OUTPUT_DIR | |||||
| Directory where to store the output images. Default: './ocr_outputs' | |||||
| ``` | |||||
| ```bash | |||||
| python deepdoc/vision/t_recognizer.py -h | |||||
| usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}] | |||||
| options: | |||||
| -h, --help show this help message and exit | |||||
| --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF | |||||
| --output_dir OUTPUT_DIR | |||||
| Directory where to store the output images. Default: './layouts_outputs' | |||||
| --threshold THRESHOLD | |||||
| A threshold to filter out detections. Default: 0.5 | |||||
| --mode {layout,tsr} Task mode: layout recognition or table structure recognition | |||||
| ``` | |||||
| Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!! | |||||
| ```bash | |||||
| export HF_ENDPOINT=https://hf-mirror.com | |||||
| ``` | |||||
| <a name="2"></a> | <a name="2"></a> | ||||
| ## 2. Vision | ## 2. Vision | ||||
| We use vision information to resolve problems as human being. | We use vision information to resolve problems as human being. | ||||
| - OCR. Since a lot of documents presented as images or at least be able to transform to image, | - OCR. Since a lot of documents presented as images or at least be able to transform to image, | ||||
| OCR is a very essential and fundamental or even universal solution for text extraction. | OCR is a very essential and fundamental or even universal solution for text extraction. | ||||
| ```bash | |||||
| python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result | |||||
| ``` | |||||
| The inputs could be directory to images or PDF, or a image or PDF. | |||||
| You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results, | |||||
| txt files which contain the OCR text. | |||||
| <div align="center" style="margin-top:20px;margin-bottom:20px;"> | <div align="center" style="margin-top:20px;margin-bottom:20px;"> | ||||
| <img src="https://lh6.googleusercontent.com/2xdiSjaGWkZ71YdORc71Ujf7jCHmO6G-6ONklzGiUYEh3QZpjPo6MQ9eqEFX20am_cdW4Ck0YRraXEetXWnM08kJd99yhik13Cy0_YKUAq2zVGR15LzkovRAmK9iT4o3hcJ8dTpspaJKUwt6R4gN7So" width="300"/> | |||||
| <img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/> | |||||
| </div> | </div> | ||||
| - Layout recognition. Documents from different domain may have various layouts, | - Layout recognition. Documents from different domain may have various layouts, | ||||
| - Footer | - Footer | ||||
| - Reference | - Reference | ||||
| - Equation | - Equation | ||||
| Have a try on the following command to see the layout detection results. | |||||
| ```bash | |||||
| python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result | |||||
| ``` | |||||
| The inputs could be directory to images or PDF, or a image or PDF. | |||||
| You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following: | |||||
| <div align="center" style="margin-top:20px;margin-bottom:20px;"> | <div align="center" style="margin-top:20px;margin-bottom:20px;"> | ||||
| <img src="https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/layout/layout.png?raw=true" width="900"/> | |||||
| <img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/> | |||||
| </div> | </div> | ||||
| - Table Structure Recognition(TSR). Data table is a frequently used structure present data including numbers or text. | |||||
| - Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text. | |||||
| And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers. | And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers. | ||||
| Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM. | Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM. | ||||
| We have five labels for TSR task: | We have five labels for TSR task: | ||||
| - Column header | - Column header | ||||
| - Projected row header | - Projected row header | ||||
| - Spanning cell | - Spanning cell | ||||
| Have a try on the following command to see the layout detection results. | |||||
| ```bash | |||||
| python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result | |||||
| ``` | |||||
| The inputs could be directory to images or PDF, or a image or PDF. | |||||
| You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following: | |||||
| <div align="center" style="margin-top:20px;margin-bottom:20px;"> | <div align="center" style="margin-top:20px;margin-bottom:20px;"> | ||||
| <img src="https://user-images.githubusercontent.com/10793386/139559159-cd23c972-8731-48ed-91df-f3f27e9f4d79.jpg" width="900"/> | |||||
| <img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/> | |||||
| </div> | </div> | ||||
| <a name="3"></a> | <a name="3"></a> | ||||
| with various layouts could be resolved into structured data composed of nearly a hundred of fields. | with various layouts could be resolved into structured data composed of nearly a hundred of fields. | ||||
| We haven't opened the parser yet, as we open the processing method after parsing procedure. | We haven't opened the parser yet, as we open the processing method after parsing procedure. | ||||
| from .ocr import OCR | from .ocr import OCR | ||||
| from .recognizer import Recognizer | from .recognizer import Recognizer | ||||
| from .layout_recognizer import LayoutRecognizer | from .layout_recognizer import LayoutRecognizer | ||||
| from .table_structure_recognizer import TableStructureRecognizer | from .table_structure_recognizer import TableStructureRecognizer | ||||
| def init_in_out(args): | |||||
| from PIL import Image | |||||
| import fitz | |||||
| import os | |||||
| import traceback | |||||
| from api.utils.file_utils import traversal_files | |||||
| images = [] | |||||
| outputs = [] | |||||
| if not os.path.exists(args.output_dir): | |||||
| os.mkdir(args.output_dir) | |||||
| def pdf_pages(fnm, zoomin=3): | |||||
| nonlocal outputs, images | |||||
| pdf = fitz.open(fnm) | |||||
| mat = fitz.Matrix(zoomin, zoomin) | |||||
| for i, page in enumerate(pdf): | |||||
| pix = page.get_pixmap(matrix=mat) | |||||
| img = Image.frombytes("RGB", [pix.width, pix.height], | |||||
| pix.samples) | |||||
| images.append(img) | |||||
| outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg") | |||||
| def images_and_outputs(fnm): | |||||
| nonlocal outputs, images | |||||
| if fnm.split(".")[-1].lower() == "pdf": | |||||
| pdf_pages(fnm) | |||||
| return | |||||
| try: | |||||
| images.append(Image.open(fnm)) | |||||
| outputs.append(os.path.split(fnm)[-1]) | |||||
| except Exception as e: | |||||
| traceback.print_exc() | |||||
| if os.path.isdir(args.inputs): | |||||
| for fnm in traversal_files(args.inputs): | |||||
| images_and_outputs(fnm) | |||||
| else: | |||||
| images_and_outputs(args.inputs) | |||||
| for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i]) | |||||
| return images, outputs |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os | import os | ||||
| import re | import re | ||||
| from collections import Counter | from collections import Counter | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| import numpy as np | import numpy as np | ||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| from .recognizer import Recognizer | |||||
| from deepdoc.vision import Recognizer | |||||
| class LayoutRecognizer(Recognizer): | class LayoutRecognizer(Recognizer): | ||||
| def __init__(self, domain): | |||||
| self.layout_labels = [ | |||||
| labels = [ | |||||
| "_background_", | "_background_", | ||||
| "Text", | "Text", | ||||
| "Title", | "Title", | ||||
| "Reference", | "Reference", | ||||
| "Equation", | "Equation", | ||||
| ] | ] | ||||
| super().__init__(self.layout_labels, domain, | |||||
| def __init__(self, domain): | |||||
| super().__init__(self.labels, domain, | |||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | ||||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): | ||||
| return any([re.search(p, b["text"]) for p in patt]) | return any([re.search(p, b["text"]) for p in patt]) | ||||
| layouts = super().__call__(image_list, thr, batch_size) | layouts = super().__call__(image_list, thr, batch_size) | ||||
| # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7) | |||||
| # save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7) | |||||
| assert len(image_list) == len(ocr_res) | assert len(image_list) == len(ocr_res) | ||||
| # Tag layout type | # Tag layout type | ||||
| boxes = [] | boxes = [] | ||||
| ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | ||||
| return ocr_res, page_layout | return ocr_res, page_layout | ||||
| import onnxruntime as ort | import onnxruntime as ort | ||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||
| from . import seeit | |||||
| from .operators import * | from .operators import * | ||||
| from rag.settings import cron_logger | from rag.settings import cron_logger | ||||
| """ | """ | ||||
| if not model_dir: | if not model_dir: | ||||
| model_dir = snapshot_download(repo_id="InfiniFlow/ocr") | |||||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") | |||||
| model_file_path = os.path.join(model_dir, task_name + ".onnx") | model_file_path = os.path.join(model_dir, task_name + ".onnx") | ||||
| if not os.path.exists(model_file_path): | if not os.path.exists(model_file_path): | ||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) | ||||
| else: | else: | ||||
| self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) | ||||
| self.input_names = [node.name for node in self.ort_sess.get_inputs()] | |||||
| self.output_names = [node.name for node in self.ort_sess.get_outputs()] | |||||
| self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4] | |||||
| self.label_list = label_list | self.label_list = label_list | ||||
| @staticmethod | @staticmethod | ||||
| return max_overlaped_i | return max_overlaped_i | ||||
| def preprocess(self, image_list): | def preprocess(self, image_list): | ||||
| preprocess_ops = [] | |||||
| for op_info in [ | |||||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||||
| {'type': 'Permute'}, | |||||
| {'stride': 32, 'type': 'PadStride'} | |||||
| ]: | |||||
| new_op_info = op_info.copy() | |||||
| op_type = new_op_info.pop('type') | |||||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||||
| inputs = [] | inputs = [] | ||||
| for im_path in image_list: | |||||
| im, im_info = preprocess(im_path, preprocess_ops) | |||||
| inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||||
| if "scale_factor" in self.input_names: | |||||
| preprocess_ops = [] | |||||
| for op_info in [ | |||||
| {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'}, | |||||
| {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'}, | |||||
| {'type': 'Permute'}, | |||||
| {'stride': 32, 'type': 'PadStride'} | |||||
| ]: | |||||
| new_op_info = op_info.copy() | |||||
| op_type = new_op_info.pop('type') | |||||
| preprocess_ops.append(eval(op_type)(**new_op_info)) | |||||
| for im_path in image_list: | |||||
| im, im_info = preprocess(im_path, preprocess_ops) | |||||
| inputs.append({"image": np.array((im,)).astype('float32'), | |||||
| "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) | |||||
| else: | |||||
| hh, ww = self.input_shape | |||||
| for img in image_list: | |||||
| h, w = img.shape[:2] | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||||
| img = cv2.resize(np.array(img).astype('float32'), (ww, hh)) | |||||
| # Scale input pixel values to 0 to 1 | |||||
| img /= 255.0 | |||||
| img = img.transpose(2, 0, 1) | |||||
| img = img[np.newaxis, :, :, :].astype(np.float32) | |||||
| inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]}) | |||||
| return inputs | return inputs | ||||
| def postprocess(self, boxes, inputs, thr): | |||||
| if "scale_factor" in self.input_names: | |||||
| bb = [] | |||||
| for b in boxes: | |||||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||||
| if score < thr: | |||||
| continue | |||||
| if clsid >= len(self.label_list): | |||||
| cron_logger.warning(f"bad category id") | |||||
| continue | |||||
| bb.append({ | |||||
| "type": self.label_list[clsid].lower(), | |||||
| "bbox": [float(t) for t in bbox.tolist()], | |||||
| "score": float(score) | |||||
| }) | |||||
| return bb | |||||
| def xywh2xyxy(x): | |||||
| # [x, y, w, h] to [x1, y1, x2, y2] | |||||
| y = np.copy(x) | |||||
| y[:, 0] = x[:, 0] - x[:, 2] / 2 | |||||
| y[:, 1] = x[:, 1] - x[:, 3] / 2 | |||||
| y[:, 2] = x[:, 0] + x[:, 2] / 2 | |||||
| y[:, 3] = x[:, 1] + x[:, 3] / 2 | |||||
| return y | |||||
| def compute_iou(box, boxes): | |||||
| # Compute xmin, ymin, xmax, ymax for both boxes | |||||
| xmin = np.maximum(box[0], boxes[:, 0]) | |||||
| ymin = np.maximum(box[1], boxes[:, 1]) | |||||
| xmax = np.minimum(box[2], boxes[:, 2]) | |||||
| ymax = np.minimum(box[3], boxes[:, 3]) | |||||
| # Compute intersection area | |||||
| intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) | |||||
| # Compute union area | |||||
| box_area = (box[2] - box[0]) * (box[3] - box[1]) | |||||
| boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||||
| union_area = box_area + boxes_area - intersection_area | |||||
| # Compute IoU | |||||
| iou = intersection_area / union_area | |||||
| return iou | |||||
| def iou_filter(boxes, scores, iou_threshold): | |||||
| sorted_indices = np.argsort(scores)[::-1] | |||||
| keep_boxes = [] | |||||
| while sorted_indices.size > 0: | |||||
| # Pick the last box | |||||
| box_id = sorted_indices[0] | |||||
| keep_boxes.append(box_id) | |||||
| # Compute IoU of the picked box with the rest | |||||
| ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) | |||||
| # Remove boxes with IoU over the threshold | |||||
| keep_indices = np.where(ious < iou_threshold)[0] | |||||
| # print(keep_indices.shape, sorted_indices.shape) | |||||
| sorted_indices = sorted_indices[keep_indices + 1] | |||||
| return keep_boxes | |||||
| boxes = np.squeeze(boxes).T | |||||
| # Filter out object confidence scores below threshold | |||||
| scores = np.max(boxes[:, 4:], axis=1) | |||||
| boxes = boxes[scores > thr, :] | |||||
| scores = scores[scores > thr] | |||||
| if len(boxes) == 0: return [] | |||||
| # Get the class with the highest confidence | |||||
| class_ids = np.argmax(boxes[:, 4:], axis=1) | |||||
| boxes = boxes[:, :4] | |||||
| input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]]) | |||||
| boxes = np.multiply(boxes, input_shape, dtype=np.float32) | |||||
| boxes = xywh2xyxy(boxes) | |||||
| unique_class_ids = np.unique(class_ids) | |||||
| indices = [] | |||||
| for class_id in unique_class_ids: | |||||
| class_indices = np.where(class_ids == class_id)[0] | |||||
| class_boxes = boxes[class_indices, :] | |||||
| class_scores = scores[class_indices] | |||||
| class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2) | |||||
| indices.extend(class_indices[class_keep_boxes]) | |||||
| return [{ | |||||
| "type": self.label_list[class_ids[i]].lower(), | |||||
| "bbox": [float(t) for t in boxes[i].tolist()], | |||||
| "score": float(scores[i]) | |||||
| } for i in indices] | |||||
| def __call__(self, image_list, thr=0.7, batch_size=16): | def __call__(self, image_list, thr=0.7, batch_size=16): | ||||
| res = [] | res = [] | ||||
| imgs = [] | imgs = [] | ||||
| end_index = min((i + 1) * batch_size, len(imgs)) | end_index = min((i + 1) * batch_size, len(imgs)) | ||||
| batch_image_list = imgs[start_index:end_index] | batch_image_list = imgs[start_index:end_index] | ||||
| inputs = self.preprocess(batch_image_list) | inputs = self.preprocess(batch_image_list) | ||||
| print("preprocess") | |||||
| for ins in inputs: | for ins in inputs: | ||||
| bb = [] | |||||
| for b in self.ort_sess.run(None, ins)[0]: | |||||
| clsid, bbox, score = int(b[0]), b[2:], b[1] | |||||
| if score < thr: | |||||
| continue | |||||
| if clsid >= len(self.label_list): | |||||
| cron_logger.warning(f"bad category id") | |||||
| continue | |||||
| bb.append({ | |||||
| "type": self.label_list[clsid].lower(), | |||||
| "bbox": [float(t) for t in bbox.tolist()], | |||||
| "score": float(score) | |||||
| }) | |||||
| bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr) | |||||
| res.append(bb) | res.append(bb) | ||||
| #seeit.save_results(image_list, res, self.label_list, threshold=thr) | #seeit.save_results(image_list, res, self.label_list, threshold=thr) | ||||
| return res | return res | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os, sys | |||||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) | |||||
| import numpy as np | |||||
| import argparse | |||||
| from deepdoc.vision import OCR, init_in_out | |||||
| from deepdoc.vision.seeit import draw_box | |||||
| def main(args): | |||||
| ocr = OCR() | |||||
| images, outputs = init_in_out(args) | |||||
| for i, img in enumerate(images): | |||||
| bxs = ocr(np.array(img)) | |||||
| bxs = [(line[0], line[1][0]) for line in bxs] | |||||
| bxs = [{ | |||||
| "text": t, | |||||
| "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], | |||||
| "type": "ocr", | |||||
| "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] | |||||
| img = draw_box(images[i], bxs, ["ocr"], 1.) | |||||
| img.save(outputs[i], quality=95) | |||||
| with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs])) | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('--inputs', | |||||
| help="Directory where to store images or PDFs, or a file path to a single image or PDF", | |||||
| required=True) | |||||
| parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'", | |||||
| default="./ocr_outputs") | |||||
| args = parser.parse_args() | |||||
| main(args) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import os, sys | |||||
| import re | |||||
| import numpy as np | |||||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) | |||||
| import argparse | |||||
| from api.utils.file_utils import get_project_base_directory | |||||
| from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out | |||||
| from deepdoc.vision.seeit import draw_box | |||||
| def main(args): | |||||
| images, outputs = init_in_out(args) | |||||
| if args.mode.lower() == "layout": | |||||
| labels = LayoutRecognizer.labels | |||||
| detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| if args.mode.lower() == "tsr": | |||||
| labels = TableStructureRecognizer.labels | |||||
| detr = TableStructureRecognizer() | |||||
| ocr = OCR() | |||||
| layouts = detr(images, float(args.threshold)) | |||||
| for i, lyt in enumerate(layouts): | |||||
| if args.mode.lower() == "tsr": | |||||
| #lyt = [t for t in lyt if t["type"] == "table column"] | |||||
| html = get_table_html(images[i], lyt, ocr) | |||||
| with open(outputs[i]+".html", "w+") as f: f.write(html) | |||||
| lyt = [{ | |||||
| "type": t["label"], | |||||
| "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], | |||||
| "score": t["score"] | |||||
| } for t in lyt] | |||||
| img = draw_box(images[i], lyt, labels, float(args.threshold)) | |||||
| img.save(outputs[i], quality=95) | |||||
| print("save result to: " + outputs[i]) | |||||
| def get_table_html(img, tb_cpns, ocr): | |||||
| boxes = ocr(np.array(img)) | |||||
| boxes = Recognizer.sort_Y_firstly( | |||||
| [{"x0": b[0][0], "x1": b[1][0], | |||||
| "top": b[0][1], "text": t[0], | |||||
| "bottom": b[-1][1], | |||||
| "layout_type": "table", | |||||
| "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], | |||||
| np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3 | |||||
| ) | |||||
| def gather(kwd, fzy=10, ption=0.6): | |||||
| nonlocal boxes | |||||
| eles = Recognizer.sort_Y_firstly( | |||||
| [r for r in tb_cpns if re.match(kwd, r["label"])], fzy) | |||||
| eles = Recognizer.layouts_cleanup(boxes, eles, 5, ption) | |||||
| return Recognizer.sort_Y_firstly(eles, 0) | |||||
| headers = gather(r".*header$") | |||||
| rows = gather(r".* (row|header)") | |||||
| spans = gather(r".*spanning") | |||||
| clmns = sorted([r for r in tb_cpns if re.match( | |||||
| r"table column$", r["label"])], key=lambda x: x["x0"]) | |||||
| clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) | |||||
| for b in boxes: | |||||
| ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) | |||||
| if ii is not None: | |||||
| b["R"] = ii | |||||
| b["R_top"] = rows[ii]["top"] | |||||
| b["R_bott"] = rows[ii]["bottom"] | |||||
| ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3) | |||||
| if ii is not None: | |||||
| b["H_top"] = headers[ii]["top"] | |||||
| b["H_bott"] = headers[ii]["bottom"] | |||||
| b["H_left"] = headers[ii]["x0"] | |||||
| b["H_right"] = headers[ii]["x1"] | |||||
| b["H"] = ii | |||||
| ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) | |||||
| if ii is not None: | |||||
| b["C"] = ii | |||||
| b["C_left"] = clmns[ii]["x0"] | |||||
| b["C_right"] = clmns[ii]["x1"] | |||||
| ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3) | |||||
| if ii is not None: | |||||
| b["H_top"] = spans[ii]["top"] | |||||
| b["H_bott"] = spans[ii]["bottom"] | |||||
| b["H_left"] = spans[ii]["x0"] | |||||
| b["H_right"] = spans[ii]["x1"] | |||||
| b["SP"] = ii | |||||
| html = """ | |||||
| <html> | |||||
| <head> | |||||
| <style> | |||||
| ._table_1nkzy_11 { | |||||
| margin: auto; | |||||
| width: 70%%; | |||||
| padding: 10px; | |||||
| } | |||||
| ._table_1nkzy_11 p { | |||||
| margin-bottom: 50px; | |||||
| border: 1px solid #e1e1e1; | |||||
| } | |||||
| caption { | |||||
| color: #6ac1ca; | |||||
| font-size: 20px; | |||||
| height: 50px; | |||||
| line-height: 50px; | |||||
| font-weight: 600; | |||||
| margin-bottom: 10px; | |||||
| } | |||||
| ._table_1nkzy_11 table { | |||||
| width: 100%%; | |||||
| border-collapse: collapse; | |||||
| } | |||||
| th { | |||||
| color: #fff; | |||||
| background-color: #6ac1ca; | |||||
| } | |||||
| td:hover { | |||||
| background: #c1e8e8; | |||||
| } | |||||
| tr:nth-child(even) { | |||||
| background-color: #f2f2f2; | |||||
| } | |||||
| ._table_1nkzy_11 th, | |||||
| ._table_1nkzy_11 td { | |||||
| text-align: center; | |||||
| border: 1px solid #ddd; | |||||
| padding: 8px; | |||||
| } | |||||
| </style> | |||||
| </head> | |||||
| <body> | |||||
| %s | |||||
| </body> | |||||
| </html> | |||||
| """% TableStructureRecognizer.construct_table(boxes, html=True) | |||||
| return html | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('--inputs', | |||||
| help="Directory where to store images or PDFs, or a file path to a single image or PDF", | |||||
| required=True) | |||||
| parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", | |||||
| default="./layouts_outputs") | |||||
| parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5) | |||||
| parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], | |||||
| default="layout") | |||||
| args = parser.parse_args() | |||||
| main(args) |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| import re | import re | ||||
| class TableStructureRecognizer(Recognizer): | class TableStructureRecognizer(Recognizer): | ||||
| labels = [ | |||||
| "table", | |||||
| "table column", | |||||
| "table row", | |||||
| "table column header", | |||||
| "table projected row header", | |||||
| "table spanning cell", | |||||
| ] | |||||
| def __init__(self): | def __init__(self): | ||||
| self.labels = [ | |||||
| "table", | |||||
| "table column", | |||||
| "table row", | |||||
| "table column header", | |||||
| "table projected row header", | |||||
| "table spanning cell", | |||||
| ] | |||||
| super().__init__(self.labels, "tsr", | super().__init__(self.labels, "tsr", | ||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | ||||
| return True | return True | ||||
| return False | return False | ||||
| def __blockType(self, b): | |||||
| @staticmethod | |||||
| def blockType(b): | |||||
| patt = [ | patt = [ | ||||
| ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | ||||
| (r"^(20|19)[0-9]{2}年$", "Dt"), | (r"^(20|19)[0-9]{2}年$", "Dt"), | ||||
| return "Ot" | return "Ot" | ||||
| def construct_table(self, boxes, is_english=False, html=False): | |||||
| @staticmethod | |||||
| def construct_table(boxes, is_english=False, html=False): | |||||
| cap = "" | cap = "" | ||||
| i = 0 | i = 0 | ||||
| while i < len(boxes): | while i < len(boxes): | ||||
| if self.is_caption(boxes[i]): | |||||
| if TableStructureRecognizer.is_caption(boxes[i]): | |||||
| cap += boxes[i]["text"] | cap += boxes[i]["text"] | ||||
| boxes.pop(i) | boxes.pop(i) | ||||
| i -= 1 | i -= 1 | ||||
| if not boxes: | if not boxes: | ||||
| return [] | return [] | ||||
| for b in boxes: | for b in boxes: | ||||
| b["btype"] = self.__blockType(b) | |||||
| b["btype"] = TableStructureRecognizer.blockType(b) | |||||
| max_type = Counter([b["btype"] for b in boxes]).items() | max_type = Counter([b["btype"] for b in boxes]).items() | ||||
| max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" | max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" | ||||
| logging.debug("MAXTYPE: " + max_type) | logging.debug("MAXTYPE: " + max_type) | ||||
| rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] | rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] | ||||
| rowh = np.min(rowh) if rowh else 0 | rowh = np.min(rowh) if rowh else 0 | ||||
| boxes = self.sort_R_firstly(boxes, rowh / 2) | |||||
| boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) | |||||
| #for b in boxes:print(b) | |||||
| boxes[0]["rn"] = 0 | boxes[0]["rn"] = 0 | ||||
| rows = [[boxes[0]]] | rows = [[boxes[0]]] | ||||
| btm = boxes[0]["bottom"] | btm = boxes[0]["bottom"] | ||||
| colwm = np.min(colwm) if colwm else 0 | colwm = np.min(colwm) if colwm else 0 | ||||
| crosspage = len(set([b["page_number"] for b in boxes])) > 1 | crosspage = len(set([b["page_number"] for b in boxes])) > 1 | ||||
| if crosspage: | if crosspage: | ||||
| boxes = self.sort_X_firstly(boxes, colwm / 2, False) | |||||
| boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False) | |||||
| else: | else: | ||||
| boxes = self.sort_C_firstly(boxes, colwm / 2) | |||||
| boxes = Recognizer.sort_C_firstly(boxes, colwm / 2) | |||||
| boxes[0]["cn"] = 0 | boxes[0]["cn"] = 0 | ||||
| cols = [[boxes[0]]] | cols = [[boxes[0]]] | ||||
| right = boxes[0]["x1"] | right = boxes[0]["x1"] | ||||
| hdset.add(i) | hdset.add(i) | ||||
| if html: | if html: | ||||
| return [self.__html_table(cap, hdset, | |||||
| self.__cal_spans(boxes, rows, | |||||
| cols, tbl, True) | |||||
| )] | |||||
| return TableStructureRecognizer.__html_table(cap, hdset, | |||||
| TableStructureRecognizer.__cal_spans(boxes, rows, | |||||
| cols, tbl, True) | |||||
| ) | |||||
| return self.__desc_table(cap, hdset, | |||||
| self.__cal_spans(boxes, rows, cols, tbl, False), | |||||
| is_english) | |||||
| return TableStructureRecognizer.__desc_table(cap, hdset, | |||||
| TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, | |||||
| False), | |||||
| is_english) | |||||
| def __html_table(self, cap, hdset, tbl): | |||||
| @staticmethod | |||||
| def __html_table(cap, hdset, tbl): | |||||
| # constrcut HTML | # constrcut HTML | ||||
| html = "<table>" | html = "<table>" | ||||
| if cap: | if cap: | ||||
| txt = "" | txt = "" | ||||
| if arr: | if arr: | ||||
| h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) | h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) | ||||
| txt = "".join([c["text"] | |||||
| for c in self.sort_Y_firstly(arr, h)]) | |||||
| txt = " ".join([c["text"] | |||||
| for c in Recognizer.sort_Y_firstly(arr, h)]) | |||||
| txts.append(txt) | txts.append(txt) | ||||
| sp = "" | sp = "" | ||||
| if arr[0].get("colspan"): | if arr[0].get("colspan"): | ||||
| html += "\n</table>" | html += "\n</table>" | ||||
| return html | return html | ||||
| def __desc_table(self, cap, hdr_rowno, tbl, is_english): | |||||
| @staticmethod | |||||
| def __desc_table(cap, hdr_rowno, tbl, is_english): | |||||
| # get text of every colomn in header row to become header text | # get text of every colomn in header row to become header text | ||||
| clmno = len(tbl[0]) | clmno = len(tbl[0]) | ||||
| rowno = len(tbl) | rowno = len(tbl) | ||||
| row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] | row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] | ||||
| return row_txt | return row_txt | ||||
| def __cal_spans(self, boxes, rows, cols, tbl, html=True): | |||||
| @staticmethod | |||||
| def __cal_spans(boxes, rows, cols, tbl, html=True): | |||||
| # caculate span | # caculate span | ||||
| clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) | clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) | ||||
| for cln in cols] | for cln in cols] | ||||
| tbl[rowspan[0]][colspan[0]] = arr | tbl[rowspan[0]][colspan[0]] = arr | ||||
| return tbl | return tbl | ||||