You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

boxes.py 10 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import copy
  3. import numpy as np
  4. from enum import Enum, unique
  5. from typing import Iterator, List, Tuple, Union
  6. import torch
  7. from detectron2.layers import cat
  8. _RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
  9. @unique
  10. class BoxMode(Enum):
  11. """
  12. Enum of different ways to represent a box.
  13. Attributes:
  14. XYXY_ABS: (x0, y0, x1, y1) in absolute floating points coordinates.
  15. The coordinates in range [0, width or height].
  16. XYWH_ABS: (x0, y0, w, h) in absolute floating points coordinates.
  17. XYXY_REL: (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
  18. XYWH_REL: (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
  19. """
  20. XYXY_ABS = 0
  21. XYWH_ABS = 1
  22. XYXY_REL = 2
  23. XYWH_REL = 3
  24. @staticmethod
  25. def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
  26. """
  27. Args:
  28. box: can be a 4-tuple, 4-list or a Nx4 array/tensor.
  29. from_mode, to_mode (BoxMode)
  30. Returns:
  31. The converted box of the same type.
  32. """
  33. if from_mode == to_mode:
  34. return box
  35. original_type = type(box)
  36. single_box = isinstance(box, (list, tuple))
  37. if single_box:
  38. arr = np.array(box)
  39. assert arr.shape == (
  40. 4,
  41. ), "BoxMode.convert takes either a 4-tuple/list or a Nx4 array/tensor"
  42. else:
  43. arr = copy.deepcopy(box) # avoid modifying the input box
  44. assert to_mode.value < 2 and from_mode.value < 2, "Relative mode not yet supported!"
  45. original_shape = arr.shape
  46. arr = arr.reshape(-1, 4)
  47. if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
  48. arr[:, 2] += arr[:, 0]
  49. arr[:, 3] += arr[:, 1]
  50. elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
  51. arr[:, 2] -= arr[:, 0]
  52. arr[:, 3] -= arr[:, 1]
  53. else:
  54. raise RuntimeError("Cannot be here!")
  55. if single_box:
  56. return original_type(arr.flatten())
  57. return arr.reshape(*original_shape)
  58. class Boxes:
  59. """
  60. This structure stores a list of boxes as a Nx4 torch.Tensor.
  61. It supports some common methods about boxes
  62. (`area`, `clip`, `nonempty`, etc),
  63. and also behaves like a Tensor
  64. (support indexing, `to(device)`, `.device`, and iteration over all boxes)
  65. Attributes:
  66. tensor: float matrix of Nx4.
  67. """
  68. BoxSizeType = Union[List[int], Tuple[int, int]]
  69. def __init__(self, tensor: torch.Tensor):
  70. """
  71. Args:
  72. tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
  73. """
  74. device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
  75. tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
  76. if tensor.numel() == 0:
  77. tensor = torch.zeros(0, 4, dtype=torch.float32, device=device)
  78. assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
  79. self.tensor = tensor
  80. def clone(self) -> "Boxes":
  81. """
  82. Clone the Boxes.
  83. Returns:
  84. Boxes
  85. """
  86. return Boxes(self.tensor.clone())
  87. def to(self, device: str) -> "Boxes":
  88. return Boxes(self.tensor.to(device))
  89. def area(self) -> torch.Tensor:
  90. """
  91. Computes the area of all the boxes.
  92. Returns:
  93. torch.Tensor: a vector with areas of each box.
  94. """
  95. box = self.tensor
  96. area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
  97. return area
  98. def clip(self, box_size: BoxSizeType) -> None:
  99. """
  100. Clip (in place) the boxes by limiting x coordinates to the range [0, width]
  101. and y coordinates to the range [0, height].
  102. Args:
  103. box_size (height, width): The clipping box's size.
  104. """
  105. assert torch.isfinite(self.tensor).all()
  106. h, w = box_size
  107. self.tensor[:, 0].clamp_(min=0, max=w)
  108. self.tensor[:, 1].clamp_(min=0, max=h)
  109. self.tensor[:, 2].clamp_(min=0, max=w)
  110. self.tensor[:, 3].clamp_(min=0, max=h)
  111. def nonempty(self, threshold: int = 0) -> torch.Tensor:
  112. """
  113. Find boxes that are non-empty.
  114. A box is considered empty, if either of its side is no larger than threshold.
  115. Returns:
  116. Tensor:
  117. a binary vector which represents whether each box is empty
  118. (False) or non-empty (True).
  119. """
  120. box = self.tensor
  121. widths = box[:, 2] - box[:, 0]
  122. heights = box[:, 3] - box[:, 1]
  123. keep = (widths > threshold) & (heights > threshold)
  124. return keep
  125. def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Boxes":
  126. """
  127. Returns:
  128. Boxes: Create a new :class:`Boxes` by indexing.
  129. The following usage are allowed:
  130. 1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
  131. 2. `new_boxes = boxes[2:10]`: return a slice of boxes.
  132. 3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
  133. with `length = len(boxes)`. Nonzero elements in the vector will be selected.
  134. Note that the returned Boxes might share storage with this Boxes,
  135. subject to Pytorch's indexing semantics.
  136. """
  137. if isinstance(item, int):
  138. return Boxes(self.tensor[item].view(1, -1))
  139. b = self.tensor[item]
  140. assert b.dim() == 2, "Indexing on Boxes with {} failed to return a matrix!".format(item)
  141. return Boxes(b)
  142. def __len__(self) -> int:
  143. return self.tensor.shape[0]
  144. def __repr__(self) -> str:
  145. return "Boxes(" + str(self.tensor) + ")"
  146. def inside_box(self, box_size: BoxSizeType, boundary_threshold: int = 0) -> torch.Tensor:
  147. """
  148. Args:
  149. box_size (height, width): Size of the reference box.
  150. boundary_threshold (int): Boxes that extend beyond the reference box
  151. boundary by more than boundary_threshold are considered "outside".
  152. Returns:
  153. a binary vector, indicating whether each box is inside the reference box.
  154. """
  155. height, width = box_size
  156. inds_inside = (
  157. (self.tensor[..., 0] >= -boundary_threshold)
  158. & (self.tensor[..., 1] >= -boundary_threshold)
  159. & (self.tensor[..., 2] < width + boundary_threshold)
  160. & (self.tensor[..., 3] < height + boundary_threshold)
  161. )
  162. return inds_inside
  163. def get_centers(self) -> torch.Tensor:
  164. """
  165. Returns:
  166. The box centers in a Nx2 array of (x, y).
  167. """
  168. return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
  169. def scale(self, scale_x: float, scale_y: float) -> None:
  170. """
  171. Scale the box with horizontal and vertical scaling factors
  172. """
  173. self.tensor[:, 0::2] *= scale_x
  174. self.tensor[:, 1::2] *= scale_y
  175. @staticmethod
  176. def cat(boxes_list: List["Boxes"]) -> "Boxes":
  177. """
  178. Concatenates a list of Boxes into a single Boxes
  179. Arguments:
  180. boxes_list (list[Boxes])
  181. Returns:
  182. Boxes: the concatenated Boxes
  183. """
  184. assert isinstance(boxes_list, (list, tuple))
  185. assert len(boxes_list) > 0
  186. assert all(isinstance(box, Boxes) for box in boxes_list)
  187. cat_boxes = type(boxes_list[0])(cat([b.tensor for b in boxes_list], dim=0))
  188. return cat_boxes
  189. @property
  190. def device(self) -> str:
  191. return self.tensor.device
  192. def __iter__(self) -> Iterator[torch.Tensor]:
  193. """
  194. Yield a box as a Tensor of shape (4,) at a time.
  195. """
  196. yield from self.tensor
  197. # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
  198. # with slight modifications
  199. def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
  200. """
  201. Given two lists of boxes of size N and M,
  202. compute the IoU (intersection over union)
  203. between __all__ N x M pairs of boxes.
  204. The box order must be (xmin, ymin, xmax, ymax).
  205. Args:
  206. boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
  207. Returns:
  208. Tensor: IoU, sized [N,M].
  209. """
  210. area1 = boxes1.area()
  211. area2 = boxes2.area()
  212. boxes1, boxes2 = boxes1.tensor, boxes2.tensor
  213. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  214. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  215. wh = (rb - lt).clamp(min=0) # [N,M,2]
  216. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  217. # handle empty boxes
  218. iou = torch.where(
  219. inter > 0,
  220. inter / (area1[:, None] + area2 - inter),
  221. torch.zeros(1, dtype=inter.dtype, device=inter.device),
  222. )
  223. return iou
  224. def matched_boxlist_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
  225. """
  226. Compute pairwise intersection over union (IOU) of two sets of matched
  227. boxes. The box order must be (xmin, ymin, xmax, ymax).
  228. Similar to boxlist_iou, but computes only diagonal elements of the matrix
  229. Arguments:
  230. boxes1: (Boxes) bounding boxes, sized [N,4].
  231. boxes2: (Boxes) bounding boxes, sized [N,4].
  232. Returns:
  233. (tensor) iou, sized [N].
  234. """
  235. assert len(boxes1) == len(boxes2), (
  236. "boxlists should have the same"
  237. "number of entries, got {}, {}".format(len(boxes1), len(boxes2))
  238. )
  239. area1 = boxes1.area() # [N]
  240. area2 = boxes2.area() # [N]
  241. box1, box2 = boxes1.tensor, boxes2.tensor
  242. lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
  243. rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
  244. wh = (rb - lt).clamp(min=0) # [N,2]
  245. inter = wh[:, 0] * wh[:, 1] # [N]
  246. iou = inter / (area1 + area2 - inter) # [N]
  247. return iou

No Description