Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

table_structure_recognizer.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. #
  13. import logging
  14. import os
  15. import re
  16. from collections import Counter
  17. import numpy as np
  18. from huggingface_hub import snapshot_download
  19. from api.utils.file_utils import get_project_base_directory
  20. from rag.nlp import huqie
  21. from .recognizer import Recognizer
  22. class TableStructureRecognizer(Recognizer):
  23. labels = [
  24. "table",
  25. "table column",
  26. "table row",
  27. "table column header",
  28. "table projected row header",
  29. "table spanning cell",
  30. ]
  31. def __init__(self):
  32. try:
  33. super().__init__(self.labels, "tsr", os.path.join(
  34. get_project_base_directory(),
  35. "rag/res/deepdoc"))
  36. except Exception as e:
  37. super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc"))
  38. def __call__(self, images, thr=0.2):
  39. tbls = super().__call__(images, thr)
  40. res = []
  41. # align left&right for rows, align top&bottom for columns
  42. for tbl in tbls:
  43. lts = [{"label": b["type"],
  44. "score": b["score"],
  45. "x0": b["bbox"][0], "x1": b["bbox"][2],
  46. "top": b["bbox"][1], "bottom": b["bbox"][-1]
  47. } for b in tbl]
  48. if not lts:
  49. continue
  50. left = [b["x0"] for b in lts if b["label"].find(
  51. "row") > 0 or b["label"].find("header") > 0]
  52. right = [b["x1"] for b in lts if b["label"].find(
  53. "row") > 0 or b["label"].find("header") > 0]
  54. if not left:
  55. continue
  56. left = np.mean(left) if len(left) > 4 else np.min(left)
  57. right = np.mean(right) if len(right) > 4 else np.max(right)
  58. for b in lts:
  59. if b["label"].find("row") > 0 or b["label"].find("header") > 0:
  60. if b["x0"] > left:
  61. b["x0"] = left
  62. if b["x1"] < right:
  63. b["x1"] = right
  64. top = [b["top"] for b in lts if b["label"] == "table column"]
  65. bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
  66. if not top:
  67. res.append(lts)
  68. continue
  69. top = np.median(top) if len(top) > 4 else np.min(top)
  70. bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
  71. for b in lts:
  72. if b["label"] == "table column":
  73. if b["top"] > top:
  74. b["top"] = top
  75. if b["bottom"] < bottom:
  76. b["bottom"] = bottom
  77. res.append(lts)
  78. return res
  79. @staticmethod
  80. def is_caption(bx):
  81. patt = [
  82. r"[图表]+[ 0-9::]{2,}"
  83. ]
  84. if any([re.match(p, bx["text"].strip()) for p in patt]) \
  85. or bx["layout_type"].find("caption") >= 0:
  86. return True
  87. return False
  88. @staticmethod
  89. def blockType(b):
  90. patt = [
  91. ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
  92. (r"^(20|19)[0-9]{2}年$", "Dt"),
  93. (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
  94. ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
  95. (r"^第*[一二三四1-4]季度$", "Dt"),
  96. (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
  97. (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
  98. ("^[0-9.,+%/ -]+$", "Nu"),
  99. (r"^[0-9A-Z/\._~-]+$", "Ca"),
  100. (r"^[A-Z]*[a-z' -]+$", "En"),
  101. (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
  102. (r"^.{1}$", "Sg")
  103. ]
  104. for p, n in patt:
  105. if re.search(p, b["text"].strip()):
  106. return n
  107. tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
  108. if len(tks) > 3:
  109. if len(tks) < 12:
  110. return "Tx"
  111. else:
  112. return "Lx"
  113. if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
  114. return "Nr"
  115. return "Ot"
  116. @staticmethod
  117. def construct_table(boxes, is_english=False, html=False):
  118. cap = ""
  119. i = 0
  120. while i < len(boxes):
  121. if TableStructureRecognizer.is_caption(boxes[i]):
  122. if is_english:
  123. cap + " "
  124. cap += boxes[i]["text"]
  125. boxes.pop(i)
  126. i -= 1
  127. i += 1
  128. if not boxes:
  129. return []
  130. for b in boxes:
  131. b["btype"] = TableStructureRecognizer.blockType(b)
  132. max_type = Counter([b["btype"] for b in boxes]).items()
  133. max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
  134. logging.debug("MAXTYPE: " + max_type)
  135. rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
  136. rowh = np.min(rowh) if rowh else 0
  137. boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
  138. #for b in boxes:print(b)
  139. boxes[0]["rn"] = 0
  140. rows = [[boxes[0]]]
  141. btm = boxes[0]["bottom"]
  142. for b in boxes[1:]:
  143. b["rn"] = len(rows) - 1
  144. lst_r = rows[-1]
  145. if lst_r[-1].get("R", "") != b.get("R", "") \
  146. or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
  147. ): # new row
  148. btm = b["bottom"]
  149. b["rn"] += 1
  150. rows.append([b])
  151. continue
  152. btm = (btm + b["bottom"]) / 2.
  153. rows[-1].append(b)
  154. colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
  155. colwm = np.min(colwm) if colwm else 0
  156. crosspage = len(set([b["page_number"] for b in boxes])) > 1
  157. if crosspage:
  158. boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False)
  159. else:
  160. boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
  161. boxes[0]["cn"] = 0
  162. cols = [[boxes[0]]]
  163. right = boxes[0]["x1"]
  164. for b in boxes[1:]:
  165. b["cn"] = len(cols) - 1
  166. lst_c = cols[-1]
  167. if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
  168. "page_number"]) \
  169. or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
  170. right = b["x1"]
  171. b["cn"] += 1
  172. cols.append([b])
  173. continue
  174. right = (right + b["x1"]) / 2.
  175. cols[-1].append(b)
  176. tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
  177. for b in boxes:
  178. tbl[b["rn"]][b["cn"]].append(b)
  179. if len(rows) >= 4:
  180. # remove single in column
  181. j = 0
  182. while j < len(tbl[0]):
  183. e, ii = 0, 0
  184. for i in range(len(tbl)):
  185. if tbl[i][j]:
  186. e += 1
  187. ii = i
  188. if e > 1:
  189. break
  190. if e > 1:
  191. j += 1
  192. continue
  193. f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
  194. [j - 1][0].get("text")) or j == 0
  195. ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
  196. [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
  197. if f and ff:
  198. j += 1
  199. continue
  200. bx = tbl[ii][j][0]
  201. logging.debug("Relocate column single: " + bx["text"])
  202. # j column only has one value
  203. left, right = 100000, 100000
  204. if j > 0 and not f:
  205. for i in range(len(tbl)):
  206. if tbl[i][j - 1]:
  207. left = min(left, np.min(
  208. [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
  209. if j + 1 < len(tbl[0]) and not ff:
  210. for i in range(len(tbl)):
  211. if tbl[i][j + 1]:
  212. right = min(right, np.min(
  213. [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
  214. assert left < 100000 or right < 100000
  215. if left < right:
  216. for jj in range(j, len(tbl[0])):
  217. for i in range(len(tbl)):
  218. for a in tbl[i][jj]:
  219. a["cn"] -= 1
  220. if tbl[ii][j - 1]:
  221. tbl[ii][j - 1].extend(tbl[ii][j])
  222. else:
  223. tbl[ii][j - 1] = tbl[ii][j]
  224. for i in range(len(tbl)):
  225. tbl[i].pop(j)
  226. else:
  227. for jj in range(j + 1, len(tbl[0])):
  228. for i in range(len(tbl)):
  229. for a in tbl[i][jj]:
  230. a["cn"] -= 1
  231. if tbl[ii][j + 1]:
  232. tbl[ii][j + 1].extend(tbl[ii][j])
  233. else:
  234. tbl[ii][j + 1] = tbl[ii][j]
  235. for i in range(len(tbl)):
  236. tbl[i].pop(j)
  237. cols.pop(j)
  238. assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
  239. len(cols), len(tbl[0]))
  240. if len(cols) >= 4:
  241. # remove single in row
  242. i = 0
  243. while i < len(tbl):
  244. e, jj = 0, 0
  245. for j in range(len(tbl[i])):
  246. if tbl[i][j]:
  247. e += 1
  248. jj = j
  249. if e > 1:
  250. break
  251. if e > 1:
  252. i += 1
  253. continue
  254. f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
  255. [jj][0].get("text")) or i == 0
  256. ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
  257. [jj][0].get("text")) or i + 1 >= len(tbl)
  258. if f and ff:
  259. i += 1
  260. continue
  261. bx = tbl[i][jj][0]
  262. logging.debug("Relocate row single: " + bx["text"])
  263. # i row only has one value
  264. up, down = 100000, 100000
  265. if i > 0 and not f:
  266. for j in range(len(tbl[i - 1])):
  267. if tbl[i - 1][j]:
  268. up = min(up, np.min(
  269. [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
  270. if i + 1 < len(tbl) and not ff:
  271. for j in range(len(tbl[i + 1])):
  272. if tbl[i + 1][j]:
  273. down = min(down, np.min(
  274. [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
  275. assert up < 100000 or down < 100000
  276. if up < down:
  277. for ii in range(i, len(tbl)):
  278. for j in range(len(tbl[ii])):
  279. for a in tbl[ii][j]:
  280. a["rn"] -= 1
  281. if tbl[i - 1][jj]:
  282. tbl[i - 1][jj].extend(tbl[i][jj])
  283. else:
  284. tbl[i - 1][jj] = tbl[i][jj]
  285. tbl.pop(i)
  286. else:
  287. for ii in range(i + 1, len(tbl)):
  288. for j in range(len(tbl[ii])):
  289. for a in tbl[ii][j]:
  290. a["rn"] -= 1
  291. if tbl[i + 1][jj]:
  292. tbl[i + 1][jj].extend(tbl[i][jj])
  293. else:
  294. tbl[i + 1][jj] = tbl[i][jj]
  295. tbl.pop(i)
  296. rows.pop(i)
  297. # which rows are headers
  298. hdset = set([])
  299. for i in range(len(tbl)):
  300. cnt, h = 0, 0
  301. for j, arr in enumerate(tbl[i]):
  302. if not arr:
  303. continue
  304. cnt += 1
  305. if max_type == "Nu" and arr[0]["btype"] == "Nu":
  306. continue
  307. if any([a.get("H") for a in arr]) \
  308. or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
  309. h += 1
  310. if h / cnt > 0.5:
  311. hdset.add(i)
  312. if html:
  313. return TableStructureRecognizer.__html_table(cap, hdset,
  314. TableStructureRecognizer.__cal_spans(boxes, rows,
  315. cols, tbl, True)
  316. )
  317. return TableStructureRecognizer.__desc_table(cap, hdset,
  318. TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl,
  319. False),
  320. is_english)
  321. @staticmethod
  322. def __html_table(cap, hdset, tbl):
  323. # constrcut HTML
  324. html = "<table>"
  325. if cap:
  326. html += f"<caption>{cap}</caption>"
  327. for i in range(len(tbl)):
  328. row = "<tr>"
  329. txts = []
  330. for j, arr in enumerate(tbl[i]):
  331. if arr is None:
  332. continue
  333. if not arr:
  334. row += "<td></td>" if i not in hdset else "<th></th>"
  335. continue
  336. txt = ""
  337. if arr:
  338. h = min(np.min([c["bottom"] - c["top"]
  339. for c in arr]) / 2, 10)
  340. txt = " ".join([c["text"]
  341. for c in Recognizer.sort_Y_firstly(arr, h)])
  342. txts.append(txt)
  343. sp = ""
  344. if arr[0].get("colspan"):
  345. sp = "colspan={}".format(arr[0]["colspan"])
  346. if arr[0].get("rowspan"):
  347. sp += " rowspan={}".format(arr[0]["rowspan"])
  348. if i in hdset:
  349. row += f"<th {sp} >" + txt + "</th>"
  350. else:
  351. row += f"<td {sp} >" + txt + "</td>"
  352. if i in hdset:
  353. if all([t in hdset for t in txts]):
  354. continue
  355. for t in txts:
  356. hdset.add(t)
  357. if row != "<tr>":
  358. row += "</tr>"
  359. else:
  360. row = ""
  361. html += "\n" + row
  362. html += "\n</table>"
  363. return html
  364. @staticmethod
  365. def __desc_table(cap, hdr_rowno, tbl, is_english):
  366. # get text of every colomn in header row to become header text
  367. clmno = len(tbl[0])
  368. rowno = len(tbl)
  369. headers = {}
  370. hdrset = set()
  371. lst_hdr = []
  372. de = "的" if not is_english else " for "
  373. for r in sorted(list(hdr_rowno)):
  374. headers[r] = ["" for _ in range(clmno)]
  375. for i in range(clmno):
  376. if not tbl[r][i]:
  377. continue
  378. txt = " ".join([a["text"].strip() for a in tbl[r][i]])
  379. headers[r][i] = txt
  380. hdrset.add(txt)
  381. if all([not t for t in headers[r]]):
  382. del headers[r]
  383. hdr_rowno.remove(r)
  384. continue
  385. for j in range(clmno):
  386. if headers[r][j]:
  387. continue
  388. if j >= len(lst_hdr):
  389. break
  390. headers[r][j] = lst_hdr[j]
  391. lst_hdr = headers[r]
  392. for i in range(rowno):
  393. if i not in hdr_rowno:
  394. continue
  395. for j in range(i + 1, rowno):
  396. if j not in hdr_rowno:
  397. break
  398. for k in range(clmno):
  399. if not headers[j - 1][k]:
  400. continue
  401. if headers[j][k].find(headers[j - 1][k]) >= 0:
  402. continue
  403. if len(headers[j][k]) > len(headers[j - 1][k]):
  404. headers[j][k] += (de if headers[j][k]
  405. else "") + headers[j - 1][k]
  406. else:
  407. headers[j][k] = headers[j - 1][k] \
  408. + (de if headers[j - 1][k] else "") \
  409. + headers[j][k]
  410. logging.debug(
  411. f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
  412. row_txt = []
  413. for i in range(rowno):
  414. if i in hdr_rowno:
  415. continue
  416. rtxt = []
  417. def append(delimer):
  418. nonlocal rtxt, row_txt
  419. rtxt = delimer.join(rtxt)
  420. if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
  421. row_txt[-1] += "\n" + rtxt
  422. else:
  423. row_txt.append(rtxt)
  424. r = 0
  425. if len(headers.items()):
  426. _arr = [(i - r, r) for r, _ in headers.items() if r < i]
  427. if _arr:
  428. _, r = min(_arr, key=lambda x: x[0])
  429. if r not in headers and clmno <= 2:
  430. for j in range(clmno):
  431. if not tbl[i][j]:
  432. continue
  433. txt = "".join([a["text"].strip() for a in tbl[i][j]])
  434. if txt:
  435. rtxt.append(txt)
  436. if rtxt:
  437. append(":")
  438. continue
  439. for j in range(clmno):
  440. if not tbl[i][j]:
  441. continue
  442. txt = "".join([a["text"].strip() for a in tbl[i][j]])
  443. if not txt:
  444. continue
  445. ctt = headers[r][j] if r in headers else ""
  446. if ctt:
  447. ctt += ":"
  448. ctt += txt
  449. if ctt:
  450. rtxt.append(ctt)
  451. if rtxt:
  452. row_txt.append("; ".join(rtxt))
  453. if cap:
  454. if is_english:
  455. from_ = " in "
  456. else:
  457. from_ = "来自"
  458. row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
  459. return row_txt
  460. @staticmethod
  461. def __cal_spans(boxes, rows, cols, tbl, html=True):
  462. # caculate span
  463. clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
  464. for cln in cols]
  465. crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
  466. for cln in cols]
  467. rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
  468. for row in rows]
  469. rbtm = [np.mean([c.get("R_btm", c["bottom"])
  470. for c in row]) for row in rows]
  471. for b in boxes:
  472. if "SP" not in b:
  473. continue
  474. b["colspan"] = [b["cn"]]
  475. b["rowspan"] = [b["rn"]]
  476. # col span
  477. for j in range(0, len(clft)):
  478. if j == b["cn"]:
  479. continue
  480. if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
  481. continue
  482. if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
  483. continue
  484. b["colspan"].append(j)
  485. # row span
  486. for j in range(0, len(rtop)):
  487. if j == b["rn"]:
  488. continue
  489. if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
  490. continue
  491. if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
  492. continue
  493. b["rowspan"].append(j)
  494. def join(arr):
  495. if not arr:
  496. return ""
  497. return "".join([t["text"] for t in arr])
  498. # rm the spaning cells
  499. for i in range(len(tbl)):
  500. for j, arr in enumerate(tbl[i]):
  501. if not arr:
  502. continue
  503. if all(["rowspan" not in a and "colspan" not in a for a in arr]):
  504. continue
  505. rowspan, colspan = [], []
  506. for a in arr:
  507. if isinstance(a.get("rowspan", 0), list):
  508. rowspan.extend(a["rowspan"])
  509. if isinstance(a.get("colspan", 0), list):
  510. colspan.extend(a["colspan"])
  511. rowspan, colspan = set(rowspan), set(colspan)
  512. if len(rowspan) < 2 and len(colspan) < 2:
  513. for a in arr:
  514. if "rowspan" in a:
  515. del a["rowspan"]
  516. if "colspan" in a:
  517. del a["colspan"]
  518. continue
  519. rowspan, colspan = sorted(rowspan), sorted(colspan)
  520. rowspan = list(range(rowspan[0], rowspan[-1] + 1))
  521. colspan = list(range(colspan[0], colspan[-1] + 1))
  522. assert i in rowspan, rowspan
  523. assert j in colspan, colspan
  524. arr = []
  525. for r in rowspan:
  526. for c in colspan:
  527. arr_txt = join(arr)
  528. if tbl[r][c] and join(tbl[r][c]) != arr_txt:
  529. arr.extend(tbl[r][c])
  530. tbl[r][c] = None if html else arr
  531. for a in arr:
  532. if len(rowspan) > 1:
  533. a["rowspan"] = len(rowspan)
  534. elif "rowspan" in a:
  535. del a["rowspan"]
  536. if len(colspan) > 1:
  537. a["colspan"] = len(colspan)
  538. elif "colspan" in a:
  539. del a["colspan"]
  540. tbl[rowspan[0]][colspan[0]] = arr
  541. return tbl