You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.py 8.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. def bbox_flip(bboxes, img_shape, direction='horizontal'):
  5. """Flip bboxes horizontally or vertically.
  6. Args:
  7. bboxes (Tensor): Shape (..., 4*k)
  8. img_shape (tuple): Image shape.
  9. direction (str): Flip direction, options are "horizontal", "vertical",
  10. "diagonal". Default: "horizontal"
  11. Returns:
  12. Tensor: Flipped bboxes.
  13. """
  14. assert bboxes.shape[-1] % 4 == 0
  15. assert direction in ['horizontal', 'vertical', 'diagonal']
  16. flipped = bboxes.clone()
  17. if direction == 'horizontal':
  18. flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
  19. flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
  20. elif direction == 'vertical':
  21. flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
  22. flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
  23. else:
  24. flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
  25. flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
  26. flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
  27. flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
  28. return flipped
  29. def bbox_mapping(bboxes,
  30. img_shape,
  31. scale_factor,
  32. flip,
  33. flip_direction='horizontal'):
  34. """Map bboxes from the original image scale to testing scale."""
  35. new_bboxes = bboxes * bboxes.new_tensor(scale_factor)
  36. if flip:
  37. new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction)
  38. return new_bboxes
  39. def bbox_mapping_back(bboxes,
  40. img_shape,
  41. scale_factor,
  42. flip,
  43. flip_direction='horizontal'):
  44. """Map bboxes from testing scale to original image scale."""
  45. new_bboxes = bbox_flip(bboxes, img_shape,
  46. flip_direction) if flip else bboxes
  47. new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor)
  48. return new_bboxes.view(bboxes.shape)
  49. def bbox2roi(bbox_list):
  50. """Convert a list of bboxes to roi format.
  51. Args:
  52. bbox_list (list[Tensor]): a list of bboxes corresponding to a batch
  53. of images.
  54. Returns:
  55. Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
  56. """
  57. rois_list = []
  58. for img_id, bboxes in enumerate(bbox_list):
  59. if bboxes.size(0) > 0:
  60. img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)
  61. rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)
  62. else:
  63. rois = bboxes.new_zeros((0, 5))
  64. rois_list.append(rois)
  65. rois = torch.cat(rois_list, 0)
  66. return rois
  67. def roi2bbox(rois):
  68. """Convert rois to bounding box format.
  69. Args:
  70. rois (torch.Tensor): RoIs with the shape (n, 5) where the first
  71. column indicates batch id of each RoI.
  72. Returns:
  73. list[torch.Tensor]: Converted boxes of corresponding rois.
  74. """
  75. bbox_list = []
  76. img_ids = torch.unique(rois[:, 0].cpu(), sorted=True)
  77. for img_id in img_ids:
  78. inds = (rois[:, 0] == img_id.item())
  79. bbox = rois[inds, 1:]
  80. bbox_list.append(bbox)
  81. return bbox_list
  82. def bbox2result(bboxes, labels, num_classes):
  83. """Convert detection results to a list of numpy arrays.
  84. Args:
  85. bboxes (torch.Tensor | np.ndarray): shape (n, 5)
  86. labels (torch.Tensor | np.ndarray): shape (n, )
  87. num_classes (int): class number, including background class
  88. Returns:
  89. list(ndarray): bbox results of each class
  90. """
  91. if bboxes.shape[0] == 0:
  92. return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
  93. else:
  94. if isinstance(bboxes, torch.Tensor):
  95. bboxes = bboxes.detach().cpu().numpy()
  96. labels = labels.detach().cpu().numpy()
  97. bboxes_sum = np.sum(bboxes, 1)
  98. bboxes_ids = np.where(bboxes_sum>0)
  99. bboxes = bboxes[bboxes_ids]
  100. labels = labels[bboxes_ids]
  101. return [bboxes[labels == i, :] for i in range(num_classes)]
  102. def distance2bbox(points, distance, max_shape=None):
  103. """Decode distance prediction to bounding box.
  104. Args:
  105. points (Tensor): Shape (B, N, 2) or (N, 2).
  106. distance (Tensor): Distance from the given point to 4
  107. boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4)
  108. max_shape (Sequence[int] or torch.Tensor or Sequence[
  109. Sequence[int]],optional): Maximum bounds for boxes, specifies
  110. (H, W, C) or (H, W). If priors shape is (B, N, 4), then
  111. the max_shape should be a Sequence[Sequence[int]]
  112. and the length of max_shape should also be B.
  113. Returns:
  114. Tensor: Boxes with shape (N, 4) or (B, N, 4)
  115. """
  116. x1 = points[..., 0] - distance[..., 0]
  117. y1 = points[..., 1] - distance[..., 1]
  118. x2 = points[..., 0] + distance[..., 2]
  119. y2 = points[..., 1] + distance[..., 3]
  120. bboxes = torch.stack([x1, y1, x2, y2], -1)
  121. if max_shape is not None:
  122. if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export():
  123. # speed up
  124. bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
  125. bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
  126. return bboxes
  127. # clip bboxes with dynamic `min` and `max` for onnx
  128. if torch.onnx.is_in_onnx_export():
  129. from mmdet.core.export import dynamic_clip_for_onnx
  130. x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
  131. bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
  132. return bboxes
  133. if not isinstance(max_shape, torch.Tensor):
  134. max_shape = x1.new_tensor(max_shape)
  135. max_shape = max_shape[..., :2].type_as(x1)
  136. if max_shape.ndim == 2:
  137. assert bboxes.ndim == 3
  138. assert max_shape.size(0) == bboxes.size(0)
  139. min_xy = x1.new_tensor(0)
  140. max_xy = torch.cat([max_shape, max_shape],
  141. dim=-1).flip(-1).unsqueeze(-2)
  142. bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
  143. bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
  144. return bboxes
  145. def bbox2distance(points, bbox, max_dis=None, eps=0.1):
  146. """Decode bounding box based on distances.
  147. Args:
  148. points (Tensor): Shape (n, 2), [x, y].
  149. bbox (Tensor): Shape (n, 4), "xyxy" format
  150. max_dis (float): Upper bound of the distance.
  151. eps (float): a small value to ensure target < max_dis, instead <=
  152. Returns:
  153. Tensor: Decoded distances.
  154. """
  155. left = points[:, 0] - bbox[:, 0]
  156. top = points[:, 1] - bbox[:, 1]
  157. right = bbox[:, 2] - points[:, 0]
  158. bottom = bbox[:, 3] - points[:, 1]
  159. if max_dis is not None:
  160. left = left.clamp(min=0, max=max_dis - eps)
  161. top = top.clamp(min=0, max=max_dis - eps)
  162. right = right.clamp(min=0, max=max_dis - eps)
  163. bottom = bottom.clamp(min=0, max=max_dis - eps)
  164. return torch.stack([left, top, right, bottom], -1)
  165. def bbox_rescale(bboxes, scale_factor=1.0):
  166. """Rescale bounding box w.r.t. scale_factor.
  167. Args:
  168. bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois
  169. scale_factor (float): rescale factor
  170. Returns:
  171. Tensor: Rescaled bboxes.
  172. """
  173. if bboxes.size(1) == 5:
  174. bboxes_ = bboxes[:, 1:]
  175. inds_ = bboxes[:, 0]
  176. else:
  177. bboxes_ = bboxes
  178. cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5
  179. cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5
  180. w = bboxes_[:, 2] - bboxes_[:, 0]
  181. h = bboxes_[:, 3] - bboxes_[:, 1]
  182. w = w * scale_factor
  183. h = h * scale_factor
  184. x1 = cx - 0.5 * w
  185. x2 = cx + 0.5 * w
  186. y1 = cy - 0.5 * h
  187. y2 = cy + 0.5 * h
  188. if bboxes.size(1) == 5:
  189. rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1)
  190. else:
  191. rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
  192. return rescaled_bboxes
  193. def bbox_cxcywh_to_xyxy(bbox):
  194. """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).
  195. Args:
  196. bbox (Tensor): Shape (n, 4) for bboxes.
  197. Returns:
  198. Tensor: Converted bboxes.
  199. """
  200. cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1)
  201. bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
  202. return torch.cat(bbox_new, dim=-1)
  203. def bbox_xyxy_to_cxcywh(bbox):
  204. """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
  205. Args:
  206. bbox (Tensor): Shape (n, 4) for bboxes.
  207. Returns:
  208. Tensor: Converted bboxes.
  209. """
  210. x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
  211. bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
  212. return torch.cat(bbox_new, dim=-1)

No Description

Contributors (1)