You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform_gen.py 15 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # File: transformer.py
  4. import inspect
  5. import numpy as np
  6. import pprint
  7. import sys
  8. from abc import ABCMeta, abstractmethod
  9. from fvcore.transforms.transform import (
  10. BlendTransform,
  11. CropTransform,
  12. HFlipTransform,
  13. NoOpTransform,
  14. Transform,
  15. TransformList,
  16. )
  17. from PIL import Image
  18. from .transform import ExtentTransform, ResizeTransform
  19. __all__ = [
  20. "RandomBrightness",
  21. "RandomContrast",
  22. "RandomCrop",
  23. "RandomExtent",
  24. "RandomFlip",
  25. "RandomSaturation",
  26. "RandomLighting",
  27. "Resize",
  28. "ResizeShortestEdge",
  29. "TransformGen",
  30. "apply_transform_gens",
  31. ]
  32. def check_dtype(img):
  33. assert isinstance(img, np.ndarray), "[TransformGen] Needs an numpy array, but got a {}!".format(
  34. type(img)
  35. )
  36. assert not isinstance(img.dtype, np.integer) or (
  37. img.dtype == np.uint8
  38. ), "[TransformGen] Got image of type {}, use uint8 or floating points instead!".format(
  39. img.dtype
  40. )
  41. assert img.ndim in [2, 3], img.ndim
  42. class TransformGen(metaclass=ABCMeta):
  43. """
  44. TransformGen takes an image of type uint8 in range [0, 255], or
  45. floating point in range [0, 1] or [0, 255] as input.
  46. It creates a :class:`Transform` based on the given image, sometimes with randomness.
  47. The transform can then be used to transform images
  48. or other data (boxes, points, annotations, etc.) associated with it.
  49. The assumption made in this class
  50. is that the image itself is sufficient to instantiate a transform.
  51. When this assumption is not true, you need to create the transforms by your own.
  52. A list of `TransformGen` can be applied with :func:`apply_transform_gens`.
  53. """
  54. def _init(self, params=None):
  55. if params:
  56. for k, v in params.items():
  57. if k != "self" and not k.startswith("_"):
  58. setattr(self, k, v)
  59. @abstractmethod
  60. def get_transform(self, img):
  61. pass
  62. def _rand_range(self, low=1.0, high=None, size=None):
  63. """
  64. Uniform float random number between low and high.
  65. """
  66. if high is None:
  67. low, high = 0, low
  68. if size is None:
  69. size = []
  70. return np.random.uniform(low, high, size)
  71. def __repr__(self):
  72. """
  73. Produce something like:
  74. "MyTransformGen(field1={self.field1}, field2={self.field2})"
  75. """
  76. try:
  77. sig = inspect.signature(self.__init__)
  78. classname = type(self).__name__
  79. argstr = []
  80. for name, param in sig.parameters.items():
  81. assert (
  82. param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
  83. ), "The default __repr__ doesn't support *args or **kwargs"
  84. assert hasattr(self, name), (
  85. "Attribute {} not found! "
  86. "Default __repr__ only works if attributes match the constructor.".format(name)
  87. )
  88. attr = getattr(self, name)
  89. default = param.default
  90. if default is attr:
  91. continue
  92. argstr.append("{}={}".format(name, pprint.pformat(attr)))
  93. return "{}({})".format(classname, ", ".join(argstr))
  94. except AssertionError:
  95. return super().__repr__()
  96. __str__ = __repr__
  97. class RandomFlip(TransformGen):
  98. """
  99. Flip the image horizontally with the given probability.
  100. TODO Vertical flip to be implemented.
  101. """
  102. def __init__(self, prob=0.5):
  103. """
  104. Args:
  105. prob (float): probability of flip.
  106. """
  107. horiz, vert = True, False
  108. # TODO implement vertical flip when we need it
  109. super().__init__()
  110. if horiz and vert:
  111. raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
  112. if not horiz and not vert:
  113. raise ValueError("At least one of horiz or vert has to be True!")
  114. self._init(locals())
  115. def get_transform(self, img):
  116. _, w = img.shape[:2]
  117. do = self._rand_range() < self.prob
  118. if do:
  119. return HFlipTransform(w)
  120. else:
  121. return NoOpTransform()
  122. class Resize(TransformGen):
  123. """ Resize image to a target size"""
  124. def __init__(self, shape, interp=Image.BILINEAR):
  125. """
  126. Args:
  127. shape: (h, w) tuple or a int
  128. interp: PIL interpolation method
  129. """
  130. if isinstance(shape, int):
  131. shape = (shape, shape)
  132. shape = tuple(shape)
  133. self._init(locals())
  134. def get_transform(self, img):
  135. return ResizeTransform(
  136. img.shape[0], img.shape[1], self.shape[0], self.shape[1], self.interp
  137. )
  138. class ResizeShortestEdge(TransformGen):
  139. """
  140. Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
  141. If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
  142. """
  143. def __init__(
  144. self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
  145. ):
  146. """
  147. Args:
  148. short_edge_length (list[int]): If ``sample_style=="range"``,
  149. a [min, max] interval from which to sample the shortest edge length.
  150. If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
  151. max_size (int): maximum allowed longest edge length.
  152. sample_style (str): either "range" or "choice".
  153. """
  154. super().__init__()
  155. assert sample_style in ["range", "choice"], sample_style
  156. self.is_range = sample_style == "range"
  157. if isinstance(short_edge_length, int):
  158. short_edge_length = (short_edge_length, short_edge_length)
  159. self._init(locals())
  160. def get_transform(self, img):
  161. h, w = img.shape[:2]
  162. if self.is_range:
  163. size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
  164. else:
  165. size = np.random.choice(self.short_edge_length)
  166. if size == 0:
  167. return NoOpTransform()
  168. scale = size * 1.0 / min(h, w)
  169. if h < w:
  170. newh, neww = size, scale * w
  171. else:
  172. newh, neww = scale * h, size
  173. if max(newh, neww) > self.max_size:
  174. scale = self.max_size * 1.0 / max(newh, neww)
  175. newh = newh * scale
  176. neww = neww * scale
  177. neww = int(neww + 0.5)
  178. newh = int(newh + 0.5)
  179. return ResizeTransform(h, w, newh, neww, self.interp)
  180. class RandomCrop(TransformGen):
  181. """
  182. Randomly crop a subimage out of an image.
  183. """
  184. def __init__(self, crop_type: str, crop_size):
  185. """
  186. Args:
  187. crop_type (str): one of "relative_range", "relative", "absolute".
  188. See `config/defaults.py` for explanation.
  189. crop_size (tuple[float]): the relative ratio or absolute pixels of
  190. height and width
  191. """
  192. super().__init__()
  193. assert crop_type in ["relative_range", "relative", "absolute"]
  194. self._init(locals())
  195. def get_transform(self, img):
  196. h, w = img.shape[:2]
  197. croph, cropw = self.get_crop_size((h, w))
  198. assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
  199. h0 = np.random.randint(h - croph + 1)
  200. w0 = np.random.randint(w - cropw + 1)
  201. return CropTransform(w0, h0, cropw, croph)
  202. def get_crop_size(self, image_size):
  203. """
  204. Args:
  205. image_size (tuple): height, width
  206. Returns:
  207. crop_size (tuple): height, width in absolute pixels
  208. """
  209. h, w = image_size
  210. if self.crop_type == "relative":
  211. ch, cw = self.crop_size
  212. return int(h * ch + 0.5), int(w * cw + 0.5)
  213. elif self.crop_type == "relative_range":
  214. crop_size = np.asarray(self.crop_size, dtype=np.float32)
  215. ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
  216. return int(h * ch + 0.5), int(w * cw + 0.5)
  217. elif self.crop_type == "absolute":
  218. return self.crop_size
  219. else:
  220. NotImplementedError("Unknown crop type {}".format(self.crop_type))
  221. class RandomExtent(TransformGen):
  222. """
  223. Outputs an image by cropping a random "subrect" of the source image.
  224. The subrect can be parameterized to include pixels outside the source image,
  225. in which case they will be set to zeros (i.e. black). The size of the output
  226. image will vary with the size of the random subrect.
  227. """
  228. def __init__(self, scale_range, shift_range):
  229. """
  230. Args:
  231. output_size (h, w): Dimensions of output image
  232. scale_range (l, h): Range of input-to-output size scaling factor
  233. shift_range (x, y): Range of shifts of the cropped subrect. The rect
  234. is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
  235. where (w, h) is the (width, height) of the input image. Set each
  236. component to zero to crop at the image's center.
  237. """
  238. super().__init__()
  239. self._init(locals())
  240. def get_transform(self, img):
  241. img_h, img_w = img.shape[:2]
  242. # Initialize src_rect to fit the input image.
  243. src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
  244. # Apply a random scaling to the src_rect.
  245. src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
  246. # Apply a random shift to the coordinates origin.
  247. src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
  248. src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
  249. # Map src_rect coordinates into image coordinates (center at corner).
  250. src_rect[0::2] += 0.5 * img_w
  251. src_rect[1::2] += 0.5 * img_h
  252. return ExtentTransform(
  253. src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
  254. output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
  255. )
  256. class RandomContrast(TransformGen):
  257. """
  258. Randomly transforms image contrast.
  259. Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
  260. - intensity < 1 will reduce contrast
  261. - intensity = 1 will preserve the input image
  262. - intensity > 1 will increase contrast
  263. See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
  264. """
  265. def __init__(self, intensity_min, intensity_max):
  266. """
  267. Args:
  268. intensity_min (float): Minimum augmentation
  269. intensity_max (float): Maximum augmentation
  270. """
  271. super().__init__()
  272. self._init(locals())
  273. def get_transform(self, img):
  274. w = np.random.uniform(self.intensity_min, self.intensity_max)
  275. return BlendTransform(src_image=img.mean(), src_weight=1 - w, dst_weight=w)
  276. class RandomBrightness(TransformGen):
  277. """
  278. Randomly transforms image brightness.
  279. Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
  280. - intensity < 1 will reduce brightness
  281. - intensity = 1 will preserve the input image
  282. - intensity > 1 will increase brightness
  283. See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
  284. """
  285. def __init__(self, intensity_min, intensity_max):
  286. """
  287. Args:
  288. intensity_min (float): Minimum augmentation
  289. intensity_max (float): Maximum augmentation
  290. """
  291. super().__init__()
  292. self._init(locals())
  293. def get_transform(self, img):
  294. w = np.random.uniform(self.intensity_min, self.intensity_max)
  295. return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
  296. class RandomSaturation(TransformGen):
  297. """
  298. Randomly transforms image saturation.
  299. Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
  300. - intensity < 1 will reduce saturation (make the image more grayscale)
  301. - intensity = 1 will preserve the input image
  302. - intensity > 1 will increase saturation
  303. See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
  304. """
  305. def __init__(self, intensity_min, intensity_max):
  306. """
  307. Args:
  308. intensity_min (float): Minimum augmentation (1 preserves input).
  309. intensity_max (float): Maximum augmentation (1 preserves input).
  310. """
  311. super().__init__()
  312. self._init(locals())
  313. def get_transform(self, img):
  314. assert img.shape[-1] == 3, "Saturation only works on RGB images"
  315. w = np.random.uniform(self.intensity_min, self.intensity_max)
  316. grayscale = img.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
  317. return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
  318. class RandomLighting(TransformGen):
  319. """
  320. Randomly transforms image color using fixed PCA over ImageNet.
  321. The degree of color jittering is randomly sampled via a normal distribution,
  322. with standard deviation given by the scale parameter.
  323. """
  324. def __init__(self, scale):
  325. """
  326. Args:
  327. scale (float): Standard deviation of principal component weighting.
  328. """
  329. super().__init__()
  330. self._init(locals())
  331. self.eigen_vecs = np.array(
  332. [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
  333. )
  334. self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
  335. def get_transform(self, img):
  336. assert img.shape[-1] == 3, "Saturation only works on RGB images"
  337. weights = np.random.normal(scale=self.scale, size=3)
  338. return BlendTransform(
  339. src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
  340. )
  341. def apply_transform_gens(transform_gens, img):
  342. """
  343. Apply a list of :class:`TransformGen` on the input image, and
  344. returns the transformed image and a list of transforms.
  345. We cannot simply create and return all transforms without
  346. applying it to the image, because a subsequent transform may
  347. need the output of the previous one.
  348. Args:
  349. transform_gens (list): list of :class:`TransformGen` instance to
  350. be applied.
  351. img (ndarray): uint8 or floating point images with 1 or 3 channels.
  352. Returns:
  353. ndarray: the transformed image
  354. TransformList: contain the transforms that's used.
  355. """
  356. for g in transform_gens:
  357. assert isinstance(g, TransformGen), g
  358. check_dtype(img)
  359. tfms = []
  360. for g in transform_gens:
  361. tfm = g.get_transform(img)
  362. assert isinstance(
  363. tfm, Transform
  364. ), "TransformGen {} must return an instance of Transform! Got {} instead".format(g, tfm)
  365. img = tfm.apply_image(img)
  366. tfms.append(tfm)
  367. return img, TransformList(tfms)

No Description