You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

image.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import matplotlib.pyplot as plt
  3. import mmcv
  4. import numpy as np
  5. import pycocotools.mask as mask_util
  6. from matplotlib.collections import PatchCollection
  7. from matplotlib.patches import Polygon
  8. from ..utils import mask2ndarray
  9. EPS = 1e-2
  10. def color_val_matplotlib(color):
  11. """Convert various input in BGR order to normalized RGB matplotlib color
  12. tuples,
  13. Args:
  14. color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
  15. Returns:
  16. tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
  17. """
  18. color = mmcv.color_val(color)
  19. color = [color / 255 for color in color[::-1]]
  20. return tuple(color)
  21. def imshow_det_bboxes(img,
  22. bboxes,
  23. labels,
  24. segms=None,
  25. class_names=None,
  26. score_thr=0,
  27. bbox_color='green',
  28. text_color='green',
  29. mask_color=None,
  30. thickness=2,
  31. font_size=13,
  32. win_name='',
  33. show=True,
  34. wait_time=0,
  35. out_file=None):
  36. """Draw bboxes and class labels (with scores) on an image.
  37. Args:
  38. img (str or ndarray): The image to be displayed.
  39. bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
  40. (n, 5).
  41. labels (ndarray): Labels of bboxes.
  42. segms (ndarray or None): Masks, shaped (n,h,w) or None
  43. class_names (list[str]): Names of each classes.
  44. score_thr (float): Minimum score of bboxes to be shown. Default: 0
  45. bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
  46. The tuple of color should be in BGR order. Default: 'green'
  47. text_color (str or tuple(int) or :obj:`Color`):Color of texts.
  48. The tuple of color should be in BGR order. Default: 'green'
  49. mask_color (str or tuple(int) or :obj:`Color`, optional):
  50. Color of masks. The tuple of color should be in BGR order.
  51. Default: None
  52. thickness (int): Thickness of lines. Default: 2
  53. font_size (int): Font size of texts. Default: 13
  54. show (bool): Whether to show the image. Default: True
  55. win_name (str): The window name. Default: ''
  56. wait_time (float): Value of waitKey param. Default: 0.
  57. out_file (str, optional): The filename to write the image.
  58. Default: None
  59. Returns:
  60. ndarray: The image with bboxes drawn on it.
  61. """
  62. assert bboxes.ndim == 2, \
  63. f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
  64. assert labels.ndim == 1, \
  65. f' labels ndim should be 1, but its ndim is {labels.ndim}.'
  66. assert bboxes.shape[0] == labels.shape[0], \
  67. 'bboxes.shape[0] and labels.shape[0] should have the same length.'
  68. assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
  69. f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
  70. img = mmcv.imread(img).astype(np.uint8)
  71. if score_thr > 0:
  72. assert bboxes.shape[1] == 5
  73. scores = bboxes[:, -1]
  74. inds = scores > score_thr
  75. bboxes = bboxes[inds, :]
  76. labels = labels[inds]
  77. if segms is not None:
  78. segms = segms[inds, ...]
  79. mask_colors = []
  80. if labels.shape[0] > 0:
  81. if mask_color is None:
  82. # Get random state before set seed, and restore random state later.
  83. # Prevent loss of randomness.
  84. # See: https://github.com/open-mmlab/mmdetection/issues/5844
  85. state = np.random.get_state()
  86. # random color
  87. np.random.seed(42)
  88. mask_colors = [
  89. np.random.randint(0, 256, (1, 3), dtype=np.uint8)
  90. for _ in range(max(labels) + 1)
  91. ]
  92. np.random.set_state(state)
  93. else:
  94. # specify color
  95. mask_colors = [
  96. np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
  97. ] * (
  98. max(labels) + 1)
  99. bbox_color = color_val_matplotlib(bbox_color)
  100. text_color = color_val_matplotlib(text_color)
  101. img = mmcv.bgr2rgb(img)
  102. width, height = img.shape[1], img.shape[0]
  103. img = np.ascontiguousarray(img)
  104. fig = plt.figure(win_name, frameon=False)
  105. plt.title(win_name)
  106. canvas = fig.canvas
  107. dpi = fig.get_dpi()
  108. # add a small EPS to avoid precision lost due to matplotlib's truncation
  109. # (https://github.com/matplotlib/matplotlib/issues/15363)
  110. fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
  111. # remove white edges by set subplot margin
  112. plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
  113. ax = plt.gca()
  114. ax.axis('off')
  115. polygons = []
  116. color = []
  117. for i, (bbox, label) in enumerate(zip(bboxes, labels)):
  118. bbox_int = bbox.astype(np.int32)
  119. poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
  120. [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
  121. np_poly = np.array(poly).reshape((4, 2))
  122. polygons.append(Polygon(np_poly))
  123. color.append(bbox_color)
  124. label_text = class_names[
  125. label] if class_names is not None else f'class {label}'
  126. if len(bbox) > 4:
  127. label_text += f'|{bbox[-1]:.02f}'
  128. ax.text(
  129. bbox_int[0],
  130. bbox_int[1],
  131. f'{label_text}',
  132. bbox={
  133. 'facecolor': 'black',
  134. 'alpha': 0.8,
  135. 'pad': 0.7,
  136. 'edgecolor': 'none'
  137. },
  138. color=text_color,
  139. fontsize=font_size,
  140. verticalalignment='top',
  141. horizontalalignment='left')
  142. if segms is not None:
  143. color_mask = mask_colors[labels[i]]
  144. mask = segms[i].astype(bool)
  145. img[mask] = img[mask] * 0.5 + color_mask * 0.5
  146. plt.imshow(img)
  147. p = PatchCollection(
  148. polygons, facecolor='none', edgecolors=color, linewidths=thickness)
  149. ax.add_collection(p)
  150. stream, _ = canvas.print_to_buffer()
  151. buffer = np.frombuffer(stream, dtype='uint8')
  152. img_rgba = buffer.reshape(height, width, 4)
  153. rgb, alpha = np.split(img_rgba, [3], axis=2)
  154. img = rgb.astype('uint8')
  155. img = mmcv.rgb2bgr(img)
  156. if show:
  157. # We do not use cv2 for display because in some cases, opencv will
  158. # conflict with Qt, it will output a warning: Current thread
  159. # is not the object's thread. You can refer to
  160. # https://github.com/opencv/opencv-python/issues/46 for details
  161. if wait_time == 0:
  162. plt.show()
  163. else:
  164. plt.show(block=False)
  165. plt.pause(wait_time)
  166. if out_file is not None:
  167. mmcv.imwrite(img, out_file)
  168. plt.close()
  169. return img
  170. def imshow_gt_det_bboxes(img,
  171. annotation,
  172. result,
  173. class_names=None,
  174. score_thr=0,
  175. gt_bbox_color=(255, 102, 61),
  176. gt_text_color=(255, 102, 61),
  177. gt_mask_color=(255, 102, 61),
  178. det_bbox_color=(72, 101, 241),
  179. det_text_color=(72, 101, 241),
  180. det_mask_color=(72, 101, 241),
  181. thickness=2,
  182. font_size=13,
  183. win_name='',
  184. show=True,
  185. wait_time=0,
  186. out_file=None):
  187. """General visualization GT and result function.
  188. Args:
  189. img (str or ndarray): The image to be displayed.)
  190. annotation (dict): Ground truth annotations where contain keys of
  191. 'gt_bboxes' and 'gt_labels' or 'gt_masks'
  192. result (tuple[list] or list): The detection result, can be either
  193. (bbox, segm) or just bbox.
  194. class_names (list[str]): Names of each classes.
  195. score_thr (float): Minimum score of bboxes to be shown. Default: 0
  196. gt_bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
  197. The tuple of color should be in BGR order. Default: (255, 102, 61)
  198. gt_text_color (str or tuple(int) or :obj:`Color`):Color of texts.
  199. The tuple of color should be in BGR order. Default: (255, 102, 61)
  200. gt_mask_color (str or tuple(int) or :obj:`Color`, optional):
  201. Color of masks. The tuple of color should be in BGR order.
  202. Default: (255, 102, 61)
  203. det_bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
  204. The tuple of color should be in BGR order. Default: (72, 101, 241)
  205. det_text_color (str or tuple(int) or :obj:`Color`):Color of texts.
  206. The tuple of color should be in BGR order. Default: (72, 101, 241)
  207. det_mask_color (str or tuple(int) or :obj:`Color`, optional):
  208. Color of masks. The tuple of color should be in BGR order.
  209. Default: (72, 101, 241)
  210. thickness (int): Thickness of lines. Default: 2
  211. font_size (int): Font size of texts. Default: 13
  212. win_name (str): The window name. Default: ''
  213. show (bool): Whether to show the image. Default: True
  214. wait_time (float): Value of waitKey param. Default: 0.
  215. out_file (str, optional): The filename to write the image.
  216. Default: None
  217. Returns:
  218. ndarray: The image with bboxes or masks drawn on it.
  219. """
  220. assert 'gt_bboxes' in annotation
  221. assert 'gt_labels' in annotation
  222. assert isinstance(
  223. result,
  224. (tuple, list)), f'Expected tuple or list, but get {type(result)}'
  225. gt_masks = annotation.get('gt_masks', None)
  226. if gt_masks is not None:
  227. gt_masks = mask2ndarray(gt_masks)
  228. img = mmcv.imread(img)
  229. img = imshow_det_bboxes(
  230. img,
  231. annotation['gt_bboxes'],
  232. annotation['gt_labels'],
  233. gt_masks,
  234. class_names=class_names,
  235. bbox_color=gt_bbox_color,
  236. text_color=gt_text_color,
  237. mask_color=gt_mask_color,
  238. thickness=thickness,
  239. font_size=font_size,
  240. win_name=win_name,
  241. show=False)
  242. if isinstance(result, tuple):
  243. bbox_result, segm_result = result
  244. if isinstance(segm_result, tuple):
  245. segm_result = segm_result[0] # ms rcnn
  246. else:
  247. bbox_result, segm_result = result, None
  248. bboxes = np.vstack(bbox_result)
  249. labels = [
  250. np.full(bbox.shape[0], i, dtype=np.int32)
  251. for i, bbox in enumerate(bbox_result)
  252. ]
  253. labels = np.concatenate(labels)
  254. segms = None
  255. if segm_result is not None and len(labels) > 0: # non empty
  256. segms = mmcv.concat_list(segm_result)
  257. segms = mask_util.decode(segms)
  258. segms = segms.transpose(2, 0, 1)
  259. img = imshow_det_bboxes(
  260. img,
  261. bboxes,
  262. labels,
  263. segms=segms,
  264. class_names=class_names,
  265. score_thr=score_thr,
  266. bbox_color=det_bbox_color,
  267. text_color=det_text_color,
  268. mask_color=det_mask_color,
  269. thickness=thickness,
  270. font_size=font_size,
  271. win_name=win_name,
  272. show=show,
  273. wait_time=wait_time,
  274. out_file=out_file)
  275. return img

No Description

Contributors (2)