Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

recognizer.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. #
  13. import os
  14. import onnxruntime as ort
  15. from huggingface_hub import snapshot_download
  16. from .operators import *
  17. from rag.settings import cron_logger
  18. class Recognizer(object):
  19. def __init__(self, label_list, task_name, model_dir=None):
  20. """
  21. If you have trouble downloading HuggingFace models, -_^ this might help!!
  22. For Linux:
  23. export HF_ENDPOINT=https://hf-mirror.com
  24. For Windows:
  25. Good luck
  26. ^_-
  27. """
  28. if not model_dir:
  29. model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
  30. model_file_path = os.path.join(model_dir, task_name + ".onnx")
  31. if not os.path.exists(model_file_path):
  32. raise ValueError("not find model file path {}".format(
  33. model_file_path))
  34. if ort.get_device() == "GPU":
  35. self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
  36. else:
  37. self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
  38. self.label_list = label_list
  39. def create_inputs(self, imgs, im_info):
  40. """generate input for different model type
  41. Args:
  42. imgs (list(numpy)): list of images (np.ndarray)
  43. im_info (list(dict)): list of image info
  44. Returns:
  45. inputs (dict): input of model
  46. """
  47. inputs = {}
  48. im_shape = []
  49. scale_factor = []
  50. if len(imgs) == 1:
  51. inputs['image'] = np.array((imgs[0],)).astype('float32')
  52. inputs['im_shape'] = np.array(
  53. (im_info[0]['im_shape'],)).astype('float32')
  54. inputs['scale_factor'] = np.array(
  55. (im_info[0]['scale_factor'],)).astype('float32')
  56. return inputs
  57. for e in im_info:
  58. im_shape.append(np.array((e['im_shape'],)).astype('float32'))
  59. scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
  60. inputs['im_shape'] = np.concatenate(im_shape, axis=0)
  61. inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
  62. imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
  63. max_shape_h = max([e[0] for e in imgs_shape])
  64. max_shape_w = max([e[1] for e in imgs_shape])
  65. padding_imgs = []
  66. for img in imgs:
  67. im_c, im_h, im_w = img.shape[:]
  68. padding_im = np.zeros(
  69. (im_c, max_shape_h, max_shape_w), dtype=np.float32)
  70. padding_im[:, :im_h, :im_w] = img
  71. padding_imgs.append(padding_im)
  72. inputs['image'] = np.stack(padding_imgs, axis=0)
  73. return inputs
  74. def preprocess(self, image_list):
  75. preprocess_ops = []
  76. for op_info in [
  77. {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
  78. {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
  79. {'type': 'Permute'},
  80. {'stride': 32, 'type': 'PadStride'}
  81. ]:
  82. new_op_info = op_info.copy()
  83. op_type = new_op_info.pop('type')
  84. preprocess_ops.append(eval(op_type)(**new_op_info))
  85. inputs = []
  86. for im_path in image_list:
  87. im, im_info = preprocess(im_path, preprocess_ops)
  88. inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
  89. return inputs
  90. def __call__(self, image_list, thr=0.7, batch_size=16):
  91. res = []
  92. imgs = []
  93. for i in range(len(image_list)):
  94. if not isinstance(image_list[i], np.ndarray):
  95. imgs.append(np.array(image_list[i]))
  96. else: imgs.append(image_list[i])
  97. batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
  98. for i in range(batch_loop_cnt):
  99. start_index = i * batch_size
  100. end_index = min((i + 1) * batch_size, len(imgs))
  101. batch_image_list = imgs[start_index:end_index]
  102. inputs = self.preprocess(batch_image_list)
  103. for ins in inputs:
  104. bb = []
  105. for b in self.ort_sess.run(None, ins)[0]:
  106. clsid, bbox, score = int(b[0]), b[2:], b[1]
  107. if score < thr:
  108. continue
  109. if clsid >= len(self.label_list):
  110. cron_logger.warning(f"bad category id")
  111. continue
  112. bb.append({
  113. "type": self.label_list[clsid].lower(),
  114. "bbox": [float(t) for t in bbox.tolist()],
  115. "score": float(score)
  116. })
  117. res.append(bb)
  118. #seeit.save_results(image_list, res, self.label_list, threshold=thr)
  119. return res