Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

table_structure_recognizer.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from collections import Counter
  20. import numpy as np
  21. from huggingface_hub import snapshot_download
  22. from api.utils.file_utils import get_project_base_directory
  23. from rag.nlp import rag_tokenizer
  24. from .recognizer import Recognizer
  25. class TableStructureRecognizer(Recognizer):
  26. labels = [
  27. "table",
  28. "table column",
  29. "table row",
  30. "table column header",
  31. "table projected row header",
  32. "table spanning cell",
  33. ]
  34. def __init__(self):
  35. try:
  36. super().__init__(self.labels, "tsr", os.path.join(
  37. get_project_base_directory(),
  38. "rag/res/deepdoc"))
  39. except Exception:
  40. super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc",
  41. local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
  42. local_dir_use_symlinks=False))
  43. def __call__(self, images, thr=0.2):
  44. tbls = super().__call__(images, thr)
  45. res = []
  46. # align left&right for rows, align top&bottom for columns
  47. for tbl in tbls:
  48. lts = [{"label": b["type"],
  49. "score": b["score"],
  50. "x0": b["bbox"][0], "x1": b["bbox"][2],
  51. "top": b["bbox"][1], "bottom": b["bbox"][-1]
  52. } for b in tbl]
  53. if not lts:
  54. continue
  55. left = [b["x0"] for b in lts if b["label"].find(
  56. "row") > 0 or b["label"].find("header") > 0]
  57. right = [b["x1"] for b in lts if b["label"].find(
  58. "row") > 0 or b["label"].find("header") > 0]
  59. if not left:
  60. continue
  61. left = np.mean(left) if len(left) > 4 else np.min(left)
  62. right = np.mean(right) if len(right) > 4 else np.max(right)
  63. for b in lts:
  64. if b["label"].find("row") > 0 or b["label"].find("header") > 0:
  65. if b["x0"] > left:
  66. b["x0"] = left
  67. if b["x1"] < right:
  68. b["x1"] = right
  69. top = [b["top"] for b in lts if b["label"] == "table column"]
  70. bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
  71. if not top:
  72. res.append(lts)
  73. continue
  74. top = np.median(top) if len(top) > 4 else np.min(top)
  75. bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
  76. for b in lts:
  77. if b["label"] == "table column":
  78. if b["top"] > top:
  79. b["top"] = top
  80. if b["bottom"] < bottom:
  81. b["bottom"] = bottom
  82. res.append(lts)
  83. return res
  84. @staticmethod
  85. def is_caption(bx):
  86. patt = [
  87. r"[图表]+[ 0-9::]{2,}"
  88. ]
  89. if any([re.match(p, bx["text"].strip()) for p in patt]) \
  90. or bx["layout_type"].find("caption") >= 0:
  91. return True
  92. return False
  93. @staticmethod
  94. def blockType(b):
  95. patt = [
  96. ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
  97. (r"^(20|19)[0-9]{2}年$", "Dt"),
  98. (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
  99. ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
  100. (r"^第*[一二三四1-4]季度$", "Dt"),
  101. (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
  102. (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
  103. ("^[0-9.,+%/ -]+$", "Nu"),
  104. (r"^[0-9A-Z/\._~-]+$", "Ca"),
  105. (r"^[A-Z]*[a-z' -]+$", "En"),
  106. (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
  107. (r"^.{1}$", "Sg")
  108. ]
  109. for p, n in patt:
  110. if re.search(p, b["text"].strip()):
  111. return n
  112. tks = [t for t in rag_tokenizer.tokenize(b["text"]).split() if len(t) > 1]
  113. if len(tks) > 3:
  114. if len(tks) < 12:
  115. return "Tx"
  116. else:
  117. return "Lx"
  118. if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
  119. return "Nr"
  120. return "Ot"
  121. @staticmethod
  122. def construct_table(boxes, is_english=False, html=True, **kwargs):
  123. cap = ""
  124. i = 0
  125. while i < len(boxes):
  126. if TableStructureRecognizer.is_caption(boxes[i]):
  127. if is_english:
  128. cap + " "
  129. cap += boxes[i]["text"]
  130. boxes.pop(i)
  131. i -= 1
  132. i += 1
  133. if not boxes:
  134. return []
  135. for b in boxes:
  136. b["btype"] = TableStructureRecognizer.blockType(b)
  137. max_type = Counter([b["btype"] for b in boxes]).items()
  138. max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
  139. logging.debug("MAXTYPE: " + max_type)
  140. rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
  141. rowh = np.min(rowh) if rowh else 0
  142. boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
  143. #for b in boxes:print(b)
  144. boxes[0]["rn"] = 0
  145. rows = [[boxes[0]]]
  146. btm = boxes[0]["bottom"]
  147. for b in boxes[1:]:
  148. b["rn"] = len(rows) - 1
  149. lst_r = rows[-1]
  150. if lst_r[-1].get("R", "") != b.get("R", "") \
  151. or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
  152. ): # new row
  153. btm = b["bottom"]
  154. b["rn"] += 1
  155. rows.append([b])
  156. continue
  157. btm = (btm + b["bottom"]) / 2.
  158. rows[-1].append(b)
  159. colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
  160. colwm = np.min(colwm) if colwm else 0
  161. crosspage = len(set([b["page_number"] for b in boxes])) > 1
  162. if crosspage:
  163. boxes = Recognizer.sort_X_firstly(boxes, colwm / 2)
  164. else:
  165. boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
  166. boxes[0]["cn"] = 0
  167. cols = [[boxes[0]]]
  168. right = boxes[0]["x1"]
  169. for b in boxes[1:]:
  170. b["cn"] = len(cols) - 1
  171. lst_c = cols[-1]
  172. if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
  173. "page_number"]) \
  174. or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
  175. right = b["x1"]
  176. b["cn"] += 1
  177. cols.append([b])
  178. continue
  179. right = (right + b["x1"]) / 2.
  180. cols[-1].append(b)
  181. tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
  182. for b in boxes:
  183. tbl[b["rn"]][b["cn"]].append(b)
  184. if len(rows) >= 4:
  185. # remove single in column
  186. j = 0
  187. while j < len(tbl[0]):
  188. e, ii = 0, 0
  189. for i in range(len(tbl)):
  190. if tbl[i][j]:
  191. e += 1
  192. ii = i
  193. if e > 1:
  194. break
  195. if e > 1:
  196. j += 1
  197. continue
  198. f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
  199. [j - 1][0].get("text")) or j == 0
  200. ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
  201. [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
  202. if f and ff:
  203. j += 1
  204. continue
  205. bx = tbl[ii][j][0]
  206. logging.debug("Relocate column single: " + bx["text"])
  207. # j column only has one value
  208. left, right = 100000, 100000
  209. if j > 0 and not f:
  210. for i in range(len(tbl)):
  211. if tbl[i][j - 1]:
  212. left = min(left, np.min(
  213. [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
  214. if j + 1 < len(tbl[0]) and not ff:
  215. for i in range(len(tbl)):
  216. if tbl[i][j + 1]:
  217. right = min(right, np.min(
  218. [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
  219. assert left < 100000 or right < 100000
  220. if left < right:
  221. for jj in range(j, len(tbl[0])):
  222. for i in range(len(tbl)):
  223. for a in tbl[i][jj]:
  224. a["cn"] -= 1
  225. if tbl[ii][j - 1]:
  226. tbl[ii][j - 1].extend(tbl[ii][j])
  227. else:
  228. tbl[ii][j - 1] = tbl[ii][j]
  229. for i in range(len(tbl)):
  230. tbl[i].pop(j)
  231. else:
  232. for jj in range(j + 1, len(tbl[0])):
  233. for i in range(len(tbl)):
  234. for a in tbl[i][jj]:
  235. a["cn"] -= 1
  236. if tbl[ii][j + 1]:
  237. tbl[ii][j + 1].extend(tbl[ii][j])
  238. else:
  239. tbl[ii][j + 1] = tbl[ii][j]
  240. for i in range(len(tbl)):
  241. tbl[i].pop(j)
  242. cols.pop(j)
  243. assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
  244. len(cols), len(tbl[0]))
  245. if len(cols) >= 4:
  246. # remove single in row
  247. i = 0
  248. while i < len(tbl):
  249. e, jj = 0, 0
  250. for j in range(len(tbl[i])):
  251. if tbl[i][j]:
  252. e += 1
  253. jj = j
  254. if e > 1:
  255. break
  256. if e > 1:
  257. i += 1
  258. continue
  259. f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
  260. [jj][0].get("text")) or i == 0
  261. ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
  262. [jj][0].get("text")) or i + 1 >= len(tbl)
  263. if f and ff:
  264. i += 1
  265. continue
  266. bx = tbl[i][jj][0]
  267. logging.debug("Relocate row single: " + bx["text"])
  268. # i row only has one value
  269. up, down = 100000, 100000
  270. if i > 0 and not f:
  271. for j in range(len(tbl[i - 1])):
  272. if tbl[i - 1][j]:
  273. up = min(up, np.min(
  274. [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
  275. if i + 1 < len(tbl) and not ff:
  276. for j in range(len(tbl[i + 1])):
  277. if tbl[i + 1][j]:
  278. down = min(down, np.min(
  279. [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
  280. assert up < 100000 or down < 100000
  281. if up < down:
  282. for ii in range(i, len(tbl)):
  283. for j in range(len(tbl[ii])):
  284. for a in tbl[ii][j]:
  285. a["rn"] -= 1
  286. if tbl[i - 1][jj]:
  287. tbl[i - 1][jj].extend(tbl[i][jj])
  288. else:
  289. tbl[i - 1][jj] = tbl[i][jj]
  290. tbl.pop(i)
  291. else:
  292. for ii in range(i + 1, len(tbl)):
  293. for j in range(len(tbl[ii])):
  294. for a in tbl[ii][j]:
  295. a["rn"] -= 1
  296. if tbl[i + 1][jj]:
  297. tbl[i + 1][jj].extend(tbl[i][jj])
  298. else:
  299. tbl[i + 1][jj] = tbl[i][jj]
  300. tbl.pop(i)
  301. rows.pop(i)
  302. # which rows are headers
  303. hdset = set([])
  304. for i in range(len(tbl)):
  305. cnt, h = 0, 0
  306. for j, arr in enumerate(tbl[i]):
  307. if not arr:
  308. continue
  309. cnt += 1
  310. if max_type == "Nu" and arr[0]["btype"] == "Nu":
  311. continue
  312. if any([a.get("H") for a in arr]) \
  313. or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
  314. h += 1
  315. if h / cnt > 0.5:
  316. hdset.add(i)
  317. if html:
  318. return TableStructureRecognizer.__html_table(cap, hdset,
  319. TableStructureRecognizer.__cal_spans(boxes, rows,
  320. cols, tbl, True)
  321. )
  322. return TableStructureRecognizer.__desc_table(cap, hdset,
  323. TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl,
  324. False),
  325. is_english)
  326. @staticmethod
  327. def __html_table(cap, hdset, tbl):
  328. # constrcut HTML
  329. html = "<table>"
  330. if cap:
  331. html += f"<caption>{cap}</caption>"
  332. for i in range(len(tbl)):
  333. row = "<tr>"
  334. txts = []
  335. for j, arr in enumerate(tbl[i]):
  336. if arr is None:
  337. continue
  338. if not arr:
  339. row += "<td></td>" if i not in hdset else "<th></th>"
  340. continue
  341. txt = ""
  342. if arr:
  343. h = min(np.min([c["bottom"] - c["top"]
  344. for c in arr]) / 2, 10)
  345. txt = " ".join([c["text"]
  346. for c in Recognizer.sort_Y_firstly(arr, h)])
  347. txts.append(txt)
  348. sp = ""
  349. if arr[0].get("colspan"):
  350. sp = "colspan={}".format(arr[0]["colspan"])
  351. if arr[0].get("rowspan"):
  352. sp += " rowspan={}".format(arr[0]["rowspan"])
  353. if i in hdset:
  354. row += f"<th {sp} >" + txt + "</th>"
  355. else:
  356. row += f"<td {sp} >" + txt + "</td>"
  357. if i in hdset:
  358. if all([t in hdset for t in txts]):
  359. continue
  360. for t in txts:
  361. hdset.add(t)
  362. if row != "<tr>":
  363. row += "</tr>"
  364. else:
  365. row = ""
  366. html += "\n" + row
  367. html += "\n</table>"
  368. return html
  369. @staticmethod
  370. def __desc_table(cap, hdr_rowno, tbl, is_english):
  371. # get text of every colomn in header row to become header text
  372. clmno = len(tbl[0])
  373. rowno = len(tbl)
  374. headers = {}
  375. hdrset = set()
  376. lst_hdr = []
  377. de = "的" if not is_english else " for "
  378. for r in sorted(list(hdr_rowno)):
  379. headers[r] = ["" for _ in range(clmno)]
  380. for i in range(clmno):
  381. if not tbl[r][i]:
  382. continue
  383. txt = " ".join([a["text"].strip() for a in tbl[r][i]])
  384. headers[r][i] = txt
  385. hdrset.add(txt)
  386. if all([not t for t in headers[r]]):
  387. del headers[r]
  388. hdr_rowno.remove(r)
  389. continue
  390. for j in range(clmno):
  391. if headers[r][j]:
  392. continue
  393. if j >= len(lst_hdr):
  394. break
  395. headers[r][j] = lst_hdr[j]
  396. lst_hdr = headers[r]
  397. for i in range(rowno):
  398. if i not in hdr_rowno:
  399. continue
  400. for j in range(i + 1, rowno):
  401. if j not in hdr_rowno:
  402. break
  403. for k in range(clmno):
  404. if not headers[j - 1][k]:
  405. continue
  406. if headers[j][k].find(headers[j - 1][k]) >= 0:
  407. continue
  408. if len(headers[j][k]) > len(headers[j - 1][k]):
  409. headers[j][k] += (de if headers[j][k]
  410. else "") + headers[j - 1][k]
  411. else:
  412. headers[j][k] = headers[j - 1][k] \
  413. + (de if headers[j - 1][k] else "") \
  414. + headers[j][k]
  415. logging.debug(
  416. f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
  417. row_txt = []
  418. for i in range(rowno):
  419. if i in hdr_rowno:
  420. continue
  421. rtxt = []
  422. def append(delimer):
  423. nonlocal rtxt, row_txt
  424. rtxt = delimer.join(rtxt)
  425. if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
  426. row_txt[-1] += "\n" + rtxt
  427. else:
  428. row_txt.append(rtxt)
  429. r = 0
  430. if len(headers.items()):
  431. _arr = [(i - r, r) for r, _ in headers.items() if r < i]
  432. if _arr:
  433. _, r = min(_arr, key=lambda x: x[0])
  434. if r not in headers and clmno <= 2:
  435. for j in range(clmno):
  436. if not tbl[i][j]:
  437. continue
  438. txt = "".join([a["text"].strip() for a in tbl[i][j]])
  439. if txt:
  440. rtxt.append(txt)
  441. if rtxt:
  442. append(":")
  443. continue
  444. for j in range(clmno):
  445. if not tbl[i][j]:
  446. continue
  447. txt = "".join([a["text"].strip() for a in tbl[i][j]])
  448. if not txt:
  449. continue
  450. ctt = headers[r][j] if r in headers else ""
  451. if ctt:
  452. ctt += ":"
  453. ctt += txt
  454. if ctt:
  455. rtxt.append(ctt)
  456. if rtxt:
  457. row_txt.append("; ".join(rtxt))
  458. if cap:
  459. if is_english:
  460. from_ = " in "
  461. else:
  462. from_ = "来自"
  463. row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
  464. return row_txt
  465. @staticmethod
  466. def __cal_spans(boxes, rows, cols, tbl, html=True):
  467. # caculate span
  468. clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
  469. for cln in cols]
  470. crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
  471. for cln in cols]
  472. rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
  473. for row in rows]
  474. rbtm = [np.mean([c.get("R_btm", c["bottom"])
  475. for c in row]) for row in rows]
  476. for b in boxes:
  477. if "SP" not in b:
  478. continue
  479. b["colspan"] = [b["cn"]]
  480. b["rowspan"] = [b["rn"]]
  481. # col span
  482. for j in range(0, len(clft)):
  483. if j == b["cn"]:
  484. continue
  485. if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
  486. continue
  487. if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
  488. continue
  489. b["colspan"].append(j)
  490. # row span
  491. for j in range(0, len(rtop)):
  492. if j == b["rn"]:
  493. continue
  494. if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
  495. continue
  496. if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
  497. continue
  498. b["rowspan"].append(j)
  499. def join(arr):
  500. if not arr:
  501. return ""
  502. return "".join([t["text"] for t in arr])
  503. # rm the spaning cells
  504. for i in range(len(tbl)):
  505. for j, arr in enumerate(tbl[i]):
  506. if not arr:
  507. continue
  508. if all(["rowspan" not in a and "colspan" not in a for a in arr]):
  509. continue
  510. rowspan, colspan = [], []
  511. for a in arr:
  512. if isinstance(a.get("rowspan", 0), list):
  513. rowspan.extend(a["rowspan"])
  514. if isinstance(a.get("colspan", 0), list):
  515. colspan.extend(a["colspan"])
  516. rowspan, colspan = set(rowspan), set(colspan)
  517. if len(rowspan) < 2 and len(colspan) < 2:
  518. for a in arr:
  519. if "rowspan" in a:
  520. del a["rowspan"]
  521. if "colspan" in a:
  522. del a["colspan"]
  523. continue
  524. rowspan, colspan = sorted(rowspan), sorted(colspan)
  525. rowspan = list(range(rowspan[0], rowspan[-1] + 1))
  526. colspan = list(range(colspan[0], colspan[-1] + 1))
  527. assert i in rowspan, rowspan
  528. assert j in colspan, colspan
  529. arr = []
  530. for r in rowspan:
  531. for c in colspan:
  532. arr_txt = join(arr)
  533. if tbl[r][c] and join(tbl[r][c]) != arr_txt:
  534. arr.extend(tbl[r][c])
  535. tbl[r][c] = None if html else arr
  536. for a in arr:
  537. if len(rowspan) > 1:
  538. a["rowspan"] = len(rowspan)
  539. elif "rowspan" in a:
  540. del a["rowspan"]
  541. if len(colspan) > 1:
  542. a["colspan"] = len(colspan)
  543. elif "colspan" in a:
  544. del a["colspan"]
  545. tbl[rowspan[0]][colspan[0]] = arr
  546. return tbl