You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

browse_dataset.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Visualization for detection/segmentation dataset.
  16. """
  17. import os
  18. import sys
  19. import importlib
  20. import numpy as np
  21. from mindspore import log as logger
  22. def imshow_det_bbox(image,
  23. bboxes,
  24. labels,
  25. segm=None,
  26. class_names=None,
  27. score_threshold=0,
  28. bbox_color=(0, 255, 0),
  29. text_color=(203, 192, 255),
  30. mask_color=(128, 0, 128),
  31. thickness=2,
  32. font_size=0.8,
  33. show=True,
  34. win_name="win",
  35. wait_time=2000,
  36. out_file=None
  37. ):
  38. """Draw an image with given bboxes and class labels (with scores).
  39. Args:
  40. image (ndarray): The image to be displayed, shaped (C, H, W) or (H, W, C), formatted RGB.
  41. bboxes (ndarray): Bounding boxes (with scores), shaped (N, 4) or (N, 5),
  42. data should be ordered with (N, x, y, w, h).
  43. labels (ndarray): Labels of bboxes, shaped (N, 1).
  44. segm (ndarray): The segmentation masks of image in M classes, shaped (M, H, W) (Default=None).
  45. class_names (list[str], tuple[str], dict): Names of each class to map label to class name
  46. (Default=None, only display label).
  47. score_threshold (float): Minimum score of bboxes to be shown (Default=0).
  48. bbox_color (tuple(int)): Color of bbox lines.
  49. The tuple of color should be in BGR order (Default=(0, 255 ,0), means 'green').
  50. text_color (tuple(int)): Color of texts.
  51. The tuple of color should be in BGR order (Default=(203, 192, 255), means 'pink').
  52. mask_color (tuple(int)): Color of mask.
  53. The tuple of color should be in BGR order (Default=(128, 0, 128), means 'purple').
  54. thickness (int): Thickness of lines (Default=2).
  55. font_size (int, float): Font size of texts (Default=0.8).
  56. show (bool): Whether to show the image (Default=True).
  57. win_name (str): The window name (Default="win").
  58. wait_time (int): Value of waitKey param (Default=2000, means display interval is 2000ms).
  59. out_file (str, optional): The filename to write the imagee (Default=None). File extension name
  60. is required to indicate the image compression type, e.g. 'jpg', 'png'.
  61. Returns:
  62. ndarray: The image with bboxes drawn on it.
  63. """
  64. try:
  65. cv2 = importlib.import_module("cv2")
  66. except ModuleNotFoundError:
  67. raise ImportError("import cv2 failed, seems you have to run `pip install opencv-python`.")
  68. # validation
  69. assert isinstance(image, np.ndarray) and image.ndim == 3 and (image.shape[0] == 3 or image.shape[2] == 3),\
  70. "image must be a ndarray in (H, W, C) or (C, H, W) format."
  71. if bboxes is not None:
  72. assert isinstance(bboxes, np.ndarray) and bboxes.ndim == 2 and (bboxes.shape[1] == 4 or bboxes.shape[1] == 5), \
  73. "bboxes must be a ndarray in (N, 4) or (N, 5) format."
  74. assert isinstance(labels, np.ndarray) and labels.ndim == 2 and labels.shape[1] == 1 and \
  75. labels.shape[0] == bboxes.shape[0], "labels must be a ndarray in (N, 1) format and has same N with bboxes."
  76. if segm is not None:
  77. assert isinstance(segm, np.ndarray) and segm.ndim == 3, "segm must be a ndarray in (M, H, W) format."
  78. H, W = (image.shape[0], image.shape[1]) if image.shape[2] == 3 else (image.shape[1], image.shape[2])
  79. assert H == segm.shape[1] and W == segm.shape[2], "segm must has same height and width with image."
  80. if bboxes is not None:
  81. assert bboxes.shape[0] <= segm.shape[0], "number of segm masks must not be less than the number of bboxes."
  82. assert isinstance(class_names, (tuple, list, dict)), "class_names must be a list, tuple or dict."
  83. assert isinstance(bbox_color, tuple) and len(bbox_color) == 3, \
  84. "bbox_color must be a three tuple, formatted (B, G, R)."
  85. assert isinstance(text_color, tuple) and len(text_color) == 3, \
  86. "text_color must be a three tuple, formatted (B, G, R)."
  87. assert isinstance(mask_color, tuple) and len(mask_color) == 3, \
  88. "mask_color must be a three tuple, formatted (B, G, R)."
  89. assert isinstance(thickness, int), "thickness must be a int."
  90. assert thickness >= 0, "thickness must be larger than or equal to zero."
  91. assert isinstance(font_size, (int, float)), "font_size must be a int or float."
  92. assert font_size >= 0, "font_size must be larger than or equal to zero."
  93. assert isinstance(show, bool), "show must be a bool."
  94. assert isinstance(win_name, str), "win_name must be a str."
  95. assert isinstance(wait_time, int), "wait_time must be a int."
  96. assert wait_time >= 0, "wait_time must be larger than or equal to zero."
  97. if out_file is not None:
  98. assert isinstance(out_file, str), "out_file must be a str."
  99. if score_threshold > 0:
  100. assert bboxes.shape[1] == 5
  101. if not show:
  102. assert out_file is not None
  103. # image
  104. if image.shape[0] == 3:
  105. image = image.transpose((1, 2, 0))
  106. draw_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  107. if bboxes is not None:
  108. bbox_num = bboxes.shape[0]
  109. for i in range(bbox_num):
  110. draw_bbox = bboxes[i]
  111. if len(draw_bbox) > 4:
  112. if draw_bbox[4] < score_threshold:
  113. continue
  114. # bbox
  115. x1, y1 = int(draw_bbox[0]), int(draw_bbox[1])
  116. x2, y2 = int(draw_bbox[0]+draw_bbox[2]), int(draw_bbox[1]+draw_bbox[3])
  117. cv2.rectangle(draw_image, (x1, y1), (x2, y2), bbox_color, thickness)
  118. # label
  119. try:
  120. draw_label = str(class_names[labels[i][0]]) if class_names is not None else f'class {labels[i][0]}'
  121. except (IndexError, KeyError):
  122. draw_label = f'class {labels[i][0]}'
  123. if len(draw_bbox) > 4:
  124. draw_label += f'|{draw_bbox[-1]:.02f}'
  125. cv2.putText(draw_image, draw_label, (x1, y2), cv2.FONT_HERSHEY_SIMPLEX, font_size, text_color, thickness)
  126. if segm is not None:
  127. mask = segm[i].astype(bool)
  128. draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5
  129. else:
  130. if segm is not None:
  131. segm_num = segm.shape[0]
  132. for i in range(segm_num):
  133. mask = segm[i].astype(bool)
  134. draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5
  135. if show:
  136. cv2.imshow(win_name, draw_image)
  137. if cv2.waitKey(wait_time) == 27:
  138. sys.exit()
  139. if out_file:
  140. logger.info("Saving image file with name: " + out_file + "...")
  141. cv2.imwrite(out_file, draw_image)
  142. os.chmod(out_file, 0o600)
  143. return draw_image