Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

recognizer.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. #
  13. import os
  14. from copy import deepcopy
  15. import onnxruntime as ort
  16. from huggingface_hub import snapshot_download
  17. from . import seeit
  18. from .operators import *
  19. from rag.settings import cron_logger
  20. class Recognizer(object):
  21. def __init__(self, label_list, task_name, model_dir=None):
  22. """
  23. If you have trouble downloading HuggingFace models, -_^ this might help!!
  24. For Linux:
  25. export HF_ENDPOINT=https://hf-mirror.com
  26. For Windows:
  27. Good luck
  28. ^_-
  29. """
  30. if not model_dir:
  31. model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
  32. model_file_path = os.path.join(model_dir, task_name + ".onnx")
  33. if not os.path.exists(model_file_path):
  34. raise ValueError("not find model file path {}".format(
  35. model_file_path))
  36. if ort.get_device() == "GPU":
  37. self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
  38. else:
  39. self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
  40. self.label_list = label_list
  41. @staticmethod
  42. def sort_Y_firstly(arr, threashold):
  43. # sort using y1 first and then x1
  44. arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
  45. for i in range(len(arr) - 1):
  46. for j in range(i, -1, -1):
  47. # restore the order using th
  48. if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
  49. and arr[j + 1]["x0"] < arr[j]["x0"]:
  50. tmp = deepcopy(arr[j])
  51. arr[j] = deepcopy(arr[j + 1])
  52. arr[j + 1] = deepcopy(tmp)
  53. return arr
  54. @staticmethod
  55. def sort_X_firstly(arr, threashold, copy=True):
  56. # sort using y1 first and then x1
  57. arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
  58. for i in range(len(arr) - 1):
  59. for j in range(i, -1, -1):
  60. # restore the order using th
  61. if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
  62. and arr[j + 1]["top"] < arr[j]["top"]:
  63. tmp = deepcopy(arr[j]) if copy else arr[j]
  64. arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
  65. arr[j + 1] = deepcopy(tmp) if copy else tmp
  66. return arr
  67. @staticmethod
  68. def sort_C_firstly(arr, thr=0):
  69. # sort using y1 first and then x1
  70. # sorted(arr, key=lambda r: (r["x0"], r["top"]))
  71. arr = Recognizer.sort_X_firstly(arr, thr)
  72. for i in range(len(arr) - 1):
  73. for j in range(i, -1, -1):
  74. # restore the order using th
  75. if "C" not in arr[j] or "C" not in arr[j + 1]:
  76. continue
  77. if arr[j + 1]["C"] < arr[j]["C"] \
  78. or (
  79. arr[j + 1]["C"] == arr[j]["C"]
  80. and arr[j + 1]["top"] < arr[j]["top"]
  81. ):
  82. tmp = arr[j]
  83. arr[j] = arr[j + 1]
  84. arr[j + 1] = tmp
  85. return arr
  86. return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
  87. @staticmethod
  88. def sort_R_firstly(arr, thr=0):
  89. # sort using y1 first and then x1
  90. # sorted(arr, key=lambda r: (r["top"], r["x0"]))
  91. arr = Recognizer.sort_Y_firstly(arr, thr)
  92. for i in range(len(arr) - 1):
  93. for j in range(i, -1, -1):
  94. if "R" not in arr[j] or "R" not in arr[j + 1]:
  95. continue
  96. if arr[j + 1]["R"] < arr[j]["R"] \
  97. or (
  98. arr[j + 1]["R"] == arr[j]["R"]
  99. and arr[j + 1]["x0"] < arr[j]["x0"]
  100. ):
  101. tmp = arr[j]
  102. arr[j] = arr[j + 1]
  103. arr[j + 1] = tmp
  104. return arr
  105. @staticmethod
  106. def overlapped_area(a, b, ratio=True):
  107. tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
  108. if b["x0"] > x1 or b["x1"] < x0:
  109. return 0
  110. if b["bottom"] < tp or b["top"] > btm:
  111. return 0
  112. x0_ = max(b["x0"], x0)
  113. x1_ = min(b["x1"], x1)
  114. assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
  115. tp, btm, x0, x1, b)
  116. tp_ = max(b["top"], tp)
  117. btm_ = min(b["bottom"], btm)
  118. assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
  119. tp, btm, x0, x1, b)
  120. ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
  121. x0 != 0 and btm - tp != 0 else 0
  122. if ov > 0 and ratio:
  123. ov /= (x1 - x0) * (btm - tp)
  124. return ov
  125. @staticmethod
  126. def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
  127. def notOverlapped(a, b):
  128. return any([a["x1"] < b["x0"],
  129. a["x0"] > b["x1"],
  130. a["bottom"] < b["top"],
  131. a["top"] > b["bottom"]])
  132. i = 0
  133. while i + 1 < len(layouts):
  134. j = i + 1
  135. while j < min(i + far, len(layouts)) \
  136. and (layouts[i].get("type", "") != layouts[j].get("type", "")
  137. or notOverlapped(layouts[i], layouts[j])):
  138. j += 1
  139. if j >= min(i + far, len(layouts)):
  140. i += 1
  141. continue
  142. if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
  143. and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
  144. i += 1
  145. continue
  146. if layouts[i].get("score") and layouts[j].get("score"):
  147. if layouts[i]["score"] > layouts[j]["score"]:
  148. layouts.pop(j)
  149. else:
  150. layouts.pop(i)
  151. continue
  152. area_i, area_i_1 = 0, 0
  153. for b in boxes:
  154. if not notOverlapped(b, layouts[i]):
  155. area_i += Recognizer.overlapped_area(b, layouts[i], False)
  156. if not notOverlapped(b, layouts[j]):
  157. area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
  158. if area_i > area_i_1:
  159. layouts.pop(j)
  160. else:
  161. layouts.pop(i)
  162. return layouts
  163. def create_inputs(self, imgs, im_info):
  164. """generate input for different model type
  165. Args:
  166. imgs (list(numpy)): list of images (np.ndarray)
  167. im_info (list(dict)): list of image info
  168. Returns:
  169. inputs (dict): input of model
  170. """
  171. inputs = {}
  172. im_shape = []
  173. scale_factor = []
  174. if len(imgs) == 1:
  175. inputs['image'] = np.array((imgs[0],)).astype('float32')
  176. inputs['im_shape'] = np.array(
  177. (im_info[0]['im_shape'],)).astype('float32')
  178. inputs['scale_factor'] = np.array(
  179. (im_info[0]['scale_factor'],)).astype('float32')
  180. return inputs
  181. for e in im_info:
  182. im_shape.append(np.array((e['im_shape'],)).astype('float32'))
  183. scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
  184. inputs['im_shape'] = np.concatenate(im_shape, axis=0)
  185. inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
  186. imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
  187. max_shape_h = max([e[0] for e in imgs_shape])
  188. max_shape_w = max([e[1] for e in imgs_shape])
  189. padding_imgs = []
  190. for img in imgs:
  191. im_c, im_h, im_w = img.shape[:]
  192. padding_im = np.zeros(
  193. (im_c, max_shape_h, max_shape_w), dtype=np.float32)
  194. padding_im[:, :im_h, :im_w] = img
  195. padding_imgs.append(padding_im)
  196. inputs['image'] = np.stack(padding_imgs, axis=0)
  197. return inputs
  198. @staticmethod
  199. def find_overlapped(box, boxes_sorted_by_y, naive=False):
  200. if not boxes_sorted_by_y:
  201. return
  202. bxs = boxes_sorted_by_y
  203. s, e, ii = 0, len(bxs), 0
  204. while s < e and not naive:
  205. ii = (e + s) // 2
  206. pv = bxs[ii]
  207. if box["bottom"] < pv["top"]:
  208. e = ii
  209. continue
  210. if box["top"] > pv["bottom"]:
  211. s = ii + 1
  212. continue
  213. break
  214. while s < ii:
  215. if box["top"] > bxs[s]["bottom"]:
  216. s += 1
  217. break
  218. while e - 1 > ii:
  219. if box["bottom"] < bxs[e - 1]["top"]:
  220. e -= 1
  221. break
  222. max_overlaped_i, max_overlaped = None, 0
  223. for i in range(s, e):
  224. ov = Recognizer.overlapped_area(bxs[i], box)
  225. if ov <= max_overlaped:
  226. continue
  227. max_overlaped_i = i
  228. max_overlaped = ov
  229. return max_overlaped_i
  230. @staticmethod
  231. def find_overlapped_with_threashold(box, boxes, thr=0.3):
  232. if not boxes:
  233. return
  234. max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
  235. s, e = 0, len(boxes)
  236. for i in range(s, e):
  237. ov = Recognizer.overlapped_area(box, boxes[i])
  238. _ov = Recognizer.overlapped_area(boxes[i], box)
  239. if (ov, _ov) < (max_overlaped, _max_overlaped):
  240. continue
  241. max_overlaped_i = i
  242. max_overlaped = ov
  243. _max_overlaped = _ov
  244. return max_overlaped_i
  245. def preprocess(self, image_list):
  246. preprocess_ops = []
  247. for op_info in [
  248. {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
  249. {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
  250. {'type': 'Permute'},
  251. {'stride': 32, 'type': 'PadStride'}
  252. ]:
  253. new_op_info = op_info.copy()
  254. op_type = new_op_info.pop('type')
  255. preprocess_ops.append(eval(op_type)(**new_op_info))
  256. inputs = []
  257. for im_path in image_list:
  258. im, im_info = preprocess(im_path, preprocess_ops)
  259. inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
  260. return inputs
  261. def __call__(self, image_list, thr=0.7, batch_size=16):
  262. res = []
  263. imgs = []
  264. for i in range(len(image_list)):
  265. if not isinstance(image_list[i], np.ndarray):
  266. imgs.append(np.array(image_list[i]))
  267. else: imgs.append(image_list[i])
  268. batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
  269. for i in range(batch_loop_cnt):
  270. start_index = i * batch_size
  271. end_index = min((i + 1) * batch_size, len(imgs))
  272. batch_image_list = imgs[start_index:end_index]
  273. inputs = self.preprocess(batch_image_list)
  274. for ins in inputs:
  275. bb = []
  276. for b in self.ort_sess.run(None, ins)[0]:
  277. clsid, bbox, score = int(b[0]), b[2:], b[1]
  278. if score < thr:
  279. continue
  280. if clsid >= len(self.label_list):
  281. cron_logger.warning(f"bad category id")
  282. continue
  283. bb.append({
  284. "type": self.label_list[clsid].lower(),
  285. "bbox": [float(t) for t in bbox.tolist()],
  286. "score": float(score)
  287. })
  288. res.append(bb)
  289. #seeit.save_results(image_list, res, self.label_list, threshold=thr)
  290. return res