You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

structures.py 40 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. import cv2
  4. import mmcv
  5. import numpy as np
  6. import pycocotools.mask as maskUtils
  7. import torch
  8. from mmcv.ops.roi_align import roi_align
  9. class BaseInstanceMasks(metaclass=ABCMeta):
  10. """Base class for instance masks."""
  11. @abstractmethod
  12. def rescale(self, scale, interpolation='nearest'):
  13. """Rescale masks as large as possible while keeping the aspect ratio.
  14. For details can refer to `mmcv.imrescale`.
  15. Args:
  16. scale (tuple[int]): The maximum size (h, w) of rescaled mask.
  17. interpolation (str): Same as :func:`mmcv.imrescale`.
  18. Returns:
  19. BaseInstanceMasks: The rescaled masks.
  20. """
  21. @abstractmethod
  22. def resize(self, out_shape, interpolation='nearest'):
  23. """Resize masks to the given out_shape.
  24. Args:
  25. out_shape: Target (h, w) of resized mask.
  26. interpolation (str): See :func:`mmcv.imresize`.
  27. Returns:
  28. BaseInstanceMasks: The resized masks.
  29. """
  30. @abstractmethod
  31. def flip(self, flip_direction='horizontal'):
  32. """Flip masks alone the given direction.
  33. Args:
  34. flip_direction (str): Either 'horizontal' or 'vertical'.
  35. Returns:
  36. BaseInstanceMasks: The flipped masks.
  37. """
  38. @abstractmethod
  39. def pad(self, out_shape, pad_val):
  40. """Pad masks to the given size of (h, w).
  41. Args:
  42. out_shape (tuple[int]): Target (h, w) of padded mask.
  43. pad_val (int): The padded value.
  44. Returns:
  45. BaseInstanceMasks: The padded masks.
  46. """
  47. @abstractmethod
  48. def crop(self, bbox):
  49. """Crop each mask by the given bbox.
  50. Args:
  51. bbox (ndarray): Bbox in format [x1, y1, x2, y2], shape (4, ).
  52. Return:
  53. BaseInstanceMasks: The cropped masks.
  54. """
  55. @abstractmethod
  56. def crop_and_resize(self,
  57. bboxes,
  58. out_shape,
  59. inds,
  60. device,
  61. interpolation='bilinear',
  62. binarize=True):
  63. """Crop and resize masks by the given bboxes.
  64. This function is mainly used in mask targets computation.
  65. It firstly align mask to bboxes by assigned_inds, then crop mask by the
  66. assigned bbox and resize to the size of (mask_h, mask_w)
  67. Args:
  68. bboxes (Tensor): Bboxes in format [x1, y1, x2, y2], shape (N, 4)
  69. out_shape (tuple[int]): Target (h, w) of resized mask
  70. inds (ndarray): Indexes to assign masks to each bbox,
  71. shape (N,) and values should be between [0, num_masks - 1].
  72. device (str): Device of bboxes
  73. interpolation (str): See `mmcv.imresize`
  74. binarize (bool): if True fractional values are rounded to 0 or 1
  75. after the resize operation. if False and unsupported an error
  76. will be raised. Defaults to True.
  77. Return:
  78. BaseInstanceMasks: the cropped and resized masks.
  79. """
  80. @abstractmethod
  81. def expand(self, expanded_h, expanded_w, top, left):
  82. """see :class:`Expand`."""
  83. @property
  84. @abstractmethod
  85. def areas(self):
  86. """ndarray: areas of each instance."""
  87. @abstractmethod
  88. def to_ndarray(self):
  89. """Convert masks to the format of ndarray.
  90. Return:
  91. ndarray: Converted masks in the format of ndarray.
  92. """
  93. @abstractmethod
  94. def to_tensor(self, dtype, device):
  95. """Convert masks to the format of Tensor.
  96. Args:
  97. dtype (str): Dtype of converted mask.
  98. device (torch.device): Device of converted masks.
  99. Returns:
  100. Tensor: Converted masks in the format of Tensor.
  101. """
  102. @abstractmethod
  103. def translate(self,
  104. out_shape,
  105. offset,
  106. direction='horizontal',
  107. fill_val=0,
  108. interpolation='bilinear'):
  109. """Translate the masks.
  110. Args:
  111. out_shape (tuple[int]): Shape for output mask, format (h, w).
  112. offset (int | float): The offset for translate.
  113. direction (str): The translate direction, either "horizontal"
  114. or "vertical".
  115. fill_val (int | float): Border value. Default 0.
  116. interpolation (str): Same as :func:`mmcv.imtranslate`.
  117. Returns:
  118. Translated masks.
  119. """
  120. def shear(self,
  121. out_shape,
  122. magnitude,
  123. direction='horizontal',
  124. border_value=0,
  125. interpolation='bilinear'):
  126. """Shear the masks.
  127. Args:
  128. out_shape (tuple[int]): Shape for output mask, format (h, w).
  129. magnitude (int | float): The magnitude used for shear.
  130. direction (str): The shear direction, either "horizontal"
  131. or "vertical".
  132. border_value (int | tuple[int]): Value used in case of a
  133. constant border. Default 0.
  134. interpolation (str): Same as in :func:`mmcv.imshear`.
  135. Returns:
  136. ndarray: Sheared masks.
  137. """
  138. @abstractmethod
  139. def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
  140. """Rotate the masks.
  141. Args:
  142. out_shape (tuple[int]): Shape for output mask, format (h, w).
  143. angle (int | float): Rotation angle in degrees. Positive values
  144. mean counter-clockwise rotation.
  145. center (tuple[float], optional): Center point (w, h) of the
  146. rotation in source image. If not specified, the center of
  147. the image will be used.
  148. scale (int | float): Isotropic scale factor.
  149. fill_val (int | float): Border value. Default 0 for masks.
  150. Returns:
  151. Rotated masks.
  152. """
  153. class BitmapMasks(BaseInstanceMasks):
  154. """This class represents masks in the form of bitmaps.
  155. Args:
  156. masks (ndarray): ndarray of masks in shape (N, H, W), where N is
  157. the number of objects.
  158. height (int): height of masks
  159. width (int): width of masks
  160. Example:
  161. >>> from mmdet.core.mask.structures import * # NOQA
  162. >>> num_masks, H, W = 3, 32, 32
  163. >>> rng = np.random.RandomState(0)
  164. >>> masks = (rng.rand(num_masks, H, W) > 0.1).astype(np.int)
  165. >>> self = BitmapMasks(masks, height=H, width=W)
  166. >>> # demo crop_and_resize
  167. >>> num_boxes = 5
  168. >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
  169. >>> out_shape = (14, 14)
  170. >>> inds = torch.randint(0, len(self), size=(num_boxes,))
  171. >>> device = 'cpu'
  172. >>> interpolation = 'bilinear'
  173. >>> new = self.crop_and_resize(
  174. ... bboxes, out_shape, inds, device, interpolation)
  175. >>> assert len(new) == num_boxes
  176. >>> assert new.height, new.width == out_shape
  177. """
  178. def __init__(self, masks, height, width):
  179. self.height = height
  180. self.width = width
  181. if len(masks) == 0:
  182. self.masks = np.empty((0, self.height, self.width), dtype=np.uint8)
  183. else:
  184. assert isinstance(masks, (list, np.ndarray))
  185. if isinstance(masks, list):
  186. assert isinstance(masks[0], np.ndarray)
  187. assert masks[0].ndim == 2 # (H, W)
  188. else:
  189. assert masks.ndim == 3 # (N, H, W)
  190. self.masks = np.stack(masks).reshape(-1, height, width)
  191. assert self.masks.shape[1] == self.height
  192. assert self.masks.shape[2] == self.width
  193. def __getitem__(self, index):
  194. """Index the BitmapMask.
  195. Args:
  196. index (int | ndarray): Indices in the format of integer or ndarray.
  197. Returns:
  198. :obj:`BitmapMasks`: Indexed bitmap masks.
  199. """
  200. masks = self.masks[index].reshape(-1, self.height, self.width)
  201. return BitmapMasks(masks, self.height, self.width)
  202. def __iter__(self):
  203. return iter(self.masks)
  204. def __repr__(self):
  205. s = self.__class__.__name__ + '('
  206. s += f'num_masks={len(self.masks)}, '
  207. s += f'height={self.height}, '
  208. s += f'width={self.width})'
  209. return s
  210. def __len__(self):
  211. """Number of masks."""
  212. return len(self.masks)
  213. def rescale(self, scale, interpolation='nearest'):
  214. """See :func:`BaseInstanceMasks.rescale`."""
  215. if len(self.masks) == 0:
  216. new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
  217. rescaled_masks = np.empty((0, new_h, new_w), dtype=np.uint8)
  218. else:
  219. rescaled_masks = np.stack([
  220. mmcv.imrescale(mask, scale, interpolation=interpolation)
  221. for mask in self.masks
  222. ])
  223. height, width = rescaled_masks.shape[1:]
  224. return BitmapMasks(rescaled_masks, height, width)
  225. def resize(self, out_shape, interpolation='nearest'):
  226. """See :func:`BaseInstanceMasks.resize`."""
  227. if len(self.masks) == 0:
  228. resized_masks = np.empty((0, *out_shape), dtype=np.uint8)
  229. else:
  230. resized_masks = np.stack([
  231. mmcv.imresize(
  232. mask, out_shape[::-1], interpolation=interpolation)
  233. for mask in self.masks
  234. ])
  235. return BitmapMasks(resized_masks, *out_shape)
  236. def flip(self, flip_direction='horizontal'):
  237. """See :func:`BaseInstanceMasks.flip`."""
  238. assert flip_direction in ('horizontal', 'vertical', 'diagonal')
  239. if len(self.masks) == 0:
  240. flipped_masks = self.masks
  241. else:
  242. flipped_masks = np.stack([
  243. mmcv.imflip(mask, direction=flip_direction)
  244. for mask in self.masks
  245. ])
  246. return BitmapMasks(flipped_masks, self.height, self.width)
  247. def pad(self, out_shape, pad_val=0):
  248. """See :func:`BaseInstanceMasks.pad`."""
  249. if len(self.masks) == 0:
  250. padded_masks = np.empty((0, *out_shape), dtype=np.uint8)
  251. else:
  252. padded_masks = np.stack([
  253. mmcv.impad(mask, shape=out_shape, pad_val=pad_val)
  254. for mask in self.masks
  255. ])
  256. return BitmapMasks(padded_masks, *out_shape)
  257. def crop(self, bbox):
  258. """See :func:`BaseInstanceMasks.crop`."""
  259. assert isinstance(bbox, np.ndarray)
  260. assert bbox.ndim == 1
  261. # clip the boundary
  262. bbox = bbox.copy()
  263. bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
  264. bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
  265. x1, y1, x2, y2 = bbox
  266. w = np.maximum(x2 - x1, 1)
  267. h = np.maximum(y2 - y1, 1)
  268. if len(self.masks) == 0:
  269. cropped_masks = np.empty((0, h, w), dtype=np.uint8)
  270. else:
  271. cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w]
  272. return BitmapMasks(cropped_masks, h, w)
  273. def crop_and_resize(self,
  274. bboxes,
  275. out_shape,
  276. inds,
  277. device='cpu',
  278. interpolation='bilinear',
  279. binarize=True):
  280. """See :func:`BaseInstanceMasks.crop_and_resize`."""
  281. if len(self.masks) == 0:
  282. empty_masks = np.empty((0, *out_shape), dtype=np.uint8)
  283. return BitmapMasks(empty_masks, *out_shape)
  284. # convert bboxes to tensor
  285. if isinstance(bboxes, np.ndarray):
  286. bboxes = torch.from_numpy(bboxes).to(device=device)
  287. if isinstance(inds, np.ndarray):
  288. inds = torch.from_numpy(inds).to(device=device)
  289. num_bbox = bboxes.shape[0]
  290. fake_inds = torch.arange(
  291. num_bbox, device=device).to(dtype=bboxes.dtype)[:, None]
  292. rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5
  293. rois = rois.to(device=device)
  294. if num_bbox > 0:
  295. gt_masks_th = torch.from_numpy(self.masks).to(device).index_select(
  296. 0, inds).to(dtype=rois.dtype)
  297. targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
  298. 1.0, 0, 'avg', True).squeeze(1)
  299. if binarize:
  300. resized_masks = (targets >= 0.5).cpu().numpy()
  301. else:
  302. resized_masks = targets.cpu().numpy()
  303. else:
  304. resized_masks = []
  305. return BitmapMasks(resized_masks, *out_shape)
  306. def expand(self, expanded_h, expanded_w, top, left):
  307. """See :func:`BaseInstanceMasks.expand`."""
  308. if len(self.masks) == 0:
  309. expanded_mask = np.empty((0, expanded_h, expanded_w),
  310. dtype=np.uint8)
  311. else:
  312. expanded_mask = np.zeros((len(self), expanded_h, expanded_w),
  313. dtype=np.uint8)
  314. expanded_mask[:, top:top + self.height,
  315. left:left + self.width] = self.masks
  316. return BitmapMasks(expanded_mask, expanded_h, expanded_w)
  317. def translate(self,
  318. out_shape,
  319. offset,
  320. direction='horizontal',
  321. fill_val=0,
  322. interpolation='bilinear'):
  323. """Translate the BitmapMasks.
  324. Args:
  325. out_shape (tuple[int]): Shape for output mask, format (h, w).
  326. offset (int | float): The offset for translate.
  327. direction (str): The translate direction, either "horizontal"
  328. or "vertical".
  329. fill_val (int | float): Border value. Default 0 for masks.
  330. interpolation (str): Same as :func:`mmcv.imtranslate`.
  331. Returns:
  332. BitmapMasks: Translated BitmapMasks.
  333. Example:
  334. >>> from mmdet.core.mask.structures import BitmapMasks
  335. >>> self = BitmapMasks.random(dtype=np.uint8)
  336. >>> out_shape = (32, 32)
  337. >>> offset = 4
  338. >>> direction = 'horizontal'
  339. >>> fill_val = 0
  340. >>> interpolation = 'bilinear'
  341. >>> # Note, There seem to be issues when:
  342. >>> # * out_shape is different than self's shape
  343. >>> # * the mask dtype is not supported by cv2.AffineWarp
  344. >>> new = self.translate(out_shape, offset, direction, fill_val,
  345. >>> interpolation)
  346. >>> assert len(new) == len(self)
  347. >>> assert new.height, new.width == out_shape
  348. """
  349. if len(self.masks) == 0:
  350. translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
  351. else:
  352. translated_masks = mmcv.imtranslate(
  353. self.masks.transpose((1, 2, 0)),
  354. offset,
  355. direction,
  356. border_value=fill_val,
  357. interpolation=interpolation)
  358. if translated_masks.ndim == 2:
  359. translated_masks = translated_masks[:, :, None]
  360. translated_masks = translated_masks.transpose(
  361. (2, 0, 1)).astype(self.masks.dtype)
  362. return BitmapMasks(translated_masks, *out_shape)
  363. def shear(self,
  364. out_shape,
  365. magnitude,
  366. direction='horizontal',
  367. border_value=0,
  368. interpolation='bilinear'):
  369. """Shear the BitmapMasks.
  370. Args:
  371. out_shape (tuple[int]): Shape for output mask, format (h, w).
  372. magnitude (int | float): The magnitude used for shear.
  373. direction (str): The shear direction, either "horizontal"
  374. or "vertical".
  375. border_value (int | tuple[int]): Value used in case of a
  376. constant border.
  377. interpolation (str): Same as in :func:`mmcv.imshear`.
  378. Returns:
  379. BitmapMasks: The sheared masks.
  380. """
  381. if len(self.masks) == 0:
  382. sheared_masks = np.empty((0, *out_shape), dtype=np.uint8)
  383. else:
  384. sheared_masks = mmcv.imshear(
  385. self.masks.transpose((1, 2, 0)),
  386. magnitude,
  387. direction,
  388. border_value=border_value,
  389. interpolation=interpolation)
  390. if sheared_masks.ndim == 2:
  391. sheared_masks = sheared_masks[:, :, None]
  392. sheared_masks = sheared_masks.transpose(
  393. (2, 0, 1)).astype(self.masks.dtype)
  394. return BitmapMasks(sheared_masks, *out_shape)
  395. def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
  396. """Rotate the BitmapMasks.
  397. Args:
  398. out_shape (tuple[int]): Shape for output mask, format (h, w).
  399. angle (int | float): Rotation angle in degrees. Positive values
  400. mean counter-clockwise rotation.
  401. center (tuple[float], optional): Center point (w, h) of the
  402. rotation in source image. If not specified, the center of
  403. the image will be used.
  404. scale (int | float): Isotropic scale factor.
  405. fill_val (int | float): Border value. Default 0 for masks.
  406. Returns:
  407. BitmapMasks: Rotated BitmapMasks.
  408. """
  409. if len(self.masks) == 0:
  410. rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype)
  411. else:
  412. rotated_masks = mmcv.imrotate(
  413. self.masks.transpose((1, 2, 0)),
  414. angle,
  415. center=center,
  416. scale=scale,
  417. border_value=fill_val)
  418. if rotated_masks.ndim == 2:
  419. # case when only one mask, (h, w)
  420. rotated_masks = rotated_masks[:, :, None] # (h, w, 1)
  421. rotated_masks = rotated_masks.transpose(
  422. (2, 0, 1)).astype(self.masks.dtype)
  423. return BitmapMasks(rotated_masks, *out_shape)
  424. @property
  425. def areas(self):
  426. """See :py:attr:`BaseInstanceMasks.areas`."""
  427. return self.masks.sum((1, 2))
  428. def to_ndarray(self):
  429. """See :func:`BaseInstanceMasks.to_ndarray`."""
  430. return self.masks
  431. def to_tensor(self, dtype, device):
  432. """See :func:`BaseInstanceMasks.to_tensor`."""
  433. return torch.tensor(self.masks, dtype=dtype, device=device)
  434. @classmethod
  435. def random(cls,
  436. num_masks=3,
  437. height=32,
  438. width=32,
  439. dtype=np.uint8,
  440. rng=None):
  441. """Generate random bitmap masks for demo / testing purposes.
  442. Example:
  443. >>> from mmdet.core.mask.structures import BitmapMasks
  444. >>> self = BitmapMasks.random()
  445. >>> print('self = {}'.format(self))
  446. self = BitmapMasks(num_masks=3, height=32, width=32)
  447. """
  448. from mmdet.utils.util_random import ensure_rng
  449. rng = ensure_rng(rng)
  450. masks = (rng.rand(num_masks, height, width) > 0.1).astype(dtype)
  451. self = cls(masks, height=height, width=width)
  452. return self
  453. def get_bboxes(self):
  454. num_masks = len(self)
  455. boxes = np.zeros((num_masks, 4), dtype=np.float32)
  456. x_any = self.masks.any(axis=1)
  457. y_any = self.masks.any(axis=2)
  458. for idx in range(num_masks):
  459. x = np.where(x_any[idx, :])[0]
  460. y = np.where(y_any[idx, :])[0]
  461. if len(x) > 0 and len(y) > 0:
  462. # use +1 for x_max and y_max so that the right and bottom
  463. # boundary of instance masks are fully included by the box
  464. boxes[idx, :] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1],
  465. dtype=np.float32)
  466. return boxes
  467. class PolygonMasks(BaseInstanceMasks):
  468. """This class represents masks in the form of polygons.
  469. Polygons is a list of three levels. The first level of the list
  470. corresponds to objects, the second level to the polys that compose the
  471. object, the third level to the poly coordinates
  472. Args:
  473. masks (list[list[ndarray]]): The first level of the list
  474. corresponds to objects, the second level to the polys that
  475. compose the object, the third level to the poly coordinates
  476. height (int): height of masks
  477. width (int): width of masks
  478. Example:
  479. >>> from mmdet.core.mask.structures import * # NOQA
  480. >>> masks = [
  481. >>> [ np.array([0, 0, 10, 0, 10, 10., 0, 10, 0, 0]) ]
  482. >>> ]
  483. >>> height, width = 16, 16
  484. >>> self = PolygonMasks(masks, height, width)
  485. >>> # demo translate
  486. >>> new = self.translate((16, 16), 4., direction='horizontal')
  487. >>> assert np.all(new.masks[0][0][1::2] == masks[0][0][1::2])
  488. >>> assert np.all(new.masks[0][0][0::2] == masks[0][0][0::2] + 4)
  489. >>> # demo crop_and_resize
  490. >>> num_boxes = 3
  491. >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
  492. >>> out_shape = (16, 16)
  493. >>> inds = torch.randint(0, len(self), size=(num_boxes,))
  494. >>> device = 'cpu'
  495. >>> interpolation = 'bilinear'
  496. >>> new = self.crop_and_resize(
  497. ... bboxes, out_shape, inds, device, interpolation)
  498. >>> assert len(new) == num_boxes
  499. >>> assert new.height, new.width == out_shape
  500. """
  501. def __init__(self, masks, height, width):
  502. assert isinstance(masks, list)
  503. if len(masks) > 0:
  504. assert isinstance(masks[0], list)
  505. assert isinstance(masks[0][0], np.ndarray)
  506. self.height = height
  507. self.width = width
  508. self.masks = masks
  509. def __getitem__(self, index):
  510. """Index the polygon masks.
  511. Args:
  512. index (ndarray | List): The indices.
  513. Returns:
  514. :obj:`PolygonMasks`: The indexed polygon masks.
  515. """
  516. if isinstance(index, np.ndarray):
  517. index = index.tolist()
  518. if isinstance(index, list):
  519. masks = [self.masks[i] for i in index]
  520. else:
  521. try:
  522. masks = self.masks[index]
  523. except Exception:
  524. raise ValueError(
  525. f'Unsupported input of type {type(index)} for indexing!')
  526. if len(masks) and isinstance(masks[0], np.ndarray):
  527. masks = [masks] # ensure a list of three levels
  528. return PolygonMasks(masks, self.height, self.width)
  529. def __iter__(self):
  530. return iter(self.masks)
  531. def __repr__(self):
  532. s = self.__class__.__name__ + '('
  533. s += f'num_masks={len(self.masks)}, '
  534. s += f'height={self.height}, '
  535. s += f'width={self.width})'
  536. return s
  537. def __len__(self):
  538. """Number of masks."""
  539. return len(self.masks)
  540. def rescale(self, scale, interpolation=None):
  541. """see :func:`BaseInstanceMasks.rescale`"""
  542. new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
  543. if len(self.masks) == 0:
  544. rescaled_masks = PolygonMasks([], new_h, new_w)
  545. else:
  546. rescaled_masks = self.resize((new_h, new_w))
  547. return rescaled_masks
  548. def resize(self, out_shape, interpolation=None):
  549. """see :func:`BaseInstanceMasks.resize`"""
  550. if len(self.masks) == 0:
  551. resized_masks = PolygonMasks([], *out_shape)
  552. else:
  553. h_scale = out_shape[0] / self.height
  554. w_scale = out_shape[1] / self.width
  555. resized_masks = []
  556. for poly_per_obj in self.masks:
  557. resized_poly = []
  558. for p in poly_per_obj:
  559. p = p.copy()
  560. p[0::2] = p[0::2] * w_scale
  561. p[1::2] = p[1::2] * h_scale
  562. resized_poly.append(p)
  563. resized_masks.append(resized_poly)
  564. resized_masks = PolygonMasks(resized_masks, *out_shape)
  565. return resized_masks
  566. def flip(self, flip_direction='horizontal'):
  567. """see :func:`BaseInstanceMasks.flip`"""
  568. assert flip_direction in ('horizontal', 'vertical', 'diagonal')
  569. if len(self.masks) == 0:
  570. flipped_masks = PolygonMasks([], self.height, self.width)
  571. else:
  572. flipped_masks = []
  573. for poly_per_obj in self.masks:
  574. flipped_poly_per_obj = []
  575. for p in poly_per_obj:
  576. p = p.copy()
  577. if flip_direction == 'horizontal':
  578. p[0::2] = self.width - p[0::2]
  579. elif flip_direction == 'vertical':
  580. p[1::2] = self.height - p[1::2]
  581. else:
  582. p[0::2] = self.width - p[0::2]
  583. p[1::2] = self.height - p[1::2]
  584. flipped_poly_per_obj.append(p)
  585. flipped_masks.append(flipped_poly_per_obj)
  586. flipped_masks = PolygonMasks(flipped_masks, self.height,
  587. self.width)
  588. return flipped_masks
  589. def crop(self, bbox):
  590. """see :func:`BaseInstanceMasks.crop`"""
  591. assert isinstance(bbox, np.ndarray)
  592. assert bbox.ndim == 1
  593. # clip the boundary
  594. bbox = bbox.copy()
  595. bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
  596. bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
  597. x1, y1, x2, y2 = bbox
  598. w = np.maximum(x2 - x1, 1)
  599. h = np.maximum(y2 - y1, 1)
  600. if len(self.masks) == 0:
  601. cropped_masks = PolygonMasks([], h, w)
  602. else:
  603. cropped_masks = []
  604. for poly_per_obj in self.masks:
  605. cropped_poly_per_obj = []
  606. for p in poly_per_obj:
  607. # pycocotools will clip the boundary
  608. p = p.copy()
  609. p[0::2] = p[0::2] - bbox[0]
  610. p[1::2] = p[1::2] - bbox[1]
  611. cropped_poly_per_obj.append(p)
  612. cropped_masks.append(cropped_poly_per_obj)
  613. cropped_masks = PolygonMasks(cropped_masks, h, w)
  614. return cropped_masks
  615. def pad(self, out_shape, pad_val=0):
  616. """padding has no effect on polygons`"""
  617. return PolygonMasks(self.masks, *out_shape)
  618. def expand(self, *args, **kwargs):
  619. """TODO: Add expand for polygon"""
  620. raise NotImplementedError
  621. def crop_and_resize(self,
  622. bboxes,
  623. out_shape,
  624. inds,
  625. device='cpu',
  626. interpolation='bilinear',
  627. binarize=True):
  628. """see :func:`BaseInstanceMasks.crop_and_resize`"""
  629. out_h, out_w = out_shape
  630. if len(self.masks) == 0:
  631. return PolygonMasks([], out_h, out_w)
  632. if not binarize:
  633. raise ValueError('Polygons are always binary, '
  634. 'setting binarize=False is unsupported')
  635. resized_masks = []
  636. for i in range(len(bboxes)):
  637. mask = self.masks[inds[i]]
  638. bbox = bboxes[i, :]
  639. x1, y1, x2, y2 = bbox
  640. w = np.maximum(x2 - x1, 1)
  641. h = np.maximum(y2 - y1, 1)
  642. h_scale = out_h / max(h, 0.1) # avoid too large scale
  643. w_scale = out_w / max(w, 0.1)
  644. resized_mask = []
  645. for p in mask:
  646. p = p.copy()
  647. # crop
  648. # pycocotools will clip the boundary
  649. p[0::2] = p[0::2] - bbox[0]
  650. p[1::2] = p[1::2] - bbox[1]
  651. # resize
  652. p[0::2] = p[0::2] * w_scale
  653. p[1::2] = p[1::2] * h_scale
  654. resized_mask.append(p)
  655. resized_masks.append(resized_mask)
  656. return PolygonMasks(resized_masks, *out_shape)
  657. def translate(self,
  658. out_shape,
  659. offset,
  660. direction='horizontal',
  661. fill_val=None,
  662. interpolation=None):
  663. """Translate the PolygonMasks.
  664. Example:
  665. >>> self = PolygonMasks.random(dtype=np.int)
  666. >>> out_shape = (self.height, self.width)
  667. >>> new = self.translate(out_shape, 4., direction='horizontal')
  668. >>> assert np.all(new.masks[0][0][1::2] == self.masks[0][0][1::2])
  669. >>> assert np.all(new.masks[0][0][0::2] == self.masks[0][0][0::2] + 4) # noqa: E501
  670. """
  671. assert fill_val is None or fill_val == 0, 'Here fill_val is not '\
  672. f'used, and defaultly should be None or 0. got {fill_val}.'
  673. if len(self.masks) == 0:
  674. translated_masks = PolygonMasks([], *out_shape)
  675. else:
  676. translated_masks = []
  677. for poly_per_obj in self.masks:
  678. translated_poly_per_obj = []
  679. for p in poly_per_obj:
  680. p = p.copy()
  681. if direction == 'horizontal':
  682. p[0::2] = np.clip(p[0::2] + offset, 0, out_shape[1])
  683. elif direction == 'vertical':
  684. p[1::2] = np.clip(p[1::2] + offset, 0, out_shape[0])
  685. translated_poly_per_obj.append(p)
  686. translated_masks.append(translated_poly_per_obj)
  687. translated_masks = PolygonMasks(translated_masks, *out_shape)
  688. return translated_masks
  689. def shear(self,
  690. out_shape,
  691. magnitude,
  692. direction='horizontal',
  693. border_value=0,
  694. interpolation='bilinear'):
  695. """See :func:`BaseInstanceMasks.shear`."""
  696. if len(self.masks) == 0:
  697. sheared_masks = PolygonMasks([], *out_shape)
  698. else:
  699. sheared_masks = []
  700. if direction == 'horizontal':
  701. shear_matrix = np.stack([[1, magnitude],
  702. [0, 1]]).astype(np.float32)
  703. elif direction == 'vertical':
  704. shear_matrix = np.stack([[1, 0], [magnitude,
  705. 1]]).astype(np.float32)
  706. for poly_per_obj in self.masks:
  707. sheared_poly = []
  708. for p in poly_per_obj:
  709. p = np.stack([p[0::2], p[1::2]], axis=0) # [2, n]
  710. new_coords = np.matmul(shear_matrix, p) # [2, n]
  711. new_coords[0, :] = np.clip(new_coords[0, :], 0,
  712. out_shape[1])
  713. new_coords[1, :] = np.clip(new_coords[1, :], 0,
  714. out_shape[0])
  715. sheared_poly.append(
  716. new_coords.transpose((1, 0)).reshape(-1))
  717. sheared_masks.append(sheared_poly)
  718. sheared_masks = PolygonMasks(sheared_masks, *out_shape)
  719. return sheared_masks
  720. def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
  721. """See :func:`BaseInstanceMasks.rotate`."""
  722. if len(self.masks) == 0:
  723. rotated_masks = PolygonMasks([], *out_shape)
  724. else:
  725. rotated_masks = []
  726. rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale)
  727. for poly_per_obj in self.masks:
  728. rotated_poly = []
  729. for p in poly_per_obj:
  730. p = p.copy()
  731. coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2]
  732. # pad 1 to convert from format [x, y] to homogeneous
  733. # coordinates format [x, y, 1]
  734. coords = np.concatenate(
  735. (coords, np.ones((coords.shape[0], 1), coords.dtype)),
  736. axis=1) # [n, 3]
  737. rotated_coords = np.matmul(
  738. rotate_matrix[None, :, :],
  739. coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2]
  740. rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0,
  741. out_shape[1])
  742. rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0,
  743. out_shape[0])
  744. rotated_poly.append(rotated_coords.reshape(-1))
  745. rotated_masks.append(rotated_poly)
  746. rotated_masks = PolygonMasks(rotated_masks, *out_shape)
  747. return rotated_masks
  748. def to_bitmap(self):
  749. """convert polygon masks to bitmap masks."""
  750. bitmap_masks = self.to_ndarray()
  751. return BitmapMasks(bitmap_masks, self.height, self.width)
  752. @property
  753. def areas(self):
  754. """Compute areas of masks.
  755. This func is modified from `detectron2
  756. <https://github.com/facebookresearch/detectron2/blob/ffff8acc35ea88ad1cb1806ab0f00b4c1c5dbfd9/detectron2/structures/masks.py#L387>`_.
  757. The function only works with Polygons using the shoelace formula.
  758. Return:
  759. ndarray: areas of each instance
  760. """ # noqa: W501
  761. area = []
  762. for polygons_per_obj in self.masks:
  763. area_per_obj = 0
  764. for p in polygons_per_obj:
  765. area_per_obj += self._polygon_area(p[0::2], p[1::2])
  766. area.append(area_per_obj)
  767. return np.asarray(area)
  768. def _polygon_area(self, x, y):
  769. """Compute the area of a component of a polygon.
  770. Using the shoelace formula:
  771. https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
  772. Args:
  773. x (ndarray): x coordinates of the component
  774. y (ndarray): y coordinates of the component
  775. Return:
  776. float: the are of the component
  777. """ # noqa: 501
  778. return 0.5 * np.abs(
  779. np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
  780. def to_ndarray(self):
  781. """Convert masks to the format of ndarray."""
  782. if len(self.masks) == 0:
  783. return np.empty((0, self.height, self.width), dtype=np.uint8)
  784. bitmap_masks = []
  785. for poly_per_obj in self.masks:
  786. bitmap_masks.append(
  787. polygon_to_bitmap(poly_per_obj, self.height, self.width))
  788. return np.stack(bitmap_masks)
  789. def to_tensor(self, dtype, device):
  790. """See :func:`BaseInstanceMasks.to_tensor`."""
  791. if len(self.masks) == 0:
  792. return torch.empty((0, self.height, self.width),
  793. dtype=dtype,
  794. device=device)
  795. ndarray_masks = self.to_ndarray()
  796. return torch.tensor(ndarray_masks, dtype=dtype, device=device)
  797. @classmethod
  798. def random(cls,
  799. num_masks=3,
  800. height=32,
  801. width=32,
  802. n_verts=5,
  803. dtype=np.float32,
  804. rng=None):
  805. """Generate random polygon masks for demo / testing purposes.
  806. Adapted from [1]_
  807. References:
  808. .. [1] https://gitlab.kitware.com/computer-vision/kwimage/-/blob/928cae35ca8/kwimage/structs/polygon.py#L379 # noqa: E501
  809. Example:
  810. >>> from mmdet.core.mask.structures import PolygonMasks
  811. >>> self = PolygonMasks.random()
  812. >>> print('self = {}'.format(self))
  813. """
  814. from mmdet.utils.util_random import ensure_rng
  815. rng = ensure_rng(rng)
  816. def _gen_polygon(n, irregularity, spikeyness):
  817. """Creates the polygon by sampling points on a circle around the
  818. centre. Random noise is added by varying the angular spacing
  819. between sequential points, and by varying the radial distance of
  820. each point from the centre.
  821. Based on original code by Mike Ounsworth
  822. Args:
  823. n (int): number of vertices
  824. irregularity (float): [0,1] indicating how much variance there
  825. is in the angular spacing of vertices. [0,1] will map to
  826. [0, 2pi/numberOfVerts]
  827. spikeyness (float): [0,1] indicating how much variance there is
  828. in each vertex from the circle of radius aveRadius. [0,1]
  829. will map to [0, aveRadius]
  830. Returns:
  831. a list of vertices, in CCW order.
  832. """
  833. from scipy.stats import truncnorm
  834. # Generate around the unit circle
  835. cx, cy = (0.0, 0.0)
  836. radius = 1
  837. tau = np.pi * 2
  838. irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / n
  839. spikeyness = np.clip(spikeyness, 1e-9, 1)
  840. # generate n angle steps
  841. lower = (tau / n) - irregularity
  842. upper = (tau / n) + irregularity
  843. angle_steps = rng.uniform(lower, upper, n)
  844. # normalize the steps so that point 0 and point n+1 are the same
  845. k = angle_steps.sum() / (2 * np.pi)
  846. angles = (angle_steps / k).cumsum() + rng.uniform(0, tau)
  847. # Convert high and low values to be wrt the standard normal range
  848. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html
  849. low = 0
  850. high = 2 * radius
  851. mean = radius
  852. std = spikeyness
  853. a = (low - mean) / std
  854. b = (high - mean) / std
  855. tnorm = truncnorm(a=a, b=b, loc=mean, scale=std)
  856. # now generate the points
  857. radii = tnorm.rvs(n, random_state=rng)
  858. x_pts = cx + radii * np.cos(angles)
  859. y_pts = cy + radii * np.sin(angles)
  860. points = np.hstack([x_pts[:, None], y_pts[:, None]])
  861. # Scale to 0-1 space
  862. points = points - points.min(axis=0)
  863. points = points / points.max(axis=0)
  864. # Randomly place within 0-1 space
  865. points = points * (rng.rand() * .8 + .2)
  866. min_pt = points.min(axis=0)
  867. max_pt = points.max(axis=0)
  868. high = (1 - max_pt)
  869. low = (0 - min_pt)
  870. offset = (rng.rand(2) * (high - low)) + low
  871. points = points + offset
  872. return points
  873. def _order_vertices(verts):
  874. """
  875. References:
  876. https://stackoverflow.com/questions/1709283/how-can-i-sort-a-coordinate-list-for-a-rectangle-counterclockwise
  877. """
  878. mlat = verts.T[0].sum() / len(verts)
  879. mlng = verts.T[1].sum() / len(verts)
  880. tau = np.pi * 2
  881. angle = (np.arctan2(mlat - verts.T[0], verts.T[1] - mlng) +
  882. tau) % tau
  883. sortx = angle.argsort()
  884. verts = verts.take(sortx, axis=0)
  885. return verts
  886. # Generate a random exterior for each requested mask
  887. masks = []
  888. for _ in range(num_masks):
  889. exterior = _order_vertices(_gen_polygon(n_verts, 0.9, 0.9))
  890. exterior = (exterior * [(width, height)]).astype(dtype)
  891. masks.append([exterior.ravel()])
  892. self = cls(masks, height, width)
  893. return self
  894. def get_bboxes(self):
  895. num_masks = len(self)
  896. boxes = np.zeros((num_masks, 4), dtype=np.float32)
  897. for idx, poly_per_obj in enumerate(self.masks):
  898. # simply use a number that is big enough for comparison with
  899. # coordinates
  900. xy_min = np.array([self.width * 2, self.height * 2],
  901. dtype=np.float32)
  902. xy_max = np.zeros(2, dtype=np.float32)
  903. for p in poly_per_obj:
  904. xy = np.array(p).reshape(-1, 2).astype(np.float32)
  905. xy_min = np.minimum(xy_min, np.min(xy, axis=0))
  906. xy_max = np.maximum(xy_max, np.max(xy, axis=0))
  907. boxes[idx, :2] = xy_min
  908. boxes[idx, 2:] = xy_max
  909. return boxes
  910. def polygon_to_bitmap(polygons, height, width):
  911. """Convert masks from the form of polygons to bitmaps.
  912. Args:
  913. polygons (list[ndarray]): masks in polygon representation
  914. height (int): mask height
  915. width (int): mask width
  916. Return:
  917. ndarray: the converted masks in bitmap representation
  918. """
  919. rles = maskUtils.frPyObjects(polygons, height, width)
  920. rle = maskUtils.merge(rles)
  921. bitmap_mask = maskUtils.decode(rle).astype(np.bool)
  922. return bitmap_mask

No Description

Contributors (1)