You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

operators.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import sys
  18. import six
  19. import cv2
  20. import numpy as np
  21. import math
  22. from PIL import Image
  23. class DecodeImage(object):
  24. """ decode image """
  25. def __init__(self,
  26. img_mode='RGB',
  27. channel_first=False,
  28. ignore_orientation=False,
  29. **kwargs):
  30. self.img_mode = img_mode
  31. self.channel_first = channel_first
  32. self.ignore_orientation = ignore_orientation
  33. def __call__(self, data):
  34. img = data['image']
  35. if six.PY2:
  36. assert isinstance(img, str) and len(
  37. img) > 0, "invalid input 'img' in DecodeImage"
  38. else:
  39. assert isinstance(img, bytes) and len(
  40. img) > 0, "invalid input 'img' in DecodeImage"
  41. img = np.frombuffer(img, dtype='uint8')
  42. if self.ignore_orientation:
  43. img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
  44. cv2.IMREAD_COLOR)
  45. else:
  46. img = cv2.imdecode(img, 1)
  47. if img is None:
  48. return None
  49. if self.img_mode == 'GRAY':
  50. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  51. elif self.img_mode == 'RGB':
  52. assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
  53. img.shape)
  54. img = img[:, :, ::-1]
  55. if self.channel_first:
  56. img = img.transpose((2, 0, 1))
  57. data['image'] = img
  58. return data
  59. class StandardizeImage(object):
  60. """normalize image
  61. Args:
  62. mean (list): im - mean
  63. std (list): im / std
  64. is_scale (bool): whether need im / 255
  65. norm_type (str): type in ['mean_std', 'none']
  66. """
  67. def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
  68. self.mean = mean
  69. self.std = std
  70. self.is_scale = is_scale
  71. self.norm_type = norm_type
  72. def __call__(self, im, im_info):
  73. """
  74. Args:
  75. im (np.ndarray): image (np.ndarray)
  76. im_info (dict): info of image
  77. Returns:
  78. im (np.ndarray): processed image (np.ndarray)
  79. im_info (dict): info of processed image
  80. """
  81. im = im.astype(np.float32, copy=False)
  82. if self.is_scale:
  83. scale = 1.0 / 255.0
  84. im *= scale
  85. if self.norm_type == 'mean_std':
  86. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  87. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  88. im -= mean
  89. im /= std
  90. return im, im_info
  91. class NormalizeImage(object):
  92. """ normalize image such as subtract mean, divide std
  93. """
  94. def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
  95. if isinstance(scale, str):
  96. scale = eval(scale)
  97. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  98. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  99. std = std if std is not None else [0.229, 0.224, 0.225]
  100. shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
  101. self.mean = np.array(mean).reshape(shape).astype('float32')
  102. self.std = np.array(std).reshape(shape).astype('float32')
  103. def __call__(self, data):
  104. img = data['image']
  105. from PIL import Image
  106. if isinstance(img, Image.Image):
  107. img = np.array(img)
  108. assert isinstance(img,
  109. np.ndarray), "invalid input 'img' in NormalizeImage"
  110. data['image'] = (
  111. img.astype('float32') * self.scale - self.mean) / self.std
  112. return data
  113. class ToCHWImage(object):
  114. """ convert hwc image to chw image
  115. """
  116. def __init__(self, **kwargs):
  117. pass
  118. def __call__(self, data):
  119. img = data['image']
  120. from PIL import Image
  121. if isinstance(img, Image.Image):
  122. img = np.array(img)
  123. data['image'] = img.transpose((2, 0, 1))
  124. return data
  125. class Fasttext(object):
  126. def __init__(self, path="None", **kwargs):
  127. import fasttext
  128. self.fast_model = fasttext.load_model(path)
  129. def __call__(self, data):
  130. label = data['label']
  131. fast_label = self.fast_model[label]
  132. data['fast_label'] = fast_label
  133. return data
  134. class KeepKeys(object):
  135. def __init__(self, keep_keys, **kwargs):
  136. self.keep_keys = keep_keys
  137. def __call__(self, data):
  138. data_list = []
  139. for key in self.keep_keys:
  140. data_list.append(data[key])
  141. return data_list
  142. class Pad(object):
  143. def __init__(self, size=None, size_div=32, **kwargs):
  144. if size is not None and not isinstance(size, (int, list, tuple)):
  145. raise TypeError("Type of target_size is invalid. Now is {}".format(
  146. type(size)))
  147. if isinstance(size, int):
  148. size = [size, size]
  149. self.size = size
  150. self.size_div = size_div
  151. def __call__(self, data):
  152. img = data['image']
  153. img_h, img_w = img.shape[0], img.shape[1]
  154. if self.size:
  155. resize_h2, resize_w2 = self.size
  156. assert (
  157. img_h < resize_h2 and img_w < resize_w2
  158. ), '(h, w) of target size should be greater than (img_h, img_w)'
  159. else:
  160. resize_h2 = max(
  161. int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
  162. self.size_div)
  163. resize_w2 = max(
  164. int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
  165. self.size_div)
  166. img = cv2.copyMakeBorder(
  167. img,
  168. 0,
  169. resize_h2 - img_h,
  170. 0,
  171. resize_w2 - img_w,
  172. cv2.BORDER_CONSTANT,
  173. value=0)
  174. data['image'] = img
  175. return data
  176. class LinearResize(object):
  177. """resize image by target_size and max_size
  178. Args:
  179. target_size (int): the target size of image
  180. keep_ratio (bool): whether keep_ratio or not, default true
  181. interp (int): method of resize
  182. """
  183. def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
  184. if isinstance(target_size, int):
  185. target_size = [target_size, target_size]
  186. self.target_size = target_size
  187. self.keep_ratio = keep_ratio
  188. self.interp = interp
  189. def __call__(self, im, im_info):
  190. """
  191. Args:
  192. im (np.ndarray): image (np.ndarray)
  193. im_info (dict): info of image
  194. Returns:
  195. im (np.ndarray): processed image (np.ndarray)
  196. im_info (dict): info of processed image
  197. """
  198. assert len(self.target_size) == 2
  199. assert self.target_size[0] > 0 and self.target_size[1] > 0
  200. _im_channel = im.shape[2]
  201. im_scale_y, im_scale_x = self.generate_scale(im)
  202. im = cv2.resize(
  203. im,
  204. None,
  205. None,
  206. fx=im_scale_x,
  207. fy=im_scale_y,
  208. interpolation=self.interp)
  209. im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
  210. im_info['scale_factor'] = np.array(
  211. [im_scale_y, im_scale_x]).astype('float32')
  212. return im, im_info
  213. def generate_scale(self, im):
  214. """
  215. Args:
  216. im (np.ndarray): image (np.ndarray)
  217. Returns:
  218. im_scale_x: the resize ratio of X
  219. im_scale_y: the resize ratio of Y
  220. """
  221. origin_shape = im.shape[:2]
  222. _im_c = im.shape[2]
  223. if self.keep_ratio:
  224. im_size_min = np.min(origin_shape)
  225. im_size_max = np.max(origin_shape)
  226. target_size_min = np.min(self.target_size)
  227. target_size_max = np.max(self.target_size)
  228. im_scale = float(target_size_min) / float(im_size_min)
  229. if np.round(im_scale * im_size_max) > target_size_max:
  230. im_scale = float(target_size_max) / float(im_size_max)
  231. im_scale_x = im_scale
  232. im_scale_y = im_scale
  233. else:
  234. resize_h, resize_w = self.target_size
  235. im_scale_y = resize_h / float(origin_shape[0])
  236. im_scale_x = resize_w / float(origin_shape[1])
  237. return im_scale_y, im_scale_x
  238. class Resize(object):
  239. def __init__(self, size=(640, 640), **kwargs):
  240. self.size = size
  241. def resize_image(self, img):
  242. resize_h, resize_w = self.size
  243. ori_h, ori_w = img.shape[:2] # (h, w, c)
  244. ratio_h = float(resize_h) / ori_h
  245. ratio_w = float(resize_w) / ori_w
  246. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  247. return img, [ratio_h, ratio_w]
  248. def __call__(self, data):
  249. img = data['image']
  250. if 'polys' in data:
  251. text_polys = data['polys']
  252. img_resize, [ratio_h, ratio_w] = self.resize_image(img)
  253. if 'polys' in data:
  254. new_boxes = []
  255. for box in text_polys:
  256. new_box = []
  257. for cord in box:
  258. new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
  259. new_boxes.append(new_box)
  260. data['polys'] = np.array(new_boxes, dtype=np.float32)
  261. data['image'] = img_resize
  262. return data
  263. class DetResizeForTest(object):
  264. def __init__(self, **kwargs):
  265. super(DetResizeForTest, self).__init__()
  266. self.resize_type = 0
  267. self.keep_ratio = False
  268. if 'image_shape' in kwargs:
  269. self.image_shape = kwargs['image_shape']
  270. self.resize_type = 1
  271. if 'keep_ratio' in kwargs:
  272. self.keep_ratio = kwargs['keep_ratio']
  273. elif 'limit_side_len' in kwargs:
  274. self.limit_side_len = kwargs['limit_side_len']
  275. self.limit_type = kwargs.get('limit_type', 'min')
  276. elif 'resize_long' in kwargs:
  277. self.resize_type = 2
  278. self.resize_long = kwargs.get('resize_long', 960)
  279. else:
  280. self.limit_side_len = 736
  281. self.limit_type = 'min'
  282. def __call__(self, data):
  283. img = data['image']
  284. src_h, src_w, _ = img.shape
  285. if sum([src_h, src_w]) < 64:
  286. img = self.image_padding(img)
  287. if self.resize_type == 0:
  288. # img, shape = self.resize_image_type0(img)
  289. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  290. elif self.resize_type == 2:
  291. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  292. else:
  293. # img, shape = self.resize_image_type1(img)
  294. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  295. data['image'] = img
  296. data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
  297. return data
  298. def image_padding(self, im, value=0):
  299. h, w, c = im.shape
  300. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  301. im_pad[:h, :w, :] = im
  302. return im_pad
  303. def resize_image_type1(self, img):
  304. resize_h, resize_w = self.image_shape
  305. ori_h, ori_w = img.shape[:2] # (h, w, c)
  306. if self.keep_ratio is True:
  307. resize_w = ori_w * resize_h / ori_h
  308. N = math.ceil(resize_w / 32)
  309. resize_w = N * 32
  310. ratio_h = float(resize_h) / ori_h
  311. ratio_w = float(resize_w) / ori_w
  312. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  313. # return img, np.array([ori_h, ori_w])
  314. return img, [ratio_h, ratio_w]
  315. def resize_image_type0(self, img):
  316. """
  317. resize image to a size multiple of 32 which is required by the network
  318. args:
  319. img(array): array with shape [h, w, c]
  320. return(tuple):
  321. img, (ratio_h, ratio_w)
  322. """
  323. limit_side_len = self.limit_side_len
  324. h, w, c = img.shape
  325. # limit the max side
  326. if self.limit_type == 'max':
  327. if max(h, w) > limit_side_len:
  328. if h > w:
  329. ratio = float(limit_side_len) / h
  330. else:
  331. ratio = float(limit_side_len) / w
  332. else:
  333. ratio = 1.
  334. elif self.limit_type == 'min':
  335. if min(h, w) < limit_side_len:
  336. if h < w:
  337. ratio = float(limit_side_len) / h
  338. else:
  339. ratio = float(limit_side_len) / w
  340. else:
  341. ratio = 1.
  342. elif self.limit_type == 'resize_long':
  343. ratio = float(limit_side_len) / max(h, w)
  344. else:
  345. raise Exception('not support limit type, image ')
  346. resize_h = int(h * ratio)
  347. resize_w = int(w * ratio)
  348. resize_h = max(int(round(resize_h / 32) * 32), 32)
  349. resize_w = max(int(round(resize_w / 32) * 32), 32)
  350. try:
  351. if int(resize_w) <= 0 or int(resize_h) <= 0:
  352. return None, (None, None)
  353. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  354. except BaseException:
  355. logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
  356. sys.exit(0)
  357. ratio_h = resize_h / float(h)
  358. ratio_w = resize_w / float(w)
  359. return img, [ratio_h, ratio_w]
  360. def resize_image_type2(self, img):
  361. h, w, _ = img.shape
  362. resize_w = w
  363. resize_h = h
  364. if resize_h > resize_w:
  365. ratio = float(self.resize_long) / resize_h
  366. else:
  367. ratio = float(self.resize_long) / resize_w
  368. resize_h = int(resize_h * ratio)
  369. resize_w = int(resize_w * ratio)
  370. max_stride = 128
  371. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  372. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  373. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  374. ratio_h = resize_h / float(h)
  375. ratio_w = resize_w / float(w)
  376. return img, [ratio_h, ratio_w]
  377. class E2EResizeForTest(object):
  378. def __init__(self, **kwargs):
  379. super(E2EResizeForTest, self).__init__()
  380. self.max_side_len = kwargs['max_side_len']
  381. self.valid_set = kwargs['valid_set']
  382. def __call__(self, data):
  383. img = data['image']
  384. src_h, src_w, _ = img.shape
  385. if self.valid_set == 'totaltext':
  386. im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
  387. img, max_side_len=self.max_side_len)
  388. else:
  389. im_resized, (ratio_h, ratio_w) = self.resize_image(
  390. img, max_side_len=self.max_side_len)
  391. data['image'] = im_resized
  392. data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
  393. return data
  394. def resize_image_for_totaltext(self, im, max_side_len=512):
  395. h, w, _ = im.shape
  396. resize_w = w
  397. resize_h = h
  398. ratio = 1.25
  399. if h * ratio > max_side_len:
  400. ratio = float(max_side_len) / resize_h
  401. resize_h = int(resize_h * ratio)
  402. resize_w = int(resize_w * ratio)
  403. max_stride = 128
  404. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  405. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  406. im = cv2.resize(im, (int(resize_w), int(resize_h)))
  407. ratio_h = resize_h / float(h)
  408. ratio_w = resize_w / float(w)
  409. return im, (ratio_h, ratio_w)
  410. def resize_image(self, im, max_side_len=512):
  411. """
  412. resize image to a size multiple of max_stride which is required by the network
  413. :param im: the resized image
  414. :param max_side_len: limit of max image size to avoid out of memory in gpu
  415. :return: the resized image and the resize ratio
  416. """
  417. h, w, _ = im.shape
  418. resize_w = w
  419. resize_h = h
  420. # Fix the longer side
  421. if resize_h > resize_w:
  422. ratio = float(max_side_len) / resize_h
  423. else:
  424. ratio = float(max_side_len) / resize_w
  425. resize_h = int(resize_h * ratio)
  426. resize_w = int(resize_w * ratio)
  427. max_stride = 128
  428. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  429. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  430. im = cv2.resize(im, (int(resize_w), int(resize_h)))
  431. ratio_h = resize_h / float(h)
  432. ratio_w = resize_w / float(w)
  433. return im, (ratio_h, ratio_w)
  434. class KieResize(object):
  435. def __init__(self, **kwargs):
  436. super(KieResize, self).__init__()
  437. self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
  438. 'img_scale'][1]
  439. def __call__(self, data):
  440. img = data['image']
  441. points = data['points']
  442. src_h, src_w, _ = img.shape
  443. im_resized, scale_factor, [ratio_h, ratio_w
  444. ], [new_h, new_w] = self.resize_image(img)
  445. resize_points = self.resize_boxes(img, points, scale_factor)
  446. data['ori_image'] = img
  447. data['ori_boxes'] = points
  448. data['points'] = resize_points
  449. data['image'] = im_resized
  450. data['shape'] = np.array([new_h, new_w])
  451. return data
  452. def resize_image(self, img):
  453. norm_img = np.zeros([1024, 1024, 3], dtype='float32')
  454. scale = [512, 1024]
  455. h, w = img.shape[:2]
  456. max_long_edge = max(scale)
  457. max_short_edge = min(scale)
  458. scale_factor = min(max_long_edge / max(h, w),
  459. max_short_edge / min(h, w))
  460. resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
  461. scale_factor) + 0.5)
  462. max_stride = 32
  463. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  464. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  465. im = cv2.resize(img, (resize_w, resize_h))
  466. new_h, new_w = im.shape[:2]
  467. w_scale = new_w / w
  468. h_scale = new_h / h
  469. scale_factor = np.array(
  470. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  471. norm_img[:new_h, :new_w, :] = im
  472. return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
  473. def resize_boxes(self, im, points, scale_factor):
  474. points = points * scale_factor
  475. img_shape = im.shape[:2]
  476. points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
  477. points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
  478. return points
  479. class SRResize(object):
  480. def __init__(self,
  481. imgH=32,
  482. imgW=128,
  483. down_sample_scale=4,
  484. keep_ratio=False,
  485. min_ratio=1,
  486. mask=False,
  487. infer_mode=False,
  488. **kwargs):
  489. self.imgH = imgH
  490. self.imgW = imgW
  491. self.keep_ratio = keep_ratio
  492. self.min_ratio = min_ratio
  493. self.down_sample_scale = down_sample_scale
  494. self.mask = mask
  495. self.infer_mode = infer_mode
  496. def __call__(self, data):
  497. imgH = self.imgH
  498. imgW = self.imgW
  499. images_lr = data["image_lr"]
  500. transform2 = ResizeNormalize(
  501. (imgW // self.down_sample_scale, imgH // self.down_sample_scale))
  502. images_lr = transform2(images_lr)
  503. data["img_lr"] = images_lr
  504. if self.infer_mode:
  505. return data
  506. images_HR = data["image_hr"]
  507. _label_strs = data["label"]
  508. transform = ResizeNormalize((imgW, imgH))
  509. images_HR = transform(images_HR)
  510. data["img_hr"] = images_HR
  511. return data
  512. class ResizeNormalize(object):
  513. def __init__(self, size, interpolation=Image.BICUBIC):
  514. self.size = size
  515. self.interpolation = interpolation
  516. def __call__(self, img):
  517. img = img.resize(self.size, self.interpolation)
  518. img_numpy = np.array(img).astype("float32")
  519. img_numpy = img_numpy.transpose((2, 0, 1)) / 255
  520. return img_numpy
  521. class GrayImageChannelFormat(object):
  522. """
  523. format gray scale image's channel: (3,h,w) -> (1,h,w)
  524. Args:
  525. inverse: inverse gray image
  526. """
  527. def __init__(self, inverse=False, **kwargs):
  528. self.inverse = inverse
  529. def __call__(self, data):
  530. img = data['image']
  531. img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  532. img_expanded = np.expand_dims(img_single_channel, 0)
  533. if self.inverse:
  534. data['image'] = np.abs(img_expanded - 1)
  535. else:
  536. data['image'] = img_expanded
  537. data['src_image'] = img
  538. return data
  539. class Permute(object):
  540. """permute image
  541. Args:
  542. to_bgr (bool): whether convert RGB to BGR
  543. channel_first (bool): whether convert HWC to CHW
  544. """
  545. def __init__(self, ):
  546. super(Permute, self).__init__()
  547. def __call__(self, im, im_info):
  548. """
  549. Args:
  550. im (np.ndarray): image (np.ndarray)
  551. im_info (dict): info of image
  552. Returns:
  553. im (np.ndarray): processed image (np.ndarray)
  554. im_info (dict): info of processed image
  555. """
  556. im = im.transpose((2, 0, 1)).copy()
  557. return im, im_info
  558. class PadStride(object):
  559. """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
  560. Args:
  561. stride (bool): model with FPN need image shape % stride == 0
  562. """
  563. def __init__(self, stride=0):
  564. self.coarsest_stride = stride
  565. def __call__(self, im, im_info):
  566. """
  567. Args:
  568. im (np.ndarray): image (np.ndarray)
  569. im_info (dict): info of image
  570. Returns:
  571. im (np.ndarray): processed image (np.ndarray)
  572. im_info (dict): info of processed image
  573. """
  574. coarsest_stride = self.coarsest_stride
  575. if coarsest_stride <= 0:
  576. return im, im_info
  577. im_c, im_h, im_w = im.shape
  578. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  579. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  580. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  581. padding_im[:, :im_h, :im_w] = im
  582. return padding_im, im_info
  583. def decode_image(im_file, im_info):
  584. """read rgb image
  585. Args:
  586. im_file (str|np.ndarray): input can be image path or np.ndarray
  587. im_info (dict): info of image
  588. Returns:
  589. im (np.ndarray): processed image (np.ndarray)
  590. im_info (dict): info of processed image
  591. """
  592. if isinstance(im_file, str):
  593. with open(im_file, 'rb') as f:
  594. im_read = f.read()
  595. data = np.frombuffer(im_read, dtype='uint8')
  596. im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
  597. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  598. else:
  599. im = im_file
  600. im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
  601. im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  602. return im, im_info
  603. def preprocess(im, preprocess_ops):
  604. # process image by preprocess_ops
  605. im_info = {
  606. 'scale_factor': np.array(
  607. [1., 1.], dtype=np.float32),
  608. 'im_shape': None,
  609. }
  610. im, im_info = decode_image(im, im_info)
  611. for operator in preprocess_ops:
  612. im, im_info = operator(im, im_info)
  613. return im, im_info
  614. def nms(bboxes, scores, iou_thresh):
  615. import numpy as np
  616. x1 = bboxes[:, 0]
  617. y1 = bboxes[:, 1]
  618. x2 = bboxes[:, 2]
  619. y2 = bboxes[:, 3]
  620. areas = (y2 - y1) * (x2 - x1)
  621. indices = []
  622. index = scores.argsort()[::-1]
  623. while index.size > 0:
  624. i = index[0]
  625. indices.append(i)
  626. x11 = np.maximum(x1[i], x1[index[1:]])
  627. y11 = np.maximum(y1[i], y1[index[1:]])
  628. x22 = np.minimum(x2[i], x2[index[1:]])
  629. y22 = np.minimum(y2[i], y2[index[1:]])
  630. w = np.maximum(0, x22 - x11 + 1)
  631. h = np.maximum(0, y22 - y11 + 1)
  632. overlaps = w * h
  633. ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
  634. idx = np.where(ious <= iou_thresh)[0]
  635. index = index[idx + 1]
  636. return indices