You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.py 7.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from functools import partial
  3. import numpy as np
  4. import torch
  5. from six.moves import map, zip
  6. from ..mask.structures import BitmapMasks, PolygonMasks
  7. def multi_apply(func, *args, **kwargs):
  8. """Apply function to a list of arguments.
  9. Note:
  10. This function applies the ``func`` to multiple inputs and
  11. map the multiple outputs of the ``func`` into different
  12. list. Each list contains the same type of outputs corresponding
  13. to different inputs.
  14. Args:
  15. func (Function): A function that will be applied to a list of
  16. arguments
  17. Returns:
  18. tuple(list): A tuple containing multiple list, each list contains \
  19. a kind of returned results by the function
  20. """
  21. pfunc = partial(func, **kwargs) if kwargs else func
  22. map_results = map(pfunc, *args)
  23. return tuple(map(list, zip(*map_results)))
  24. def unmap(data, count, inds, fill=0):
  25. """Unmap a subset of item (data) back to the original set of items (of size
  26. count)"""
  27. if data.dim() == 1:
  28. ret = data.new_full((count, ), fill)
  29. ret[inds.type(torch.bool)] = data
  30. else:
  31. new_size = (count, ) + data.size()[1:]
  32. ret = data.new_full(new_size, fill)
  33. ret[inds.type(torch.bool), :] = data
  34. return ret
  35. def mask2ndarray(mask):
  36. """Convert Mask to ndarray..
  37. Args:
  38. mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or
  39. torch.Tensor or np.ndarray): The mask to be converted.
  40. Returns:
  41. np.ndarray: Ndarray mask of shape (n, h, w) that has been converted
  42. """
  43. if isinstance(mask, (BitmapMasks, PolygonMasks)):
  44. mask = mask.to_ndarray()
  45. elif isinstance(mask, torch.Tensor):
  46. mask = mask.detach().cpu().numpy()
  47. elif not isinstance(mask, np.ndarray):
  48. raise TypeError(f'Unsupported {type(mask)} data type')
  49. return mask
  50. def flip_tensor(src_tensor, flip_direction):
  51. """flip tensor base on flip_direction.
  52. Args:
  53. src_tensor (Tensor): input feature map, shape (B, C, H, W).
  54. flip_direction (str): The flipping direction. Options are
  55. 'horizontal', 'vertical', 'diagonal'.
  56. Returns:
  57. out_tensor (Tensor): Flipped tensor.
  58. """
  59. assert src_tensor.ndim == 4
  60. valid_directions = ['horizontal', 'vertical', 'diagonal']
  61. assert flip_direction in valid_directions
  62. if flip_direction == 'horizontal':
  63. out_tensor = torch.flip(src_tensor, [3])
  64. elif flip_direction == 'vertical':
  65. out_tensor = torch.flip(src_tensor, [2])
  66. else:
  67. out_tensor = torch.flip(src_tensor, [2, 3])
  68. return out_tensor
  69. def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
  70. """Extract a multi-scale single image tensor from a multi-scale batch
  71. tensor based on batch index.
  72. Note: The default value of detach is True, because the proposal gradient
  73. needs to be detached during the training of the two-stage model. E.g
  74. Cascade Mask R-CNN.
  75. Args:
  76. mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
  77. each is a 4D-tensor.
  78. batch_id (int): Batch index.
  79. detach (bool): Whether detach gradient. Default True.
  80. Returns:
  81. list[Tensor]: Multi-scale single image tensor.
  82. """
  83. assert isinstance(mlvl_tensors, (list, tuple))
  84. num_levels = len(mlvl_tensors)
  85. if detach:
  86. mlvl_tensor_list = [
  87. mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
  88. ]
  89. else:
  90. mlvl_tensor_list = [
  91. mlvl_tensors[i][batch_id] for i in range(num_levels)
  92. ]
  93. return mlvl_tensor_list
  94. def filter_scores_and_topk(scores, score_thr, topk, results=None):
  95. """Filter results using score threshold and topk candidates.
  96. Args:
  97. scores (Tensor): The scores, shape (num_bboxes, K).
  98. score_thr (float): The score filter threshold.
  99. topk (int): The number of topk candidates.
  100. results (dict or list or Tensor, Optional): The results to
  101. which the filtering rule is to be applied. The shape
  102. of each item is (num_bboxes, N).
  103. Returns:
  104. tuple: Filtered results
  105. - scores (Tensor): The scores after being filtered, \
  106. shape (num_bboxes_filtered, ).
  107. - labels (Tensor): The class labels, shape \
  108. (num_bboxes_filtered, ).
  109. - anchor_idxs (Tensor): The anchor indexes, shape \
  110. (num_bboxes_filtered, ).
  111. - filtered_results (dict or list or Tensor, Optional): \
  112. The filtered results. The shape of each item is \
  113. (num_bboxes_filtered, N).
  114. """
  115. valid_mask = scores > score_thr
  116. scores = scores[valid_mask]
  117. valid_idxs = torch.nonzero(valid_mask)
  118. num_topk = min(topk, valid_idxs.size(0))
  119. # torch.sort is actually faster than .topk (at least on GPUs)
  120. scores, idxs = scores.sort(descending=True)
  121. scores = scores[:num_topk]
  122. topk_idxs = valid_idxs[idxs[:num_topk]]
  123. keep_idxs, labels = topk_idxs.unbind(dim=1)
  124. filtered_results = None
  125. if results is not None:
  126. if isinstance(results, dict):
  127. filtered_results = {k: v[keep_idxs] for k, v in results.items()}
  128. elif isinstance(results, list):
  129. filtered_results = [result[keep_idxs] for result in results]
  130. elif isinstance(results, torch.Tensor):
  131. filtered_results = results[keep_idxs]
  132. else:
  133. raise NotImplementedError(f'Only supports dict or list or Tensor, '
  134. f'but get {type(results)}.')
  135. return scores, labels, keep_idxs, filtered_results
  136. def center_of_mass(mask, esp=1e-6):
  137. """Calculate the centroid coordinates of the mask.
  138. Args:
  139. mask (Tensor): The mask to be calculated, shape (h, w).
  140. esp (float): Avoid dividing by zero. Default: 1e-6.
  141. Returns:
  142. tuple[Tensor]: the coordinates of the center point of the mask.
  143. - center_h (Tensor): the center point of the height.
  144. - center_w (Tensor): the center point of the width.
  145. """
  146. h, w = mask.shape
  147. grid_h = torch.arange(h, device=mask.device)[:, None]
  148. grid_w = torch.arange(w, device=mask.device)
  149. normalizer = mask.sum().float().clamp(min=esp)
  150. center_h = (mask * grid_h).sum() / normalizer
  151. center_w = (mask * grid_w).sum() / normalizer
  152. return center_h, center_w
  153. def generate_coordinate(featmap_sizes, device='cuda'):
  154. """Generate the coordinate.
  155. Args:
  156. featmap_sizes (tuple): The feature to be calculated,
  157. of shape (N, C, W, H).
  158. device (str): The device where the feature will be put on.
  159. Returns:
  160. coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H).
  161. """
  162. x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
  163. y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
  164. y, x = torch.meshgrid(y_range, x_range)
  165. y = y.expand([featmap_sizes[0], 1, -1, -1])
  166. x = x.expand([featmap_sizes[0], 1, -1, -1])
  167. coord_feat = torch.cat([x, y], 1)
  168. return coord_feat

No Description

Contributors (2)