You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

table_structure_recognizer.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. import logging
  2. import os
  3. import re
  4. from collections import Counter
  5. from copy import deepcopy
  6. import numpy as np
  7. from api.utils.file_utils import get_project_base_directory
  8. from rag.nlp import huqie
  9. from .recognizer import Recognizer
  10. class TableStructureRecognizer(Recognizer):
  11. def __init__(self):
  12. self.labels = [
  13. "table",
  14. "table column",
  15. "table row",
  16. "table column header",
  17. "table projected row header",
  18. "table spanning cell",
  19. ]
  20. super().__init__(self.labels, "tsr",
  21. os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
  22. def __call__(self, images, thr=0.5):
  23. tbls = super().__call__(images, thr)
  24. res = []
  25. # align left&right for rows, align top&bottom for columns
  26. for tbl in tbls:
  27. lts = [{"label": b["type"],
  28. "score": b["score"],
  29. "x0": b["bbox"][0], "x1": b["bbox"][2],
  30. "top": b["bbox"][1], "bottom": b["bbox"][-1]
  31. } for b in tbl]
  32. if not lts:
  33. continue
  34. left = [b["x0"] for b in lts if b["label"].find(
  35. "row") > 0 or b["label"].find("header") > 0]
  36. right = [b["x1"] for b in lts if b["label"].find(
  37. "row") > 0 or b["label"].find("header") > 0]
  38. if not left:
  39. continue
  40. left = np.median(left) if len(left) > 4 else np.min(left)
  41. right = np.median(right) if len(right) > 4 else np.max(right)
  42. for b in lts:
  43. if b["label"].find("row") > 0 or b["label"].find("header") > 0:
  44. if b["x0"] > left:
  45. b["x0"] = left
  46. if b["x1"] < right:
  47. b["x1"] = right
  48. top = [b["top"] for b in lts if b["label"] == "table column"]
  49. bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
  50. if not top:
  51. res.append(lts)
  52. continue
  53. top = np.median(top) if len(top) > 4 else np.min(top)
  54. bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
  55. for b in lts:
  56. if b["label"] == "table column":
  57. if b["top"] > top:
  58. b["top"] = top
  59. if b["bottom"] < bottom:
  60. b["bottom"] = bottom
  61. res.append(lts)
  62. return res
  63. @staticmethod
  64. def is_caption(bx):
  65. patt = [
  66. r"[图表]+[ 0-9::]{2,}"
  67. ]
  68. if any([re.match(p, bx["text"].strip()) for p in patt]) \
  69. or bx["layout_type"].find("caption") >= 0:
  70. return True
  71. return False
  72. def __blockType(self, b):
  73. patt = [
  74. ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
  75. (r"^(20|19)[0-9]{2}年$", "Dt"),
  76. (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
  77. ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
  78. (r"^第*[一二三四1-4]季度$", "Dt"),
  79. (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
  80. (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
  81. ("^[0-9.,+%/ -]+$", "Nu"),
  82. (r"^[0-9A-Z/\._~-]+$", "Ca"),
  83. (r"^[A-Z]*[a-z' -]+$", "En"),
  84. (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
  85. (r"^.{1}$", "Sg")
  86. ]
  87. for p, n in patt:
  88. if re.search(p, b["text"].strip()):
  89. return n
  90. tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
  91. if len(tks) > 3:
  92. if len(tks) < 12:
  93. return "Tx"
  94. else:
  95. return "Lx"
  96. if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
  97. return "Nr"
  98. return "Ot"
  99. def construct_table(self, boxes, is_english=False, html=False):
  100. cap = ""
  101. i = 0
  102. while i < len(boxes):
  103. if self.is_caption(boxes[i]):
  104. cap += boxes[i]["text"]
  105. boxes.pop(i)
  106. i -= 1
  107. i += 1
  108. if not boxes:
  109. return []
  110. for b in boxes:
  111. b["btype"] = self.__blockType(b)
  112. max_type = Counter([b["btype"] for b in boxes]).items()
  113. max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
  114. logging.debug("MAXTYPE: " + max_type)
  115. rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
  116. rowh = np.min(rowh) if rowh else 0
  117. boxes = self.sort_R_firstly(boxes, rowh / 2)
  118. boxes[0]["rn"] = 0
  119. rows = [[boxes[0]]]
  120. btm = boxes[0]["bottom"]
  121. for b in boxes[1:]:
  122. b["rn"] = len(rows) - 1
  123. lst_r = rows[-1]
  124. if lst_r[-1].get("R", "") != b.get("R", "") \
  125. or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
  126. ): # new row
  127. btm = b["bottom"]
  128. b["rn"] += 1
  129. rows.append([b])
  130. continue
  131. btm = (btm + b["bottom"]) / 2.
  132. rows[-1].append(b)
  133. colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
  134. colwm = np.min(colwm) if colwm else 0
  135. crosspage = len(set([b["page_number"] for b in boxes])) > 1
  136. if crosspage:
  137. boxes = self.sort_X_firstly(boxes, colwm / 2, False)
  138. else:
  139. boxes = self.sort_C_firstly(boxes, colwm / 2)
  140. boxes[0]["cn"] = 0
  141. cols = [[boxes[0]]]
  142. right = boxes[0]["x1"]
  143. for b in boxes[1:]:
  144. b["cn"] = len(cols) - 1
  145. lst_c = cols[-1]
  146. if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
  147. "page_number"]) \
  148. or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
  149. right = b["x1"]
  150. b["cn"] += 1
  151. cols.append([b])
  152. continue
  153. right = (right + b["x1"]) / 2.
  154. cols[-1].append(b)
  155. tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
  156. for b in boxes:
  157. tbl[b["rn"]][b["cn"]].append(b)
  158. if len(rows) >= 4:
  159. # remove single in column
  160. j = 0
  161. while j < len(tbl[0]):
  162. e, ii = 0, 0
  163. for i in range(len(tbl)):
  164. if tbl[i][j]:
  165. e += 1
  166. ii = i
  167. if e > 1:
  168. break
  169. if e > 1:
  170. j += 1
  171. continue
  172. f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
  173. [j - 1][0].get("text")) or j == 0
  174. ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
  175. [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
  176. if f and ff:
  177. j += 1
  178. continue
  179. bx = tbl[ii][j][0]
  180. logging.debug("Relocate column single: " + bx["text"])
  181. # j column only has one value
  182. left, right = 100000, 100000
  183. if j > 0 and not f:
  184. for i in range(len(tbl)):
  185. if tbl[i][j - 1]:
  186. left = min(left, np.min(
  187. [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
  188. if j + 1 < len(tbl[0]) and not ff:
  189. for i in range(len(tbl)):
  190. if tbl[i][j + 1]:
  191. right = min(right, np.min(
  192. [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
  193. assert left < 100000 or right < 100000
  194. if left < right:
  195. for jj in range(j, len(tbl[0])):
  196. for i in range(len(tbl)):
  197. for a in tbl[i][jj]:
  198. a["cn"] -= 1
  199. if tbl[ii][j - 1]:
  200. tbl[ii][j - 1].extend(tbl[ii][j])
  201. else:
  202. tbl[ii][j - 1] = tbl[ii][j]
  203. for i in range(len(tbl)):
  204. tbl[i].pop(j)
  205. else:
  206. for jj in range(j + 1, len(tbl[0])):
  207. for i in range(len(tbl)):
  208. for a in tbl[i][jj]:
  209. a["cn"] -= 1
  210. if tbl[ii][j + 1]:
  211. tbl[ii][j + 1].extend(tbl[ii][j])
  212. else:
  213. tbl[ii][j + 1] = tbl[ii][j]
  214. for i in range(len(tbl)):
  215. tbl[i].pop(j)
  216. cols.pop(j)
  217. assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
  218. len(cols), len(tbl[0]))
  219. if len(cols) >= 4:
  220. # remove single in row
  221. i = 0
  222. while i < len(tbl):
  223. e, jj = 0, 0
  224. for j in range(len(tbl[i])):
  225. if tbl[i][j]:
  226. e += 1
  227. jj = j
  228. if e > 1:
  229. break
  230. if e > 1:
  231. i += 1
  232. continue
  233. f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
  234. [jj][0].get("text")) or i == 0
  235. ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
  236. [jj][0].get("text")) or i + 1 >= len(tbl)
  237. if f and ff:
  238. i += 1
  239. continue
  240. bx = tbl[i][jj][0]
  241. logging.debug("Relocate row single: " + bx["text"])
  242. # i row only has one value
  243. up, down = 100000, 100000
  244. if i > 0 and not f:
  245. for j in range(len(tbl[i - 1])):
  246. if tbl[i - 1][j]:
  247. up = min(up, np.min(
  248. [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
  249. if i + 1 < len(tbl) and not ff:
  250. for j in range(len(tbl[i + 1])):
  251. if tbl[i + 1][j]:
  252. down = min(down, np.min(
  253. [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
  254. assert up < 100000 or down < 100000
  255. if up < down:
  256. for ii in range(i, len(tbl)):
  257. for j in range(len(tbl[ii])):
  258. for a in tbl[ii][j]:
  259. a["rn"] -= 1
  260. if tbl[i - 1][jj]:
  261. tbl[i - 1][jj].extend(tbl[i][jj])
  262. else:
  263. tbl[i - 1][jj] = tbl[i][jj]
  264. tbl.pop(i)
  265. else:
  266. for ii in range(i + 1, len(tbl)):
  267. for j in range(len(tbl[ii])):
  268. for a in tbl[ii][j]:
  269. a["rn"] -= 1
  270. if tbl[i + 1][jj]:
  271. tbl[i + 1][jj].extend(tbl[i][jj])
  272. else:
  273. tbl[i + 1][jj] = tbl[i][jj]
  274. tbl.pop(i)
  275. rows.pop(i)
  276. # which rows are headers
  277. hdset = set([])
  278. for i in range(len(tbl)):
  279. cnt, h = 0, 0
  280. for j, arr in enumerate(tbl[i]):
  281. if not arr:
  282. continue
  283. cnt += 1
  284. if max_type == "Nu" and arr[0]["btype"] == "Nu":
  285. continue
  286. if any([a.get("H") for a in arr]) \
  287. or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
  288. h += 1
  289. if h / cnt > 0.5:
  290. hdset.add(i)
  291. if html:
  292. return [self.__html_table(cap, hdset,
  293. self.__cal_spans(boxes, rows,
  294. cols, tbl, True)
  295. )]
  296. return self.__desc_table(cap, hdset,
  297. self.__cal_spans(boxes, rows, cols, tbl, False),
  298. is_english)
  299. def __html_table(self, cap, hdset, tbl):
  300. # constrcut HTML
  301. html = "<table>"
  302. if cap:
  303. html += f"<caption>{cap}</caption>"
  304. for i in range(len(tbl)):
  305. row = "<tr>"
  306. txts = []
  307. for j, arr in enumerate(tbl[i]):
  308. if arr is None:
  309. continue
  310. if not arr:
  311. row += "<td></td>" if i not in hdset else "<th></th>"
  312. continue
  313. txt = ""
  314. if arr:
  315. h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
  316. txt = "".join([c["text"]
  317. for c in self.sort_Y_firstly(arr, h)])
  318. txts.append(txt)
  319. sp = ""
  320. if arr[0].get("colspan"):
  321. sp = "colspan={}".format(arr[0]["colspan"])
  322. if arr[0].get("rowspan"):
  323. sp += " rowspan={}".format(arr[0]["rowspan"])
  324. if i in hdset:
  325. row += f"<th {sp} >" + txt + "</th>"
  326. else:
  327. row += f"<td {sp} >" + txt + "</td>"
  328. if i in hdset:
  329. if all([t in hdset for t in txts]):
  330. continue
  331. for t in txts:
  332. hdset.add(t)
  333. if row != "<tr>":
  334. row += "</tr>"
  335. else:
  336. row = ""
  337. html += "\n" + row
  338. html += "\n</table>"
  339. return html
  340. def __desc_table(self, cap, hdr_rowno, tbl, is_english):
  341. # get text of every colomn in header row to become header text
  342. clmno = len(tbl[0])
  343. rowno = len(tbl)
  344. headers = {}
  345. hdrset = set()
  346. lst_hdr = []
  347. de = "的" if not is_english else " for "
  348. for r in sorted(list(hdr_rowno)):
  349. headers[r] = ["" for _ in range(clmno)]
  350. for i in range(clmno):
  351. if not tbl[r][i]:
  352. continue
  353. txt = "".join([a["text"].strip() for a in tbl[r][i]])
  354. headers[r][i] = txt
  355. hdrset.add(txt)
  356. if all([not t for t in headers[r]]):
  357. del headers[r]
  358. hdr_rowno.remove(r)
  359. continue
  360. for j in range(clmno):
  361. if headers[r][j]:
  362. continue
  363. if j >= len(lst_hdr):
  364. break
  365. headers[r][j] = lst_hdr[j]
  366. lst_hdr = headers[r]
  367. for i in range(rowno):
  368. if i not in hdr_rowno:
  369. continue
  370. for j in range(i + 1, rowno):
  371. if j not in hdr_rowno:
  372. break
  373. for k in range(clmno):
  374. if not headers[j - 1][k]:
  375. continue
  376. if headers[j][k].find(headers[j - 1][k]) >= 0:
  377. continue
  378. if len(headers[j][k]) > len(headers[j - 1][k]):
  379. headers[j][k] += (de if headers[j][k]
  380. else "") + headers[j - 1][k]
  381. else:
  382. headers[j][k] = headers[j - 1][k] \
  383. + (de if headers[j - 1][k] else "") \
  384. + headers[j][k]
  385. logging.debug(
  386. f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
  387. row_txt = []
  388. for i in range(rowno):
  389. if i in hdr_rowno:
  390. continue
  391. rtxt = []
  392. def append(delimer):
  393. nonlocal rtxt, row_txt
  394. rtxt = delimer.join(rtxt)
  395. if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
  396. row_txt[-1] += "\n" + rtxt
  397. else:
  398. row_txt.append(rtxt)
  399. r = 0
  400. if len(headers.items()):
  401. _arr = [(i - r, r) for r, _ in headers.items() if r < i]
  402. if _arr:
  403. _, r = min(_arr, key=lambda x: x[0])
  404. if r not in headers and clmno <= 2:
  405. for j in range(clmno):
  406. if not tbl[i][j]:
  407. continue
  408. txt = "".join([a["text"].strip() for a in tbl[i][j]])
  409. if txt:
  410. rtxt.append(txt)
  411. if rtxt:
  412. append(":")
  413. continue
  414. for j in range(clmno):
  415. if not tbl[i][j]:
  416. continue
  417. txt = "".join([a["text"].strip() for a in tbl[i][j]])
  418. if not txt:
  419. continue
  420. ctt = headers[r][j] if r in headers else ""
  421. if ctt:
  422. ctt += ":"
  423. ctt += txt
  424. if ctt:
  425. rtxt.append(ctt)
  426. if rtxt:
  427. row_txt.append("; ".join(rtxt))
  428. if cap:
  429. if is_english:
  430. from_ = " in "
  431. else:
  432. from_ = "来自"
  433. row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
  434. return row_txt
  435. def __cal_spans(self, boxes, rows, cols, tbl, html=True):
  436. # caculate span
  437. clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
  438. for cln in cols]
  439. crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
  440. for cln in cols]
  441. rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
  442. for row in rows]
  443. rbtm = [np.mean([c.get("R_btm", c["bottom"])
  444. for c in row]) for row in rows]
  445. for b in boxes:
  446. if "SP" not in b:
  447. continue
  448. b["colspan"] = [b["cn"]]
  449. b["rowspan"] = [b["rn"]]
  450. # col span
  451. for j in range(0, len(clft)):
  452. if j == b["cn"]:
  453. continue
  454. if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
  455. continue
  456. if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
  457. continue
  458. b["colspan"].append(j)
  459. # row span
  460. for j in range(0, len(rtop)):
  461. if j == b["rn"]:
  462. continue
  463. if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
  464. continue
  465. if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
  466. continue
  467. b["rowspan"].append(j)
  468. def join(arr):
  469. if not arr:
  470. return ""
  471. return "".join([t["text"] for t in arr])
  472. # rm the spaning cells
  473. for i in range(len(tbl)):
  474. for j, arr in enumerate(tbl[i]):
  475. if not arr:
  476. continue
  477. if all(["rowspan" not in a and "colspan" not in a for a in arr]):
  478. continue
  479. rowspan, colspan = [], []
  480. for a in arr:
  481. if isinstance(a.get("rowspan", 0), list):
  482. rowspan.extend(a["rowspan"])
  483. if isinstance(a.get("colspan", 0), list):
  484. colspan.extend(a["colspan"])
  485. rowspan, colspan = set(rowspan), set(colspan)
  486. if len(rowspan) < 2 and len(colspan) < 2:
  487. for a in arr:
  488. if "rowspan" in a:
  489. del a["rowspan"]
  490. if "colspan" in a:
  491. del a["colspan"]
  492. continue
  493. rowspan, colspan = sorted(rowspan), sorted(colspan)
  494. rowspan = list(range(rowspan[0], rowspan[-1] + 1))
  495. colspan = list(range(colspan[0], colspan[-1] + 1))
  496. assert i in rowspan, rowspan
  497. assert j in colspan, colspan
  498. arr = []
  499. for r in rowspan:
  500. for c in colspan:
  501. arr_txt = join(arr)
  502. if tbl[r][c] and join(tbl[r][c]) != arr_txt:
  503. arr.extend(tbl[r][c])
  504. tbl[r][c] = None if html else arr
  505. for a in arr:
  506. if len(rowspan) > 1:
  507. a["rowspan"] = len(rowspan)
  508. elif "rowspan" in a:
  509. del a["rowspan"]
  510. if len(colspan) > 1:
  511. a["colspan"] = len(colspan)
  512. elif "colspan" in a:
  513. del a["colspan"]
  514. tbl[rowspan[0]][colspan[0]] = arr
  515. return tbl