Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

layout_recognizer.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import re
  3. from collections import Counter
  4. from copy import deepcopy
  5. import numpy as np
  6. from api.utils.file_utils import get_project_base_directory
  7. from .recognizer import Recognizer
  8. class LayoutRecognizer(Recognizer):
  9. def __init__(self, domain):
  10. self.layout_labels = [
  11. "_background_",
  12. "Text",
  13. "Title",
  14. "Figure",
  15. "Figure caption",
  16. "Table",
  17. "Table caption",
  18. "Header",
  19. "Footer",
  20. "Reference",
  21. "Equation",
  22. ]
  23. super().__init__(self.layout_labels, domain,
  24. os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
  25. def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
  26. def __is_garbage(b):
  27. patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
  28. r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
  29. "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
  30. "\\(cid *: *[0-9]+ *\\)"
  31. ]
  32. return any([re.search(p, b["text"]) for p in patt])
  33. layouts = super().__call__(image_list, thr, batch_size)
  34. # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
  35. assert len(image_list) == len(ocr_res)
  36. # Tag layout type
  37. boxes = []
  38. assert len(image_list) == len(layouts)
  39. garbages = {}
  40. page_layout = []
  41. for pn, lts in enumerate(layouts):
  42. bxs = ocr_res[pn]
  43. lts = [{"type": b["type"],
  44. "score": float(b["score"]),
  45. "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
  46. "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
  47. "page_number": pn,
  48. } for b in lts]
  49. lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
  50. lts = self.layouts_cleanup(bxs, lts)
  51. page_layout.append(lts)
  52. # Tag layout type, layouts are ready
  53. def findLayout(ty):
  54. nonlocal bxs, lts, self
  55. lts_ = [lt for lt in lts if lt["type"] == ty]
  56. i = 0
  57. while i < len(bxs):
  58. if bxs[i].get("layout_type"):
  59. i += 1
  60. continue
  61. if __is_garbage(bxs[i]):
  62. bxs.pop(i)
  63. continue
  64. ii = self.find_overlapped_with_threashold(bxs[i], lts_,
  65. thr=0.4)
  66. if ii is None: # belong to nothing
  67. bxs[i]["layout_type"] = ""
  68. i += 1
  69. continue
  70. lts_[ii]["visited"] = True
  71. if lts_[ii]["type"] in ["footer", "header", "reference"]:
  72. if lts_[ii]["type"] not in garbages:
  73. garbages[lts_[ii]["type"]] = []
  74. garbages[lts_[ii]["type"]].append(bxs[i]["text"])
  75. bxs.pop(i)
  76. continue
  77. bxs[i]["layoutno"] = f"{ty}-{ii}"
  78. bxs[i]["layout_type"] = lts_[ii]["type"]
  79. i += 1
  80. for lt in ["footer", "header", "reference", "figure caption",
  81. "table caption", "title", "text", "table", "figure", "equation"]:
  82. findLayout(lt)
  83. # add box to figure layouts which has not text box
  84. for i, lt in enumerate(
  85. [lt for lt in lts if lt["type"] == "figure"]):
  86. if lt.get("visited"):
  87. continue
  88. lt = deepcopy(lt)
  89. del lt["type"]
  90. lt["text"] = ""
  91. lt["layout_type"] = "figure"
  92. lt["layoutno"] = f"figure-{i}"
  93. bxs.append(lt)
  94. boxes.extend(bxs)
  95. ocr_res = boxes
  96. garbag_set = set()
  97. for k in garbages.keys():
  98. garbages[k] = Counter(garbages[k])
  99. for g, c in garbages[k].items():
  100. if c > 1:
  101. garbag_set.add(g)
  102. ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
  103. return ocr_res, page_layout