You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_augment.py 36 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import cv2
  4. import mmcv
  5. import numpy as np
  6. from ..builder import PIPELINES
  7. from .compose import Compose
  8. _MAX_LEVEL = 10
  9. def level_to_value(level, max_value):
  10. """Map from level to values based on max_value."""
  11. return (level / _MAX_LEVEL) * max_value
  12. def enhance_level_to_value(level, a=1.8, b=0.1):
  13. """Map from level to values."""
  14. return (level / _MAX_LEVEL) * a + b
  15. def random_negative(value, random_negative_prob):
  16. """Randomly negate value based on random_negative_prob."""
  17. return -value if np.random.rand() < random_negative_prob else value
  18. def bbox2fields():
  19. """The key correspondence from bboxes to labels, masks and
  20. segmentations."""
  21. bbox2label = {
  22. 'gt_bboxes': 'gt_labels',
  23. 'gt_bboxes_ignore': 'gt_labels_ignore'
  24. }
  25. bbox2mask = {
  26. 'gt_bboxes': 'gt_masks',
  27. 'gt_bboxes_ignore': 'gt_masks_ignore'
  28. }
  29. bbox2seg = {
  30. 'gt_bboxes': 'gt_semantic_seg',
  31. }
  32. return bbox2label, bbox2mask, bbox2seg
  33. @PIPELINES.register_module()
  34. class AutoAugment:
  35. """Auto augmentation.
  36. This data augmentation is proposed in `Learning Data Augmentation
  37. Strategies for Object Detection <https://arxiv.org/pdf/1906.11172>`_.
  38. TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms
  39. Args:
  40. policies (list[list[dict]]): The policies of auto augmentation. Each
  41. policy in ``policies`` is a specific augmentation policy, and is
  42. composed by several augmentations (dict). When AutoAugment is
  43. called, a random policy in ``policies`` will be selected to
  44. augment images.
  45. Examples:
  46. >>> replace = (104, 116, 124)
  47. >>> policies = [
  48. >>> [
  49. >>> dict(type='Sharpness', prob=0.0, level=8),
  50. >>> dict(
  51. >>> type='Shear',
  52. >>> prob=0.4,
  53. >>> level=0,
  54. >>> replace=replace,
  55. >>> axis='x')
  56. >>> ],
  57. >>> [
  58. >>> dict(
  59. >>> type='Rotate',
  60. >>> prob=0.6,
  61. >>> level=10,
  62. >>> replace=replace),
  63. >>> dict(type='Color', prob=1.0, level=6)
  64. >>> ]
  65. >>> ]
  66. >>> augmentation = AutoAugment(policies)
  67. >>> img = np.ones(100, 100, 3)
  68. >>> gt_bboxes = np.ones(10, 4)
  69. >>> results = dict(img=img, gt_bboxes=gt_bboxes)
  70. >>> results = augmentation(results)
  71. """
  72. def __init__(self, policies):
  73. assert isinstance(policies, list) and len(policies) > 0, \
  74. 'Policies must be a non-empty list.'
  75. for policy in policies:
  76. assert isinstance(policy, list) and len(policy) > 0, \
  77. 'Each policy in policies must be a non-empty list.'
  78. for augment in policy:
  79. assert isinstance(augment, dict) and 'type' in augment, \
  80. 'Each specific augmentation must be a dict with key' \
  81. ' "type".'
  82. self.policies = copy.deepcopy(policies)
  83. self.transforms = [Compose(policy) for policy in self.policies]
  84. def __call__(self, results):
  85. transform = np.random.choice(self.transforms)
  86. return transform(results)
  87. def __repr__(self):
  88. return f'{self.__class__.__name__}(policies={self.policies})'
  89. @PIPELINES.register_module()
  90. class Shear:
  91. """Apply Shear Transformation to image (and its corresponding bbox, mask,
  92. segmentation).
  93. Args:
  94. level (int | float): The level should be in range [0,_MAX_LEVEL].
  95. img_fill_val (int | float | tuple): The filled values for image border.
  96. If float, the same fill value will be used for all the three
  97. channels of image. If tuple, the should be 3 elements.
  98. seg_ignore_label (int): The fill value used for segmentation map.
  99. Note this value must equals ``ignore_label`` in ``semantic_head``
  100. of the corresponding config. Default 255.
  101. prob (float): The probability for performing Shear and should be in
  102. range [0, 1].
  103. direction (str): The direction for shear, either "horizontal"
  104. or "vertical".
  105. max_shear_magnitude (float): The maximum magnitude for Shear
  106. transformation.
  107. random_negative_prob (float): The probability that turns the
  108. offset negative. Should be in range [0,1]
  109. interpolation (str): Same as in :func:`mmcv.imshear`.
  110. """
  111. def __init__(self,
  112. level,
  113. img_fill_val=128,
  114. seg_ignore_label=255,
  115. prob=0.5,
  116. direction='horizontal',
  117. max_shear_magnitude=0.3,
  118. random_negative_prob=0.5,
  119. interpolation='bilinear'):
  120. assert isinstance(level, (int, float)), 'The level must be type ' \
  121. f'int or float, got {type(level)}.'
  122. assert 0 <= level <= _MAX_LEVEL, 'The level should be in range ' \
  123. f'[0,{_MAX_LEVEL}], got {level}.'
  124. if isinstance(img_fill_val, (float, int)):
  125. img_fill_val = tuple([float(img_fill_val)] * 3)
  126. elif isinstance(img_fill_val, tuple):
  127. assert len(img_fill_val) == 3, 'img_fill_val as tuple must ' \
  128. f'have 3 elements. got {len(img_fill_val)}.'
  129. img_fill_val = tuple([float(val) for val in img_fill_val])
  130. else:
  131. raise ValueError(
  132. 'img_fill_val must be float or tuple with 3 elements.')
  133. assert np.all([0 <= val <= 255 for val in img_fill_val]), 'all ' \
  134. 'elements of img_fill_val should between range [0,255].' \
  135. f'got {img_fill_val}.'
  136. assert 0 <= prob <= 1.0, 'The probability of shear should be in ' \
  137. f'range [0,1]. got {prob}.'
  138. assert direction in ('horizontal', 'vertical'), 'direction must ' \
  139. f'in be either "horizontal" or "vertical". got {direction}.'
  140. assert isinstance(max_shear_magnitude, float), 'max_shear_magnitude ' \
  141. f'should be type float. got {type(max_shear_magnitude)}.'
  142. assert 0. <= max_shear_magnitude <= 1., 'Defaultly ' \
  143. 'max_shear_magnitude should be in range [0,1]. ' \
  144. f'got {max_shear_magnitude}.'
  145. self.level = level
  146. self.magnitude = level_to_value(level, max_shear_magnitude)
  147. self.img_fill_val = img_fill_val
  148. self.seg_ignore_label = seg_ignore_label
  149. self.prob = prob
  150. self.direction = direction
  151. self.max_shear_magnitude = max_shear_magnitude
  152. self.random_negative_prob = random_negative_prob
  153. self.interpolation = interpolation
  154. def _shear_img(self,
  155. results,
  156. magnitude,
  157. direction='horizontal',
  158. interpolation='bilinear'):
  159. """Shear the image.
  160. Args:
  161. results (dict): Result dict from loading pipeline.
  162. magnitude (int | float): The magnitude used for shear.
  163. direction (str): The direction for shear, either "horizontal"
  164. or "vertical".
  165. interpolation (str): Same as in :func:`mmcv.imshear`.
  166. """
  167. for key in results.get('img_fields', ['img']):
  168. img = results[key]
  169. img_sheared = mmcv.imshear(
  170. img,
  171. magnitude,
  172. direction,
  173. border_value=self.img_fill_val,
  174. interpolation=interpolation)
  175. results[key] = img_sheared.astype(img.dtype)
  176. results['img_shape'] = results[key].shape
  177. def _shear_bboxes(self, results, magnitude):
  178. """Shear the bboxes."""
  179. h, w, c = results['img_shape']
  180. if self.direction == 'horizontal':
  181. shear_matrix = np.stack([[1, magnitude],
  182. [0, 1]]).astype(np.float32) # [2, 2]
  183. else:
  184. shear_matrix = np.stack([[1, 0], [magnitude,
  185. 1]]).astype(np.float32)
  186. for key in results.get('bbox_fields', []):
  187. min_x, min_y, max_x, max_y = np.split(
  188. results[key], results[key].shape[-1], axis=-1)
  189. coordinates = np.stack([[min_x, min_y], [max_x, min_y],
  190. [min_x, max_y],
  191. [max_x, max_y]]) # [4, 2, nb_box, 1]
  192. coordinates = coordinates[..., 0].transpose(
  193. (2, 1, 0)).astype(np.float32) # [nb_box, 2, 4]
  194. new_coords = np.matmul(shear_matrix[None, :, :],
  195. coordinates) # [nb_box, 2, 4]
  196. min_x = np.min(new_coords[:, 0, :], axis=-1)
  197. min_y = np.min(new_coords[:, 1, :], axis=-1)
  198. max_x = np.max(new_coords[:, 0, :], axis=-1)
  199. max_y = np.max(new_coords[:, 1, :], axis=-1)
  200. min_x = np.clip(min_x, a_min=0, a_max=w)
  201. min_y = np.clip(min_y, a_min=0, a_max=h)
  202. max_x = np.clip(max_x, a_min=min_x, a_max=w)
  203. max_y = np.clip(max_y, a_min=min_y, a_max=h)
  204. results[key] = np.stack([min_x, min_y, max_x, max_y],
  205. axis=-1).astype(results[key].dtype)
  206. def _shear_masks(self,
  207. results,
  208. magnitude,
  209. direction='horizontal',
  210. fill_val=0,
  211. interpolation='bilinear'):
  212. """Shear the masks."""
  213. h, w, c = results['img_shape']
  214. for key in results.get('mask_fields', []):
  215. masks = results[key]
  216. results[key] = masks.shear((h, w),
  217. magnitude,
  218. direction,
  219. border_value=fill_val,
  220. interpolation=interpolation)
  221. def _shear_seg(self,
  222. results,
  223. magnitude,
  224. direction='horizontal',
  225. fill_val=255,
  226. interpolation='bilinear'):
  227. """Shear the segmentation maps."""
  228. for key in results.get('seg_fields', []):
  229. seg = results[key]
  230. results[key] = mmcv.imshear(
  231. seg,
  232. magnitude,
  233. direction,
  234. border_value=fill_val,
  235. interpolation=interpolation).astype(seg.dtype)
  236. def _filter_invalid(self, results, min_bbox_size=0):
  237. """Filter bboxes and corresponding masks too small after shear
  238. augmentation."""
  239. bbox2label, bbox2mask, _ = bbox2fields()
  240. for key in results.get('bbox_fields', []):
  241. bbox_w = results[key][:, 2] - results[key][:, 0]
  242. bbox_h = results[key][:, 3] - results[key][:, 1]
  243. valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
  244. valid_inds = np.nonzero(valid_inds)[0]
  245. results[key] = results[key][valid_inds]
  246. # label fields. e.g. gt_labels and gt_labels_ignore
  247. label_key = bbox2label.get(key)
  248. if label_key in results:
  249. results[label_key] = results[label_key][valid_inds]
  250. # mask fields, e.g. gt_masks and gt_masks_ignore
  251. mask_key = bbox2mask.get(key)
  252. if mask_key in results:
  253. results[mask_key] = results[mask_key][valid_inds]
  254. def __call__(self, results):
  255. """Call function to shear images, bounding boxes, masks and semantic
  256. segmentation maps.
  257. Args:
  258. results (dict): Result dict from loading pipeline.
  259. Returns:
  260. dict: Sheared results.
  261. """
  262. if np.random.rand() > self.prob:
  263. return results
  264. magnitude = random_negative(self.magnitude, self.random_negative_prob)
  265. self._shear_img(results, magnitude, self.direction, self.interpolation)
  266. self._shear_bboxes(results, magnitude)
  267. # fill_val set to 0 for background of mask.
  268. self._shear_masks(
  269. results,
  270. magnitude,
  271. self.direction,
  272. fill_val=0,
  273. interpolation=self.interpolation)
  274. self._shear_seg(
  275. results,
  276. magnitude,
  277. self.direction,
  278. fill_val=self.seg_ignore_label,
  279. interpolation=self.interpolation)
  280. self._filter_invalid(results)
  281. return results
  282. def __repr__(self):
  283. repr_str = self.__class__.__name__
  284. repr_str += f'(level={self.level}, '
  285. repr_str += f'img_fill_val={self.img_fill_val}, '
  286. repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
  287. repr_str += f'prob={self.prob}, '
  288. repr_str += f'direction={self.direction}, '
  289. repr_str += f'max_shear_magnitude={self.max_shear_magnitude}, '
  290. repr_str += f'random_negative_prob={self.random_negative_prob}, '
  291. repr_str += f'interpolation={self.interpolation})'
  292. return repr_str
  293. @PIPELINES.register_module()
  294. class Rotate:
  295. """Apply Rotate Transformation to image (and its corresponding bbox, mask,
  296. segmentation).
  297. Args:
  298. level (int | float): The level should be in range (0,_MAX_LEVEL].
  299. scale (int | float): Isotropic scale factor. Same in
  300. ``mmcv.imrotate``.
  301. center (int | float | tuple[float]): Center point (w, h) of the
  302. rotation in the source image. If None, the center of the
  303. image will be used. Same in ``mmcv.imrotate``.
  304. img_fill_val (int | float | tuple): The fill value for image border.
  305. If float, the same value will be used for all the three
  306. channels of image. If tuple, the should be 3 elements (e.g.
  307. equals the number of channels for image).
  308. seg_ignore_label (int): The fill value used for segmentation map.
  309. Note this value must equals ``ignore_label`` in ``semantic_head``
  310. of the corresponding config. Default 255.
  311. prob (float): The probability for perform transformation and
  312. should be in range 0 to 1.
  313. max_rotate_angle (int | float): The maximum angles for rotate
  314. transformation.
  315. random_negative_prob (float): The probability that turns the
  316. offset negative.
  317. """
  318. def __init__(self,
  319. level,
  320. scale=1,
  321. center=None,
  322. img_fill_val=128,
  323. seg_ignore_label=255,
  324. prob=0.5,
  325. max_rotate_angle=30,
  326. random_negative_prob=0.5):
  327. assert isinstance(level, (int, float)), \
  328. f'The level must be type int or float. got {type(level)}.'
  329. assert 0 <= level <= _MAX_LEVEL, \
  330. f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.'
  331. assert isinstance(scale, (int, float)), \
  332. f'The scale must be type int or float. got type {type(scale)}.'
  333. if isinstance(center, (int, float)):
  334. center = (center, center)
  335. elif isinstance(center, tuple):
  336. assert len(center) == 2, 'center with type tuple must have '\
  337. f'2 elements. got {len(center)} elements.'
  338. else:
  339. assert center is None, 'center must be None or type int, '\
  340. f'float or tuple, got type {type(center)}.'
  341. if isinstance(img_fill_val, (float, int)):
  342. img_fill_val = tuple([float(img_fill_val)] * 3)
  343. elif isinstance(img_fill_val, tuple):
  344. assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\
  345. f'have 3 elements. got {len(img_fill_val)}.'
  346. img_fill_val = tuple([float(val) for val in img_fill_val])
  347. else:
  348. raise ValueError(
  349. 'img_fill_val must be float or tuple with 3 elements.')
  350. assert np.all([0 <= val <= 255 for val in img_fill_val]), \
  351. 'all elements of img_fill_val should between range [0,255]. '\
  352. f'got {img_fill_val}.'
  353. assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\
  354. 'got {prob}.'
  355. assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\
  356. f'should be type int or float. got type {type(max_rotate_angle)}.'
  357. self.level = level
  358. self.scale = scale
  359. # Rotation angle in degrees. Positive values mean
  360. # clockwise rotation.
  361. self.angle = level_to_value(level, max_rotate_angle)
  362. self.center = center
  363. self.img_fill_val = img_fill_val
  364. self.seg_ignore_label = seg_ignore_label
  365. self.prob = prob
  366. self.max_rotate_angle = max_rotate_angle
  367. self.random_negative_prob = random_negative_prob
  368. def _rotate_img(self, results, angle, center=None, scale=1.0):
  369. """Rotate the image.
  370. Args:
  371. results (dict): Result dict from loading pipeline.
  372. angle (float): Rotation angle in degrees, positive values
  373. mean clockwise rotation. Same in ``mmcv.imrotate``.
  374. center (tuple[float], optional): Center point (w, h) of the
  375. rotation. Same in ``mmcv.imrotate``.
  376. scale (int | float): Isotropic scale factor. Same in
  377. ``mmcv.imrotate``.
  378. """
  379. for key in results.get('img_fields', ['img']):
  380. img = results[key].copy()
  381. img_rotated = mmcv.imrotate(
  382. img, angle, center, scale, border_value=self.img_fill_val)
  383. results[key] = img_rotated.astype(img.dtype)
  384. results['img_shape'] = results[key].shape
  385. def _rotate_bboxes(self, results, rotate_matrix):
  386. """Rotate the bboxes."""
  387. h, w, c = results['img_shape']
  388. for key in results.get('bbox_fields', []):
  389. min_x, min_y, max_x, max_y = np.split(
  390. results[key], results[key].shape[-1], axis=-1)
  391. coordinates = np.stack([[min_x, min_y], [max_x, min_y],
  392. [min_x, max_y],
  393. [max_x, max_y]]) # [4, 2, nb_bbox, 1]
  394. # pad 1 to convert from format [x, y] to homogeneous
  395. # coordinates format [x, y, 1]
  396. coordinates = np.concatenate(
  397. (coordinates,
  398. np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)),
  399. axis=1) # [4, 3, nb_bbox, 1]
  400. coordinates = coordinates.transpose(
  401. (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1]
  402. rotated_coords = np.matmul(rotate_matrix,
  403. coordinates) # [nb_bbox, 4, 2, 1]
  404. rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2]
  405. min_x, min_y = np.min(
  406. rotated_coords[:, :, 0], axis=1), np.min(
  407. rotated_coords[:, :, 1], axis=1)
  408. max_x, max_y = np.max(
  409. rotated_coords[:, :, 0], axis=1), np.max(
  410. rotated_coords[:, :, 1], axis=1)
  411. min_x, min_y = np.clip(
  412. min_x, a_min=0, a_max=w), np.clip(
  413. min_y, a_min=0, a_max=h)
  414. max_x, max_y = np.clip(
  415. max_x, a_min=min_x, a_max=w), np.clip(
  416. max_y, a_min=min_y, a_max=h)
  417. results[key] = np.stack([min_x, min_y, max_x, max_y],
  418. axis=-1).astype(results[key].dtype)
  419. def _rotate_masks(self,
  420. results,
  421. angle,
  422. center=None,
  423. scale=1.0,
  424. fill_val=0):
  425. """Rotate the masks."""
  426. h, w, c = results['img_shape']
  427. for key in results.get('mask_fields', []):
  428. masks = results[key]
  429. results[key] = masks.rotate((h, w), angle, center, scale, fill_val)
  430. def _rotate_seg(self,
  431. results,
  432. angle,
  433. center=None,
  434. scale=1.0,
  435. fill_val=255):
  436. """Rotate the segmentation map."""
  437. for key in results.get('seg_fields', []):
  438. seg = results[key].copy()
  439. results[key] = mmcv.imrotate(
  440. seg, angle, center, scale,
  441. border_value=fill_val).astype(seg.dtype)
  442. def _filter_invalid(self, results, min_bbox_size=0):
  443. """Filter bboxes and corresponding masks too small after rotate
  444. augmentation."""
  445. bbox2label, bbox2mask, _ = bbox2fields()
  446. for key in results.get('bbox_fields', []):
  447. bbox_w = results[key][:, 2] - results[key][:, 0]
  448. bbox_h = results[key][:, 3] - results[key][:, 1]
  449. valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
  450. valid_inds = np.nonzero(valid_inds)[0]
  451. results[key] = results[key][valid_inds]
  452. # label fields. e.g. gt_labels and gt_labels_ignore
  453. label_key = bbox2label.get(key)
  454. if label_key in results:
  455. results[label_key] = results[label_key][valid_inds]
  456. # mask fields, e.g. gt_masks and gt_masks_ignore
  457. mask_key = bbox2mask.get(key)
  458. if mask_key in results:
  459. results[mask_key] = results[mask_key][valid_inds]
  460. def __call__(self, results):
  461. """Call function to rotate images, bounding boxes, masks and semantic
  462. segmentation maps.
  463. Args:
  464. results (dict): Result dict from loading pipeline.
  465. Returns:
  466. dict: Rotated results.
  467. """
  468. if np.random.rand() > self.prob:
  469. return results
  470. h, w = results['img'].shape[:2]
  471. center = self.center
  472. if center is None:
  473. center = ((w - 1) * 0.5, (h - 1) * 0.5)
  474. angle = random_negative(self.angle, self.random_negative_prob)
  475. self._rotate_img(results, angle, center, self.scale)
  476. rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale)
  477. self._rotate_bboxes(results, rotate_matrix)
  478. self._rotate_masks(results, angle, center, self.scale, fill_val=0)
  479. self._rotate_seg(
  480. results, angle, center, self.scale, fill_val=self.seg_ignore_label)
  481. self._filter_invalid(results)
  482. return results
  483. def __repr__(self):
  484. repr_str = self.__class__.__name__
  485. repr_str += f'(level={self.level}, '
  486. repr_str += f'scale={self.scale}, '
  487. repr_str += f'center={self.center}, '
  488. repr_str += f'img_fill_val={self.img_fill_val}, '
  489. repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
  490. repr_str += f'prob={self.prob}, '
  491. repr_str += f'max_rotate_angle={self.max_rotate_angle}, '
  492. repr_str += f'random_negative_prob={self.random_negative_prob})'
  493. return repr_str
  494. @PIPELINES.register_module()
  495. class Translate:
  496. """Translate the images, bboxes, masks and segmentation maps horizontally
  497. or vertically.
  498. Args:
  499. level (int | float): The level for Translate and should be in
  500. range [0,_MAX_LEVEL].
  501. prob (float): The probability for performing translation and
  502. should be in range [0, 1].
  503. img_fill_val (int | float | tuple): The filled value for image
  504. border. If float, the same fill value will be used for all
  505. the three channels of image. If tuple, the should be 3
  506. elements (e.g. equals the number of channels for image).
  507. seg_ignore_label (int): The fill value used for segmentation map.
  508. Note this value must equals ``ignore_label`` in ``semantic_head``
  509. of the corresponding config. Default 255.
  510. direction (str): The translate direction, either "horizontal"
  511. or "vertical".
  512. max_translate_offset (int | float): The maximum pixel's offset for
  513. Translate.
  514. random_negative_prob (float): The probability that turns the
  515. offset negative.
  516. min_size (int | float): The minimum pixel for filtering
  517. invalid bboxes after the translation.
  518. """
  519. def __init__(self,
  520. level,
  521. prob=0.5,
  522. img_fill_val=128,
  523. seg_ignore_label=255,
  524. direction='horizontal',
  525. max_translate_offset=250.,
  526. random_negative_prob=0.5,
  527. min_size=0):
  528. assert isinstance(level, (int, float)), \
  529. 'The level must be type int or float.'
  530. assert 0 <= level <= _MAX_LEVEL, \
  531. 'The level used for calculating Translate\'s offset should be ' \
  532. 'in range [0,_MAX_LEVEL]'
  533. assert 0 <= prob <= 1.0, \
  534. 'The probability of translation should be in range [0, 1].'
  535. if isinstance(img_fill_val, (float, int)):
  536. img_fill_val = tuple([float(img_fill_val)] * 3)
  537. elif isinstance(img_fill_val, tuple):
  538. assert len(img_fill_val) == 3, \
  539. 'img_fill_val as tuple must have 3 elements.'
  540. img_fill_val = tuple([float(val) for val in img_fill_val])
  541. else:
  542. raise ValueError('img_fill_val must be type float or tuple.')
  543. assert np.all([0 <= val <= 255 for val in img_fill_val]), \
  544. 'all elements of img_fill_val should between range [0,255].'
  545. assert direction in ('horizontal', 'vertical'), \
  546. 'direction should be "horizontal" or "vertical".'
  547. assert isinstance(max_translate_offset, (int, float)), \
  548. 'The max_translate_offset must be type int or float.'
  549. # the offset used for translation
  550. self.offset = int(level_to_value(level, max_translate_offset))
  551. self.level = level
  552. self.prob = prob
  553. self.img_fill_val = img_fill_val
  554. self.seg_ignore_label = seg_ignore_label
  555. self.direction = direction
  556. self.max_translate_offset = max_translate_offset
  557. self.random_negative_prob = random_negative_prob
  558. self.min_size = min_size
  559. def _translate_img(self, results, offset, direction='horizontal'):
  560. """Translate the image.
  561. Args:
  562. results (dict): Result dict from loading pipeline.
  563. offset (int | float): The offset for translate.
  564. direction (str): The translate direction, either "horizontal"
  565. or "vertical".
  566. """
  567. for key in results.get('img_fields', ['img']):
  568. img = results[key].copy()
  569. results[key] = mmcv.imtranslate(
  570. img, offset, direction, self.img_fill_val).astype(img.dtype)
  571. results['img_shape'] = results[key].shape
  572. def _translate_bboxes(self, results, offset):
  573. """Shift bboxes horizontally or vertically, according to offset."""
  574. h, w, c = results['img_shape']
  575. for key in results.get('bbox_fields', []):
  576. min_x, min_y, max_x, max_y = np.split(
  577. results[key], results[key].shape[-1], axis=-1)
  578. if self.direction == 'horizontal':
  579. min_x = np.maximum(0, min_x + offset)
  580. max_x = np.minimum(w, max_x + offset)
  581. elif self.direction == 'vertical':
  582. min_y = np.maximum(0, min_y + offset)
  583. max_y = np.minimum(h, max_y + offset)
  584. # the boxes translated outside of image will be filtered along with
  585. # the corresponding masks, by invoking ``_filter_invalid``.
  586. results[key] = np.concatenate([min_x, min_y, max_x, max_y],
  587. axis=-1)
  588. def _translate_masks(self,
  589. results,
  590. offset,
  591. direction='horizontal',
  592. fill_val=0):
  593. """Translate masks horizontally or vertically."""
  594. h, w, c = results['img_shape']
  595. for key in results.get('mask_fields', []):
  596. masks = results[key]
  597. results[key] = masks.translate((h, w), offset, direction, fill_val)
  598. def _translate_seg(self,
  599. results,
  600. offset,
  601. direction='horizontal',
  602. fill_val=255):
  603. """Translate segmentation maps horizontally or vertically."""
  604. for key in results.get('seg_fields', []):
  605. seg = results[key].copy()
  606. results[key] = mmcv.imtranslate(seg, offset, direction,
  607. fill_val).astype(seg.dtype)
  608. def _filter_invalid(self, results, min_size=0):
  609. """Filter bboxes and masks too small or translated out of image."""
  610. bbox2label, bbox2mask, _ = bbox2fields()
  611. for key in results.get('bbox_fields', []):
  612. bbox_w = results[key][:, 2] - results[key][:, 0]
  613. bbox_h = results[key][:, 3] - results[key][:, 1]
  614. valid_inds = (bbox_w > min_size) & (bbox_h > min_size)
  615. valid_inds = np.nonzero(valid_inds)[0]
  616. results[key] = results[key][valid_inds]
  617. # label fields. e.g. gt_labels and gt_labels_ignore
  618. label_key = bbox2label.get(key)
  619. if label_key in results:
  620. results[label_key] = results[label_key][valid_inds]
  621. # mask fields, e.g. gt_masks and gt_masks_ignore
  622. mask_key = bbox2mask.get(key)
  623. if mask_key in results:
  624. results[mask_key] = results[mask_key][valid_inds]
  625. return results
  626. def __call__(self, results):
  627. """Call function to translate images, bounding boxes, masks and
  628. semantic segmentation maps.
  629. Args:
  630. results (dict): Result dict from loading pipeline.
  631. Returns:
  632. dict: Translated results.
  633. """
  634. if np.random.rand() > self.prob:
  635. return results
  636. offset = random_negative(self.offset, self.random_negative_prob)
  637. self._translate_img(results, offset, self.direction)
  638. self._translate_bboxes(results, offset)
  639. # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks.
  640. self._translate_masks(results, offset, self.direction)
  641. # fill_val set to ``seg_ignore_label`` for the ignored value
  642. # of segmentation map.
  643. self._translate_seg(
  644. results, offset, self.direction, fill_val=self.seg_ignore_label)
  645. self._filter_invalid(results, min_size=self.min_size)
  646. return results
  647. @PIPELINES.register_module()
  648. class ColorTransform:
  649. """Apply Color transformation to image. The bboxes, masks, and
  650. segmentations are not modified.
  651. Args:
  652. level (int | float): Should be in range [0,_MAX_LEVEL].
  653. prob (float): The probability for performing Color transformation.
  654. """
  655. def __init__(self, level, prob=0.5):
  656. assert isinstance(level, (int, float)), \
  657. 'The level must be type int or float.'
  658. assert 0 <= level <= _MAX_LEVEL, \
  659. 'The level should be in range [0,_MAX_LEVEL].'
  660. assert 0 <= prob <= 1.0, \
  661. 'The probability should be in range [0,1].'
  662. self.level = level
  663. self.prob = prob
  664. self.factor = enhance_level_to_value(level)
  665. def _adjust_color_img(self, results, factor=1.0):
  666. """Apply Color transformation to image."""
  667. for key in results.get('img_fields', ['img']):
  668. # NOTE defaultly the image should be BGR format
  669. img = results[key]
  670. results[key] = mmcv.adjust_color(img, factor).astype(img.dtype)
  671. def __call__(self, results):
  672. """Call function for Color transformation.
  673. Args:
  674. results (dict): Result dict from loading pipeline.
  675. Returns:
  676. dict: Colored results.
  677. """
  678. if np.random.rand() > self.prob:
  679. return results
  680. self._adjust_color_img(results, self.factor)
  681. return results
  682. def __repr__(self):
  683. repr_str = self.__class__.__name__
  684. repr_str += f'(level={self.level}, '
  685. repr_str += f'prob={self.prob})'
  686. return repr_str
  687. @PIPELINES.register_module()
  688. class EqualizeTransform:
  689. """Apply Equalize transformation to image. The bboxes, masks and
  690. segmentations are not modified.
  691. Args:
  692. prob (float): The probability for performing Equalize transformation.
  693. """
  694. def __init__(self, prob=0.5):
  695. assert 0 <= prob <= 1.0, \
  696. 'The probability should be in range [0,1].'
  697. self.prob = prob
  698. def _imequalize(self, results):
  699. """Equalizes the histogram of one image."""
  700. for key in results.get('img_fields', ['img']):
  701. img = results[key]
  702. results[key] = mmcv.imequalize(img).astype(img.dtype)
  703. def __call__(self, results):
  704. """Call function for Equalize transformation.
  705. Args:
  706. results (dict): Results dict from loading pipeline.
  707. Returns:
  708. dict: Results after the transformation.
  709. """
  710. if np.random.rand() > self.prob:
  711. return results
  712. self._imequalize(results)
  713. return results
  714. def __repr__(self):
  715. repr_str = self.__class__.__name__
  716. repr_str += f'(prob={self.prob})'
  717. @PIPELINES.register_module()
  718. class BrightnessTransform:
  719. """Apply Brightness transformation to image. The bboxes, masks and
  720. segmentations are not modified.
  721. Args:
  722. level (int | float): Should be in range [0,_MAX_LEVEL].
  723. prob (float): The probability for performing Brightness transformation.
  724. """
  725. def __init__(self, level, prob=0.5):
  726. assert isinstance(level, (int, float)), \
  727. 'The level must be type int or float.'
  728. assert 0 <= level <= _MAX_LEVEL, \
  729. 'The level should be in range [0,_MAX_LEVEL].'
  730. assert 0 <= prob <= 1.0, \
  731. 'The probability should be in range [0,1].'
  732. self.level = level
  733. self.prob = prob
  734. self.factor = enhance_level_to_value(level)
  735. def _adjust_brightness_img(self, results, factor=1.0):
  736. """Adjust the brightness of image."""
  737. for key in results.get('img_fields', ['img']):
  738. img = results[key]
  739. results[key] = mmcv.adjust_brightness(img,
  740. factor).astype(img.dtype)
  741. def __call__(self, results):
  742. """Call function for Brightness transformation.
  743. Args:
  744. results (dict): Results dict from loading pipeline.
  745. Returns:
  746. dict: Results after the transformation.
  747. """
  748. if np.random.rand() > self.prob:
  749. return results
  750. self._adjust_brightness_img(results, self.factor)
  751. return results
  752. def __repr__(self):
  753. repr_str = self.__class__.__name__
  754. repr_str += f'(level={self.level}, '
  755. repr_str += f'prob={self.prob})'
  756. return repr_str
  757. @PIPELINES.register_module()
  758. class ContrastTransform:
  759. """Apply Contrast transformation to image. The bboxes, masks and
  760. segmentations are not modified.
  761. Args:
  762. level (int | float): Should be in range [0,_MAX_LEVEL].
  763. prob (float): The probability for performing Contrast transformation.
  764. """
  765. def __init__(self, level, prob=0.5):
  766. assert isinstance(level, (int, float)), \
  767. 'The level must be type int or float.'
  768. assert 0 <= level <= _MAX_LEVEL, \
  769. 'The level should be in range [0,_MAX_LEVEL].'
  770. assert 0 <= prob <= 1.0, \
  771. 'The probability should be in range [0,1].'
  772. self.level = level
  773. self.prob = prob
  774. self.factor = enhance_level_to_value(level)
  775. def _adjust_contrast_img(self, results, factor=1.0):
  776. """Adjust the image contrast."""
  777. for key in results.get('img_fields', ['img']):
  778. img = results[key]
  779. results[key] = mmcv.adjust_contrast(img, factor).astype(img.dtype)
  780. def __call__(self, results):
  781. """Call function for Contrast transformation.
  782. Args:
  783. results (dict): Results dict from loading pipeline.
  784. Returns:
  785. dict: Results after the transformation.
  786. """
  787. if np.random.rand() > self.prob:
  788. return results
  789. self._adjust_contrast_img(results, self.factor)
  790. return results
  791. def __repr__(self):
  792. repr_str = self.__class__.__name__
  793. repr_str += f'(level={self.level}, '
  794. repr_str += f'prob={self.prob})'
  795. return repr_str

No Description

Contributors (3)