您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import copy
  2. import numpy as np
  3. import cv2
  4. from shapely.geometry import Polygon
  5. import pyclipper
  6. def build_post_process(config, global_config=None):
  7. support_dict = ['DBPostProcess', 'CTCLabelDecode']
  8. config = copy.deepcopy(config)
  9. module_name = config.pop('name')
  10. if module_name == "None":
  11. return
  12. if global_config is not None:
  13. config.update(global_config)
  14. assert module_name in support_dict, Exception(
  15. 'post process only support {}'.format(support_dict))
  16. module_class = eval(module_name)(**config)
  17. return module_class
  18. class DBPostProcess(object):
  19. """
  20. The post process for Differentiable Binarization (DB).
  21. """
  22. def __init__(self,
  23. thresh=0.3,
  24. box_thresh=0.7,
  25. max_candidates=1000,
  26. unclip_ratio=2.0,
  27. use_dilation=False,
  28. score_mode="fast",
  29. box_type='quad',
  30. **kwargs):
  31. self.thresh = thresh
  32. self.box_thresh = box_thresh
  33. self.max_candidates = max_candidates
  34. self.unclip_ratio = unclip_ratio
  35. self.min_size = 3
  36. self.score_mode = score_mode
  37. self.box_type = box_type
  38. assert score_mode in [
  39. "slow", "fast"
  40. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  41. self.dilation_kernel = None if not use_dilation else np.array(
  42. [[1, 1], [1, 1]])
  43. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  44. '''
  45. _bitmap: single map with shape (1, H, W),
  46. whose values are binarized as {0, 1}
  47. '''
  48. bitmap = _bitmap
  49. height, width = bitmap.shape
  50. boxes = []
  51. scores = []
  52. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
  53. cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  54. for contour in contours[:self.max_candidates]:
  55. epsilon = 0.002 * cv2.arcLength(contour, True)
  56. approx = cv2.approxPolyDP(contour, epsilon, True)
  57. points = approx.reshape((-1, 2))
  58. if points.shape[0] < 4:
  59. continue
  60. score = self.box_score_fast(pred, points.reshape(-1, 2))
  61. if self.box_thresh > score:
  62. continue
  63. if points.shape[0] > 2:
  64. box = self.unclip(points, self.unclip_ratio)
  65. if len(box) > 1:
  66. continue
  67. else:
  68. continue
  69. box = box.reshape(-1, 2)
  70. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  71. if sside < self.min_size + 2:
  72. continue
  73. box = np.array(box)
  74. box[:, 0] = np.clip(
  75. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  76. box[:, 1] = np.clip(
  77. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  78. boxes.append(box.tolist())
  79. scores.append(score)
  80. return boxes, scores
  81. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  82. '''
  83. _bitmap: single map with shape (1, H, W),
  84. whose values are binarized as {0, 1}
  85. '''
  86. bitmap = _bitmap
  87. height, width = bitmap.shape
  88. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  89. cv2.CHAIN_APPROX_SIMPLE)
  90. if len(outs) == 3:
  91. img, contours, _ = outs[0], outs[1], outs[2]
  92. elif len(outs) == 2:
  93. contours, _ = outs[0], outs[1]
  94. num_contours = min(len(contours), self.max_candidates)
  95. boxes = []
  96. scores = []
  97. for index in range(num_contours):
  98. contour = contours[index]
  99. points, sside = self.get_mini_boxes(contour)
  100. if sside < self.min_size:
  101. continue
  102. points = np.array(points)
  103. if self.score_mode == "fast":
  104. score = self.box_score_fast(pred, points.reshape(-1, 2))
  105. else:
  106. score = self.box_score_slow(pred, contour)
  107. if self.box_thresh > score:
  108. continue
  109. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  110. box, sside = self.get_mini_boxes(box)
  111. if sside < self.min_size + 2:
  112. continue
  113. box = np.array(box)
  114. box[:, 0] = np.clip(
  115. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  116. box[:, 1] = np.clip(
  117. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  118. boxes.append(box.astype("int32"))
  119. scores.append(score)
  120. return np.array(boxes, dtype="int32"), scores
  121. def unclip(self, box, unclip_ratio):
  122. poly = Polygon(box)
  123. distance = poly.area * unclip_ratio / poly.length
  124. offset = pyclipper.PyclipperOffset()
  125. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  126. expanded = np.array(offset.Execute(distance))
  127. return expanded
  128. def get_mini_boxes(self, contour):
  129. bounding_box = cv2.minAreaRect(contour)
  130. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  131. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  132. if points[1][1] > points[0][1]:
  133. index_1 = 0
  134. index_4 = 1
  135. else:
  136. index_1 = 1
  137. index_4 = 0
  138. if points[3][1] > points[2][1]:
  139. index_2 = 2
  140. index_3 = 3
  141. else:
  142. index_2 = 3
  143. index_3 = 2
  144. box = [
  145. points[index_1], points[index_2], points[index_3], points[index_4]
  146. ]
  147. return box, min(bounding_box[1])
  148. def box_score_fast(self, bitmap, _box):
  149. '''
  150. box_score_fast: use bbox mean score as the mean score
  151. '''
  152. h, w = bitmap.shape[:2]
  153. box = _box.copy()
  154. xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
  155. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
  156. ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
  157. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
  158. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  159. box[:, 0] = box[:, 0] - xmin
  160. box[:, 1] = box[:, 1] - ymin
  161. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
  162. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  163. def box_score_slow(self, bitmap, contour):
  164. '''
  165. box_score_slow: use polyon mean score as the mean score
  166. '''
  167. h, w = bitmap.shape[:2]
  168. contour = contour.copy()
  169. contour = np.reshape(contour, (-1, 2))
  170. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  171. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  172. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  173. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  174. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  175. contour[:, 0] = contour[:, 0] - xmin
  176. contour[:, 1] = contour[:, 1] - ymin
  177. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
  178. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  179. def __call__(self, outs_dict, shape_list):
  180. pred = outs_dict['maps']
  181. if not isinstance(pred, np.ndarray):
  182. pred = pred.numpy()
  183. pred = pred[:, 0, :, :]
  184. segmentation = pred > self.thresh
  185. boxes_batch = []
  186. for batch_index in range(pred.shape[0]):
  187. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  188. if self.dilation_kernel is not None:
  189. mask = cv2.dilate(
  190. np.array(segmentation[batch_index]).astype(np.uint8),
  191. self.dilation_kernel)
  192. else:
  193. mask = segmentation[batch_index]
  194. if self.box_type == 'poly':
  195. boxes, scores = self.polygons_from_bitmap(pred[batch_index],
  196. mask, src_w, src_h)
  197. elif self.box_type == 'quad':
  198. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  199. src_w, src_h)
  200. else:
  201. raise ValueError(
  202. "box_type can only be one of ['quad', 'poly']")
  203. boxes_batch.append({'points': boxes})
  204. return boxes_batch
  205. class BaseRecLabelDecode(object):
  206. """ Convert between text-label and text-index """
  207. def __init__(self, character_dict_path=None, use_space_char=False):
  208. self.beg_str = "sos"
  209. self.end_str = "eos"
  210. self.reverse = False
  211. self.character_str = []
  212. if character_dict_path is None:
  213. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  214. dict_character = list(self.character_str)
  215. else:
  216. with open(character_dict_path, "rb") as fin:
  217. lines = fin.readlines()
  218. for line in lines:
  219. line = line.decode('utf-8').strip("\n").strip("\r\n")
  220. self.character_str.append(line)
  221. if use_space_char:
  222. self.character_str.append(" ")
  223. dict_character = list(self.character_str)
  224. if 'arabic' in character_dict_path:
  225. self.reverse = True
  226. dict_character = self.add_special_char(dict_character)
  227. self.dict = {}
  228. for i, char in enumerate(dict_character):
  229. self.dict[char] = i
  230. self.character = dict_character
  231. def pred_reverse(self, pred):
  232. pred_re = []
  233. c_current = ''
  234. for c in pred:
  235. if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
  236. if c_current != '':
  237. pred_re.append(c_current)
  238. pred_re.append(c)
  239. c_current = ''
  240. else:
  241. c_current += c
  242. if c_current != '':
  243. pred_re.append(c_current)
  244. return ''.join(pred_re[::-1])
  245. def add_special_char(self, dict_character):
  246. return dict_character
  247. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  248. """ convert text-index into text-label. """
  249. result_list = []
  250. ignored_tokens = self.get_ignored_tokens()
  251. batch_size = len(text_index)
  252. for batch_idx in range(batch_size):
  253. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  254. if is_remove_duplicate:
  255. selection[1:] = text_index[batch_idx][1:] != text_index[
  256. batch_idx][:-1]
  257. for ignored_token in ignored_tokens:
  258. selection &= text_index[batch_idx] != ignored_token
  259. char_list = [
  260. self.character[text_id]
  261. for text_id in text_index[batch_idx][selection]
  262. ]
  263. if text_prob is not None:
  264. conf_list = text_prob[batch_idx][selection]
  265. else:
  266. conf_list = [1] * len(selection)
  267. if len(conf_list) == 0:
  268. conf_list = [0]
  269. text = ''.join(char_list)
  270. if self.reverse: # for arabic rec
  271. text = self.pred_reverse(text)
  272. result_list.append((text, np.mean(conf_list).tolist()))
  273. return result_list
  274. def get_ignored_tokens(self):
  275. return [0] # for ctc blank
  276. class CTCLabelDecode(BaseRecLabelDecode):
  277. """ Convert between text-label and text-index """
  278. def __init__(self, character_dict_path=None, use_space_char=False,
  279. **kwargs):
  280. super(CTCLabelDecode, self).__init__(character_dict_path,
  281. use_space_char)
  282. def __call__(self, preds, label=None, *args, **kwargs):
  283. if isinstance(preds, tuple) or isinstance(preds, list):
  284. preds = preds[-1]
  285. if not isinstance(preds, np.ndarray):
  286. preds = preds.numpy()
  287. preds_idx = preds.argmax(axis=2)
  288. preds_prob = preds.max(axis=2)
  289. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  290. if label is None:
  291. return text
  292. label = self.decode(label)
  293. return text, label
  294. def add_special_char(self, dict_character):
  295. dict_character = ['blank'] + dict_character
  296. return dict_character