您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

recognizer.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import math
  19. import numpy as np
  20. import cv2
  21. from copy import deepcopy
  22. import onnxruntime as ort
  23. from huggingface_hub import snapshot_download
  24. from api.utils.file_utils import get_project_base_directory
  25. from .operators import * # noqa: F403
  26. from .operators import preprocess
  27. from . import operators
  28. class Recognizer(object):
  29. def __init__(self, label_list, task_name, model_dir=None):
  30. """
  31. If you have trouble downloading HuggingFace models, -_^ this might help!!
  32. For Linux:
  33. export HF_ENDPOINT=https://hf-mirror.com
  34. For Windows:
  35. Good luck
  36. ^_-
  37. """
  38. if not model_dir:
  39. model_dir = os.path.join(
  40. get_project_base_directory(),
  41. "rag/res/deepdoc")
  42. model_file_path = os.path.join(model_dir, task_name + ".onnx")
  43. if not os.path.exists(model_file_path):
  44. model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
  45. local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
  46. local_dir_use_symlinks=False)
  47. model_file_path = os.path.join(model_dir, task_name + ".onnx")
  48. else:
  49. model_file_path = os.path.join(model_dir, task_name + ".onnx")
  50. if not os.path.exists(model_file_path):
  51. raise ValueError("not find model file path {}".format(
  52. model_file_path))
  53. def cuda_is_available():
  54. try:
  55. import torch
  56. if torch.cuda.is_available():
  57. return True
  58. except Exception:
  59. return False
  60. return False
  61. # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
  62. # Shrink GPU memory after execution
  63. self.run_options = ort.RunOptions()
  64. if cuda_is_available():
  65. options = ort.SessionOptions()
  66. options.enable_cpu_mem_arena = False
  67. cuda_provider_options = {
  68. "device_id": 0, # Use specific GPU
  69. "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
  70. "arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
  71. }
  72. self.ort_sess = ort.InferenceSession(
  73. model_file_path, options=options,
  74. providers=['CUDAExecutionProvider'],
  75. provider_options=[cuda_provider_options]
  76. )
  77. self.run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:0")
  78. logging.info(f"Recognizer {task_name} uses GPU")
  79. else:
  80. self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
  81. self.run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
  82. logging.info(f"Recognizer {task_name} uses CPU")
  83. self.input_names = [node.name for node in self.ort_sess.get_inputs()]
  84. self.output_names = [node.name for node in self.ort_sess.get_outputs()]
  85. self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
  86. self.label_list = label_list
  87. @staticmethod
  88. def sort_Y_firstly(arr, threashold):
  89. # sort using y1 first and then x1
  90. arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
  91. for i in range(len(arr) - 1):
  92. for j in range(i, -1, -1):
  93. # restore the order using th
  94. if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
  95. and arr[j + 1]["x0"] < arr[j]["x0"]:
  96. tmp = deepcopy(arr[j])
  97. arr[j] = deepcopy(arr[j + 1])
  98. arr[j + 1] = deepcopy(tmp)
  99. return arr
  100. @staticmethod
  101. def sort_X_firstly(arr, threashold, copy=True):
  102. # sort using y1 first and then x1
  103. arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
  104. for i in range(len(arr) - 1):
  105. for j in range(i, -1, -1):
  106. # restore the order using th
  107. if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
  108. and arr[j + 1]["top"] < arr[j]["top"]:
  109. tmp = deepcopy(arr[j]) if copy else arr[j]
  110. arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
  111. arr[j + 1] = deepcopy(tmp) if copy else tmp
  112. return arr
  113. @staticmethod
  114. def sort_C_firstly(arr, thr=0):
  115. # sort using y1 first and then x1
  116. # sorted(arr, key=lambda r: (r["x0"], r["top"]))
  117. arr = Recognizer.sort_X_firstly(arr, thr)
  118. for i in range(len(arr) - 1):
  119. for j in range(i, -1, -1):
  120. # restore the order using th
  121. if "C" not in arr[j] or "C" not in arr[j + 1]:
  122. continue
  123. if arr[j + 1]["C"] < arr[j]["C"] \
  124. or (
  125. arr[j + 1]["C"] == arr[j]["C"]
  126. and arr[j + 1]["top"] < arr[j]["top"]
  127. ):
  128. tmp = arr[j]
  129. arr[j] = arr[j + 1]
  130. arr[j + 1] = tmp
  131. return arr
  132. return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
  133. @staticmethod
  134. def sort_R_firstly(arr, thr=0):
  135. # sort using y1 first and then x1
  136. # sorted(arr, key=lambda r: (r["top"], r["x0"]))
  137. arr = Recognizer.sort_Y_firstly(arr, thr)
  138. for i in range(len(arr) - 1):
  139. for j in range(i, -1, -1):
  140. if "R" not in arr[j] or "R" not in arr[j + 1]:
  141. continue
  142. if arr[j + 1]["R"] < arr[j]["R"] \
  143. or (
  144. arr[j + 1]["R"] == arr[j]["R"]
  145. and arr[j + 1]["x0"] < arr[j]["x0"]
  146. ):
  147. tmp = arr[j]
  148. arr[j] = arr[j + 1]
  149. arr[j + 1] = tmp
  150. return arr
  151. @staticmethod
  152. def overlapped_area(a, b, ratio=True):
  153. tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
  154. if b["x0"] > x1 or b["x1"] < x0:
  155. return 0
  156. if b["bottom"] < tp or b["top"] > btm:
  157. return 0
  158. x0_ = max(b["x0"], x0)
  159. x1_ = min(b["x1"], x1)
  160. assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format(
  161. tp, btm, x0, x1, b)
  162. tp_ = max(b["top"], tp)
  163. btm_ = min(b["bottom"], btm)
  164. assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format(
  165. tp, btm, x0, x1, b)
  166. ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
  167. x0 != 0 and btm - tp != 0 else 0
  168. if ov > 0 and ratio:
  169. ov /= (x1 - x0) * (btm - tp)
  170. return ov
  171. @staticmethod
  172. def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
  173. def notOverlapped(a, b):
  174. return any([a["x1"] < b["x0"],
  175. a["x0"] > b["x1"],
  176. a["bottom"] < b["top"],
  177. a["top"] > b["bottom"]])
  178. i = 0
  179. while i + 1 < len(layouts):
  180. j = i + 1
  181. while j < min(i + far, len(layouts)) \
  182. and (layouts[i].get("type", "") != layouts[j].get("type", "")
  183. or notOverlapped(layouts[i], layouts[j])):
  184. j += 1
  185. if j >= min(i + far, len(layouts)):
  186. i += 1
  187. continue
  188. if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
  189. and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
  190. i += 1
  191. continue
  192. if layouts[i].get("score") and layouts[j].get("score"):
  193. if layouts[i]["score"] > layouts[j]["score"]:
  194. layouts.pop(j)
  195. else:
  196. layouts.pop(i)
  197. continue
  198. area_i, area_i_1 = 0, 0
  199. for b in boxes:
  200. if not notOverlapped(b, layouts[i]):
  201. area_i += Recognizer.overlapped_area(b, layouts[i], False)
  202. if not notOverlapped(b, layouts[j]):
  203. area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
  204. if area_i > area_i_1:
  205. layouts.pop(j)
  206. else:
  207. layouts.pop(i)
  208. return layouts
  209. def create_inputs(self, imgs, im_info):
  210. """generate input for different model type
  211. Args:
  212. imgs (list(numpy)): list of images (np.ndarray)
  213. im_info (list(dict)): list of image info
  214. Returns:
  215. inputs (dict): input of model
  216. """
  217. inputs = {}
  218. im_shape = []
  219. scale_factor = []
  220. if len(imgs) == 1:
  221. inputs['image'] = np.array((imgs[0],)).astype('float32')
  222. inputs['im_shape'] = np.array(
  223. (im_info[0]['im_shape'],)).astype('float32')
  224. inputs['scale_factor'] = np.array(
  225. (im_info[0]['scale_factor'],)).astype('float32')
  226. return inputs
  227. for e in im_info:
  228. im_shape.append(np.array((e['im_shape'],)).astype('float32'))
  229. scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
  230. inputs['im_shape'] = np.concatenate(im_shape, axis=0)
  231. inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
  232. imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
  233. max_shape_h = max([e[0] for e in imgs_shape])
  234. max_shape_w = max([e[1] for e in imgs_shape])
  235. padding_imgs = []
  236. for img in imgs:
  237. im_c, im_h, im_w = img.shape[:]
  238. padding_im = np.zeros(
  239. (im_c, max_shape_h, max_shape_w), dtype=np.float32)
  240. padding_im[:, :im_h, :im_w] = img
  241. padding_imgs.append(padding_im)
  242. inputs['image'] = np.stack(padding_imgs, axis=0)
  243. return inputs
  244. @staticmethod
  245. def find_overlapped(box, boxes_sorted_by_y, naive=False):
  246. if not boxes_sorted_by_y:
  247. return
  248. bxs = boxes_sorted_by_y
  249. s, e, ii = 0, len(bxs), 0
  250. while s < e and not naive:
  251. ii = (e + s) // 2
  252. pv = bxs[ii]
  253. if box["bottom"] < pv["top"]:
  254. e = ii
  255. continue
  256. if box["top"] > pv["bottom"]:
  257. s = ii + 1
  258. continue
  259. break
  260. while s < ii:
  261. if box["top"] > bxs[s]["bottom"]:
  262. s += 1
  263. break
  264. while e - 1 > ii:
  265. if box["bottom"] < bxs[e - 1]["top"]:
  266. e -= 1
  267. break
  268. max_overlaped_i, max_overlaped = None, 0
  269. for i in range(s, e):
  270. ov = Recognizer.overlapped_area(bxs[i], box)
  271. if ov <= max_overlaped:
  272. continue
  273. max_overlaped_i = i
  274. max_overlaped = ov
  275. return max_overlaped_i
  276. @staticmethod
  277. def find_horizontally_tightest_fit(box, boxes):
  278. if not boxes:
  279. return
  280. min_dis, min_i = 1000000, None
  281. for i,b in enumerate(boxes):
  282. if box.get("layoutno", "0") != b.get("layoutno", "0"):
  283. continue
  284. dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
  285. if dis < min_dis:
  286. min_i = i
  287. min_dis = dis
  288. return min_i
  289. @staticmethod
  290. def find_overlapped_with_threashold(box, boxes, thr=0.3):
  291. if not boxes:
  292. return
  293. max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
  294. s, e = 0, len(boxes)
  295. for i in range(s, e):
  296. ov = Recognizer.overlapped_area(box, boxes[i])
  297. _ov = Recognizer.overlapped_area(boxes[i], box)
  298. if (ov, _ov) < (max_overlapped, _max_overlapped):
  299. continue
  300. max_overlapped_i = i
  301. max_overlapped = ov
  302. _max_overlapped = _ov
  303. return max_overlapped_i
  304. def preprocess(self, image_list):
  305. inputs = []
  306. if "scale_factor" in self.input_names:
  307. preprocess_ops = []
  308. for op_info in [
  309. {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
  310. {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
  311. {'type': 'Permute'},
  312. {'stride': 32, 'type': 'PadStride'}
  313. ]:
  314. new_op_info = op_info.copy()
  315. op_type = new_op_info.pop('type')
  316. preprocess_ops.append(getattr(operators, op_type)(**new_op_info))
  317. for im_path in image_list:
  318. im, im_info = preprocess(im_path, preprocess_ops)
  319. inputs.append({"image": np.array((im,)).astype('float32'),
  320. "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
  321. else:
  322. hh, ww = self.input_shape
  323. for img in image_list:
  324. h, w = img.shape[:2]
  325. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  326. img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
  327. # Scale input pixel values to 0 to 1
  328. img /= 255.0
  329. img = img.transpose(2, 0, 1)
  330. img = img[np.newaxis, :, :, :].astype(np.float32)
  331. inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
  332. return inputs
  333. def postprocess(self, boxes, inputs, thr):
  334. if "scale_factor" in self.input_names:
  335. bb = []
  336. for b in boxes:
  337. clsid, bbox, score = int(b[0]), b[2:], b[1]
  338. if score < thr:
  339. continue
  340. if clsid >= len(self.label_list):
  341. continue
  342. bb.append({
  343. "type": self.label_list[clsid].lower(),
  344. "bbox": [float(t) for t in bbox.tolist()],
  345. "score": float(score)
  346. })
  347. return bb
  348. def xywh2xyxy(x):
  349. # [x, y, w, h] to [x1, y1, x2, y2]
  350. y = np.copy(x)
  351. y[:, 0] = x[:, 0] - x[:, 2] / 2
  352. y[:, 1] = x[:, 1] - x[:, 3] / 2
  353. y[:, 2] = x[:, 0] + x[:, 2] / 2
  354. y[:, 3] = x[:, 1] + x[:, 3] / 2
  355. return y
  356. def compute_iou(box, boxes):
  357. # Compute xmin, ymin, xmax, ymax for both boxes
  358. xmin = np.maximum(box[0], boxes[:, 0])
  359. ymin = np.maximum(box[1], boxes[:, 1])
  360. xmax = np.minimum(box[2], boxes[:, 2])
  361. ymax = np.minimum(box[3], boxes[:, 3])
  362. # Compute intersection area
  363. intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
  364. # Compute union area
  365. box_area = (box[2] - box[0]) * (box[3] - box[1])
  366. boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  367. union_area = box_area + boxes_area - intersection_area
  368. # Compute IoU
  369. iou = intersection_area / union_area
  370. return iou
  371. def iou_filter(boxes, scores, iou_threshold):
  372. sorted_indices = np.argsort(scores)[::-1]
  373. keep_boxes = []
  374. while sorted_indices.size > 0:
  375. # Pick the last box
  376. box_id = sorted_indices[0]
  377. keep_boxes.append(box_id)
  378. # Compute IoU of the picked box with the rest
  379. ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
  380. # Remove boxes with IoU over the threshold
  381. keep_indices = np.where(ious < iou_threshold)[0]
  382. # print(keep_indices.shape, sorted_indices.shape)
  383. sorted_indices = sorted_indices[keep_indices + 1]
  384. return keep_boxes
  385. boxes = np.squeeze(boxes).T
  386. # Filter out object confidence scores below threshold
  387. scores = np.max(boxes[:, 4:], axis=1)
  388. boxes = boxes[scores > thr, :]
  389. scores = scores[scores > thr]
  390. if len(boxes) == 0:
  391. return []
  392. # Get the class with the highest confidence
  393. class_ids = np.argmax(boxes[:, 4:], axis=1)
  394. boxes = boxes[:, :4]
  395. input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
  396. boxes = np.multiply(boxes, input_shape, dtype=np.float32)
  397. boxes = xywh2xyxy(boxes)
  398. unique_class_ids = np.unique(class_ids)
  399. indices = []
  400. for class_id in unique_class_ids:
  401. class_indices = np.where(class_ids == class_id)[0]
  402. class_boxes = boxes[class_indices, :]
  403. class_scores = scores[class_indices]
  404. class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
  405. indices.extend(class_indices[class_keep_boxes])
  406. return [{
  407. "type": self.label_list[class_ids[i]].lower(),
  408. "bbox": [float(t) for t in boxes[i].tolist()],
  409. "score": float(scores[i])
  410. } for i in indices]
  411. def __call__(self, image_list, thr=0.7, batch_size=16):
  412. res = []
  413. imgs = []
  414. for i in range(len(image_list)):
  415. if not isinstance(image_list[i], np.ndarray):
  416. imgs.append(np.array(image_list[i]))
  417. else:
  418. imgs.append(image_list[i])
  419. batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
  420. for i in range(batch_loop_cnt):
  421. start_index = i * batch_size
  422. end_index = min((i + 1) * batch_size, len(imgs))
  423. batch_image_list = imgs[start_index:end_index]
  424. inputs = self.preprocess(batch_image_list)
  425. logging.debug("preprocess")
  426. for ins in inputs:
  427. bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr)
  428. res.append(bb)
  429. #seeit.save_results(image_list, res, self.label_list, threshold=thr)
  430. return res