You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

layout_recognizer.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. #
  13. import os
  14. import re
  15. from collections import Counter
  16. from copy import deepcopy
  17. import numpy as np
  18. from huggingface_hub import snapshot_download
  19. from api.utils.file_utils import get_project_base_directory
  20. from deepdoc.vision import Recognizer
  21. class LayoutRecognizer(Recognizer):
  22. labels = [
  23. "_background_",
  24. "Text",
  25. "Title",
  26. "Figure",
  27. "Figure caption",
  28. "Table",
  29. "Table caption",
  30. "Header",
  31. "Footer",
  32. "Reference",
  33. "Equation",
  34. ]
  35. def __init__(self, domain):
  36. try:
  37. model_dir = os.path.join(
  38. get_project_base_directory(),
  39. "rag/res/deepdoc")
  40. super().__init__(self.labels, domain, model_dir)
  41. except Exception:
  42. model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
  43. local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
  44. local_dir_use_symlinks=False)
  45. super().__init__(self.labels, domain, model_dir)
  46. self.garbage_layouts = ["footer", "header", "reference"]
  47. def __call__(self, image_list, ocr_res, scale_factor=3,
  48. thr=0.2, batch_size=16, drop=True):
  49. def __is_garbage(b):
  50. patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
  51. r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
  52. "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
  53. "\\(cid *: *[0-9]+ *\\)"
  54. ]
  55. return any([re.search(p, b["text"]) for p in patt])
  56. layouts = super().__call__(image_list, thr, batch_size)
  57. # save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
  58. assert len(image_list) == len(ocr_res)
  59. # Tag layout type
  60. boxes = []
  61. assert len(image_list) == len(layouts)
  62. garbages = {}
  63. page_layout = []
  64. for pn, lts in enumerate(layouts):
  65. bxs = ocr_res[pn]
  66. lts = [{"type": b["type"],
  67. "score": float(b["score"]),
  68. "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
  69. "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
  70. "page_number": pn,
  71. } for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
  72. lts = self.sort_Y_firstly(lts, np.mean(
  73. [lt["bottom"] - lt["top"] for lt in lts]) / 2)
  74. lts = self.layouts_cleanup(bxs, lts)
  75. page_layout.append(lts)
  76. # Tag layout type, layouts are ready
  77. def findLayout(ty):
  78. nonlocal bxs, lts, self
  79. lts_ = [lt for lt in lts if lt["type"] == ty]
  80. i = 0
  81. while i < len(bxs):
  82. if bxs[i].get("layout_type"):
  83. i += 1
  84. continue
  85. if __is_garbage(bxs[i]):
  86. bxs.pop(i)
  87. continue
  88. ii = self.find_overlapped_with_threashold(bxs[i], lts_,
  89. thr=0.4)
  90. if ii is None: # belong to nothing
  91. bxs[i]["layout_type"] = ""
  92. i += 1
  93. continue
  94. lts_[ii]["visited"] = True
  95. keep_feats = [
  96. lts_[
  97. ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
  98. lts_[
  99. ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
  100. ]
  101. if drop and lts_[
  102. ii]["type"] in self.garbage_layouts and not any(keep_feats):
  103. if lts_[ii]["type"] not in garbages:
  104. garbages[lts_[ii]["type"]] = []
  105. garbages[lts_[ii]["type"]].append(bxs[i]["text"])
  106. bxs.pop(i)
  107. continue
  108. bxs[i]["layoutno"] = f"{ty}-{ii}"
  109. bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
  110. ii]["type"] != "equation" else "figure"
  111. i += 1
  112. for lt in ["footer", "header", "reference", "figure caption",
  113. "table caption", "title", "table", "text", "figure", "equation"]:
  114. findLayout(lt)
  115. # add box to figure layouts which has not text box
  116. for i, lt in enumerate(
  117. [lt for lt in lts if lt["type"] in ["figure", "equation"]]):
  118. if lt.get("visited"):
  119. continue
  120. lt = deepcopy(lt)
  121. del lt["type"]
  122. lt["text"] = ""
  123. lt["layout_type"] = "figure"
  124. lt["layoutno"] = f"figure-{i}"
  125. bxs.append(lt)
  126. boxes.extend(bxs)
  127. ocr_res = boxes
  128. garbag_set = set()
  129. for k in garbages.keys():
  130. garbages[k] = Counter(garbages[k])
  131. for g, c in garbages[k].items():
  132. if c > 1:
  133. garbag_set.add(g)
  134. ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
  135. return ocr_res, page_layout