You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

video_visualizer.py 8.9 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import numpy as np
  3. import pycocotools.mask as mask_util
  4. from detectron2.utils.visualizer import (
  5. ColorMode,
  6. Visualizer,
  7. _create_text_labels,
  8. _PanopticPrediction,
  9. )
  10. from .colormap import random_color
  11. class _DetectedInstance:
  12. """
  13. Used to store data about detected objects in video frame,
  14. in order to transfer color to objects in the future frames.
  15. Attributes:
  16. label (int):
  17. bbox (tuple[float]):
  18. mask_rle (dict):
  19. color (tuple[float]): RGB colors in range (0, 1)
  20. ttl (int): time-to-live for the instance. For example, if ttl=2,
  21. the instance color can be transferred to objects in the next two frames.
  22. """
  23. __slots__ = ["label", "bbox", "mask_rle", "color", "ttl"]
  24. def __init__(self, label, bbox, mask_rle, color, ttl):
  25. self.label = label
  26. self.bbox = bbox
  27. self.mask_rle = mask_rle
  28. self.color = color
  29. self.ttl = ttl
  30. class VideoVisualizer:
  31. def __init__(self, metadata, instance_mode=ColorMode.IMAGE):
  32. """
  33. Args:
  34. metadata (MetadataCatalog): image metadata.
  35. """
  36. self.metadata = metadata
  37. self._old_instances = []
  38. assert instance_mode in [
  39. ColorMode.IMAGE,
  40. ColorMode.IMAGE_BW,
  41. ], "Other mode not supported yet."
  42. self._instance_mode = instance_mode
  43. def draw_instance_predictions(self, frame, predictions):
  44. """
  45. Draw instance-level prediction results on an image.
  46. Args:
  47. frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255].
  48. predictions (Instances): the output of an instance detection/segmentation
  49. model. Following fields will be used to draw:
  50. "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
  51. Returns:
  52. output (VisImage): image object with visualizations.
  53. """
  54. frame_visualizer = Visualizer(frame, self.metadata)
  55. num_instances = len(predictions)
  56. if num_instances == 0:
  57. return frame_visualizer.output
  58. boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None
  59. scores = predictions.scores if predictions.has("scores") else None
  60. classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None
  61. keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
  62. if predictions.has("pred_masks"):
  63. masks = predictions.pred_masks
  64. # mask IOU is not yet enabled
  65. # masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F"))
  66. # assert len(masks_rles) == num_instances
  67. else:
  68. masks = None
  69. detected = [
  70. _DetectedInstance(classes[i], boxes[i], mask_rle=None, color=None, ttl=8)
  71. for i in range(num_instances)
  72. ]
  73. colors = self._assign_colors(detected)
  74. labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
  75. if self._instance_mode == ColorMode.IMAGE_BW:
  76. # any() returns uint8 tensor
  77. frame_visualizer.output.img = frame_visualizer._create_grayscale_image(
  78. (masks.any(dim=0) > 0).numpy() if masks is not None else None
  79. )
  80. alpha = 0.3
  81. else:
  82. alpha = 0.5
  83. frame_visualizer.overlay_instances(
  84. boxes=None if masks is not None else boxes, # boxes are a bit distracting
  85. masks=masks,
  86. labels=labels,
  87. keypoints=keypoints,
  88. assigned_colors=colors,
  89. alpha=alpha,
  90. )
  91. return frame_visualizer.output
  92. def draw_sem_seg(self, frame, sem_seg, area_threshold=None):
  93. """
  94. Args:
  95. sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W),
  96. each value is the integer label.
  97. area_threshold (Optional[int]): only draw segmentations larger than the threshold
  98. """
  99. # don't need to do anything special
  100. frame_visualizer = Visualizer(frame, self.metadata)
  101. frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None)
  102. return frame_visualizer.output
  103. def draw_panoptic_seg_predictions(
  104. self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5
  105. ):
  106. frame_visualizer = Visualizer(frame, self.metadata)
  107. pred = _PanopticPrediction(panoptic_seg, segments_info)
  108. if self._instance_mode == ColorMode.IMAGE_BW:
  109. frame_visualizer.output.img = frame_visualizer._create_grayscale_image(
  110. pred.non_empty_mask()
  111. )
  112. # draw mask for all semantic segments first i.e. "stuff"
  113. for mask, sinfo in pred.semantic_masks():
  114. category_idx = sinfo["category_id"]
  115. try:
  116. mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
  117. except AttributeError:
  118. mask_color = None
  119. frame_visualizer.draw_binary_mask(
  120. mask,
  121. color=mask_color,
  122. text=self.metadata.stuff_classes[category_idx],
  123. alpha=alpha,
  124. area_threshold=area_threshold,
  125. )
  126. all_instances = list(pred.instance_masks())
  127. if len(all_instances) == 0:
  128. return frame_visualizer.output
  129. # draw mask for all instances second
  130. masks, sinfo = list(zip(*all_instances))
  131. num_instances = len(masks)
  132. masks_rles = mask_util.encode(
  133. np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F")
  134. )
  135. assert len(masks_rles) == num_instances
  136. category_ids = [x["category_id"] for x in sinfo]
  137. detected = [
  138. _DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8)
  139. for i in range(num_instances)
  140. ]
  141. colors = self._assign_colors(detected)
  142. labels = [self.metadata.thing_classes[k] for k in category_ids]
  143. frame_visualizer.overlay_instances(
  144. boxes=None,
  145. masks=masks,
  146. labels=labels,
  147. keypoints=None,
  148. assigned_colors=colors,
  149. alpha=alpha,
  150. )
  151. return frame_visualizer.output
  152. def _assign_colors(self, instances):
  153. """
  154. Naive tracking heuristics to assign same color to the same instance,
  155. will update the internal state of tracked instances.
  156. Returns:
  157. list[tuple[float]]: list of colors.
  158. """
  159. # Compute iou with either boxes or masks:
  160. is_crowd = np.zeros((len(instances),), dtype=np.bool)
  161. if instances[0].bbox is None:
  162. assert instances[0].mask_rle is not None
  163. # use mask iou only when box iou is None
  164. # because box seems good enough
  165. rles_old = [x.mask_rle for x in self._old_instances]
  166. rles_new = [x.mask_rle for x in instances]
  167. ious = mask_util.iou(rles_old, rles_new, is_crowd)
  168. threshold = 0.5
  169. else:
  170. boxes_old = [x.bbox for x in self._old_instances]
  171. boxes_new = [x.bbox for x in instances]
  172. ious = mask_util.iou(boxes_old, boxes_new, is_crowd)
  173. threshold = 0.6
  174. if len(ious) == 0:
  175. ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32")
  176. # Only allow matching instances of the same label:
  177. for old_idx, old in enumerate(self._old_instances):
  178. for new_idx, new in enumerate(instances):
  179. if old.label != new.label:
  180. ious[old_idx, new_idx] = 0
  181. matched_new_per_old = np.asarray(ious).argmax(axis=1)
  182. max_iou_per_old = np.asarray(ious).max(axis=1)
  183. # Try to find match for each old instance:
  184. extra_instances = []
  185. for idx, inst in enumerate(self._old_instances):
  186. if max_iou_per_old[idx] > threshold:
  187. newidx = matched_new_per_old[idx]
  188. if instances[newidx].color is None:
  189. instances[newidx].color = inst.color
  190. continue
  191. # If an old instance does not match any new instances,
  192. # keep it for the next frame in case it is just missed by the detector
  193. inst.ttl -= 1
  194. if inst.ttl > 0:
  195. extra_instances.append(inst)
  196. # Assign random color to newly-detected instances:
  197. for inst in instances:
  198. if inst.color is None:
  199. inst.color = random_color(rgb=True, maximum=1)
  200. self._old_instances = instances[:] + extra_instances
  201. return [d.color for d in instances]

No Description