# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # File: transformer.py import inspect import numpy as np import pprint import sys from abc import ABCMeta, abstractmethod from fvcore.transforms.transform import ( BlendTransform, CropTransform, HFlipTransform, NoOpTransform, Transform, TransformList, ) from PIL import Image from .transform import ExtentTransform, ResizeTransform __all__ = [ "RandomBrightness", "RandomContrast", "RandomCrop", "RandomExtent", "RandomFlip", "RandomSaturation", "RandomLighting", "Resize", "ResizeShortestEdge", "TransformGen", "apply_transform_gens", ] def check_dtype(img): assert isinstance(img, np.ndarray), "[TransformGen] Needs an numpy array, but got a {}!".format( type(img) ) assert not isinstance(img.dtype, np.integer) or ( img.dtype == np.uint8 ), "[TransformGen] Got image of type {}, use uint8 or floating points instead!".format( img.dtype ) assert img.ndim in [2, 3], img.ndim class TransformGen(metaclass=ABCMeta): """ TransformGen takes an image of type uint8 in range [0, 255], or floating point in range [0, 1] or [0, 255] as input. It creates a :class:`Transform` based on the given image, sometimes with randomness. The transform can then be used to transform images or other data (boxes, points, annotations, etc.) associated with it. The assumption made in this class is that the image itself is sufficient to instantiate a transform. When this assumption is not true, you need to create the transforms by your own. A list of `TransformGen` can be applied with :func:`apply_transform_gens`. """ def _init(self, params=None): if params: for k, v in params.items(): if k != "self" and not k.startswith("_"): setattr(self, k, v) @abstractmethod def get_transform(self, img): pass def _rand_range(self, low=1.0, high=None, size=None): """ Uniform float random number between low and high. """ if high is None: low, high = 0, low if size is None: size = [] return np.random.uniform(low, high, size) def __repr__(self): """ Produce something like: "MyTransformGen(field1={self.field1}, field2={self.field2})" """ try: sig = inspect.signature(self.__init__) classname = type(self).__name__ argstr = [] for name, param in sig.parameters.items(): assert ( param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD ), "The default __repr__ doesn't support *args or **kwargs" assert hasattr(self, name), ( "Attribute {} not found! " "Default __repr__ only works if attributes match the constructor.".format(name) ) attr = getattr(self, name) default = param.default if default is attr: continue argstr.append("{}={}".format(name, pprint.pformat(attr))) return "{}({})".format(classname, ", ".join(argstr)) except AssertionError: return super().__repr__() __str__ = __repr__ class RandomFlip(TransformGen): """ Flip the image horizontally with the given probability. TODO Vertical flip to be implemented. """ def __init__(self, prob=0.5): """ Args: prob (float): probability of flip. """ horiz, vert = True, False # TODO implement vertical flip when we need it super().__init__() if horiz and vert: raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") if not horiz and not vert: raise ValueError("At least one of horiz or vert has to be True!") self._init(locals()) def get_transform(self, img): _, w = img.shape[:2] do = self._rand_range() < self.prob if do: return HFlipTransform(w) else: return NoOpTransform() class Resize(TransformGen): """ Resize image to a target size""" def __init__(self, shape, interp=Image.BILINEAR): """ Args: shape: (h, w) tuple or a int interp: PIL interpolation method """ if isinstance(shape, int): shape = (shape, shape) shape = tuple(shape) self._init(locals()) def get_transform(self, img): return ResizeTransform( img.shape[0], img.shape[1], self.shape[0], self.shape[1], self.interp ) class ResizeShortestEdge(TransformGen): """ Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge. If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. """ def __init__( self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR ): """ Args: short_edge_length (list[int]): If ``sample_style=="range"``, a [min, max] interval from which to sample the shortest edge length. If ``sample_style=="choice"``, a list of shortest edge lengths to sample from. max_size (int): maximum allowed longest edge length. sample_style (str): either "range" or "choice". """ super().__init__() assert sample_style in ["range", "choice"], sample_style self.is_range = sample_style == "range" if isinstance(short_edge_length, int): short_edge_length = (short_edge_length, short_edge_length) self._init(locals()) def get_transform(self, img): h, w = img.shape[:2] if self.is_range: size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1) else: size = np.random.choice(self.short_edge_length) if size == 0: return NoOpTransform() scale = size * 1.0 / min(h, w) if h < w: newh, neww = size, scale * w else: newh, neww = scale * h, size if max(newh, neww) > self.max_size: scale = self.max_size * 1.0 / max(newh, neww) newh = newh * scale neww = neww * scale neww = int(neww + 0.5) newh = int(newh + 0.5) return ResizeTransform(h, w, newh, neww, self.interp) class RandomCrop(TransformGen): """ Randomly crop a subimage out of an image. """ def __init__(self, crop_type: str, crop_size): """ Args: crop_type (str): one of "relative_range", "relative", "absolute". See `config/defaults.py` for explanation. crop_size (tuple[float]): the relative ratio or absolute pixels of height and width """ super().__init__() assert crop_type in ["relative_range", "relative", "absolute"] self._init(locals()) def get_transform(self, img): h, w = img.shape[:2] croph, cropw = self.get_crop_size((h, w)) assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self) h0 = np.random.randint(h - croph + 1) w0 = np.random.randint(w - cropw + 1) return CropTransform(w0, h0, cropw, croph) def get_crop_size(self, image_size): """ Args: image_size (tuple): height, width Returns: crop_size (tuple): height, width in absolute pixels """ h, w = image_size if self.crop_type == "relative": ch, cw = self.crop_size return int(h * ch + 0.5), int(w * cw + 0.5) elif self.crop_type == "relative_range": crop_size = np.asarray(self.crop_size, dtype=np.float32) ch, cw = crop_size + np.random.rand(2) * (1 - crop_size) return int(h * ch + 0.5), int(w * cw + 0.5) elif self.crop_type == "absolute": return self.crop_size else: NotImplementedError("Unknown crop type {}".format(self.crop_type)) class RandomExtent(TransformGen): """ Outputs an image by cropping a random "subrect" of the source image. The subrect can be parameterized to include pixels outside the source image, in which case they will be set to zeros (i.e. black). The size of the output image will vary with the size of the random subrect. """ def __init__(self, scale_range, shift_range): """ Args: output_size (h, w): Dimensions of output image scale_range (l, h): Range of input-to-output size scaling factor shift_range (x, y): Range of shifts of the cropped subrect. The rect is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)], where (w, h) is the (width, height) of the input image. Set each component to zero to crop at the image's center. """ super().__init__() self._init(locals()) def get_transform(self, img): img_h, img_w = img.shape[:2] # Initialize src_rect to fit the input image. src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h]) # Apply a random scaling to the src_rect. src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1]) # Apply a random shift to the coordinates origin. src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5) src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5) # Map src_rect coordinates into image coordinates (center at corner). src_rect[0::2] += 0.5 * img_w src_rect[1::2] += 0.5 * img_h return ExtentTransform( src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]), output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])), ) class RandomContrast(TransformGen): """ Randomly transforms image contrast. Contrast intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce contrast - intensity = 1 will preserve the input image - intensity > 1 will increase contrast See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html """ def __init__(self, intensity_min, intensity_max): """ Args: intensity_min (float): Minimum augmentation intensity_max (float): Maximum augmentation """ super().__init__() self._init(locals()) def get_transform(self, img): w = np.random.uniform(self.intensity_min, self.intensity_max) return BlendTransform(src_image=img.mean(), src_weight=1 - w, dst_weight=w) class RandomBrightness(TransformGen): """ Randomly transforms image brightness. Brightness intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce brightness - intensity = 1 will preserve the input image - intensity > 1 will increase brightness See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html """ def __init__(self, intensity_min, intensity_max): """ Args: intensity_min (float): Minimum augmentation intensity_max (float): Maximum augmentation """ super().__init__() self._init(locals()) def get_transform(self, img): w = np.random.uniform(self.intensity_min, self.intensity_max) return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w) class RandomSaturation(TransformGen): """ Randomly transforms image saturation. Saturation intensity is uniformly sampled in (intensity_min, intensity_max). - intensity < 1 will reduce saturation (make the image more grayscale) - intensity = 1 will preserve the input image - intensity > 1 will increase saturation See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html """ def __init__(self, intensity_min, intensity_max): """ Args: intensity_min (float): Minimum augmentation (1 preserves input). intensity_max (float): Maximum augmentation (1 preserves input). """ super().__init__() self._init(locals()) def get_transform(self, img): assert img.shape[-1] == 3, "Saturation only works on RGB images" w = np.random.uniform(self.intensity_min, self.intensity_max) grayscale = img.dot([0.299, 0.587, 0.114])[:, :, np.newaxis] return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w) class RandomLighting(TransformGen): """ Randomly transforms image color using fixed PCA over ImageNet. The degree of color jittering is randomly sampled via a normal distribution, with standard deviation given by the scale parameter. """ def __init__(self, scale): """ Args: scale (float): Standard deviation of principal component weighting. """ super().__init__() self._init(locals()) self.eigen_vecs = np.array( [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] ) self.eigen_vals = np.array([0.2175, 0.0188, 0.0045]) def get_transform(self, img): assert img.shape[-1] == 3, "Saturation only works on RGB images" weights = np.random.normal(scale=self.scale, size=3) return BlendTransform( src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0 ) def apply_transform_gens(transform_gens, img): """ Apply a list of :class:`TransformGen` on the input image, and returns the transformed image and a list of transforms. We cannot simply create and return all transforms without applying it to the image, because a subsequent transform may need the output of the previous one. Args: transform_gens (list): list of :class:`TransformGen` instance to be applied. img (ndarray): uint8 or floating point images with 1 or 3 channels. Returns: ndarray: the transformed image TransformList: contain the transforms that's used. """ for g in transform_gens: assert isinstance(g, TransformGen), g check_dtype(img) tfms = [] for g in transform_gens: tfm = g.get_transform(img) assert isinstance( tfm, Transform ), "TransformGen {} must return an instance of Transform! Got {} instead".format(g, tfm) img = tfm.apply_image(img) tfms.append(tfm) return img, TransformList(tfms)