#! /usr/bin/python # -*- coding: utf-8 -*- import numpy as np from six.moves import xrange import tensorlayer as tl from tensorlayer import logging from tensorlayer.layers.core import Module __all__ = [ 'transformer', 'batch_transformer', 'SpatialTransformer2dAffine', ] def transformer(U, theta, out_size, name='SpatialTransformer2dAffine'): """Spatial Transformer Layer for `2D Affine Transformation `__ , see :class:`SpatialTransformer2dAffine` class. Parameters ---------- U : list of float The output of a convolutional net should have the shape [num_batch, height, width, num_channels]. theta: float The output of the localisation network should be [num_batch, 6], value range should be [0, 1] (via tanh). out_size: tuple of int The size of the output of the network (height, width) name: str Optional function name Returns ------- Tensor The transformed tensor. References ---------- - `Spatial Transformer Networks `__ - `TensorFlow/Models `__ Notes ----- To initialize the network to the identity transform init. >>> import tensorflow as tf >>> # ``theta`` to >>> identity = np.array([[1., 0., 0.], [0., 1., 0.]]) >>> identity = identity.flatten() >>> theta = tf.Variable(initial_value=identity) """ def _repeat(x, n_repeats): rep = tl.transpose(a=tl.expand_dims(tl.ones(shape=tl.stack([ n_repeats, ])), axis=1), perm=[1, 0]) rep = tl.cast(rep, 'int32') x = tl.matmul(tl.reshape(x, (-1, 1)), rep) return tl.reshape(x, [-1]) def _interpolate(im, x, y, out_size): # constants num_batch, height, width, channels = tl.get_tensor_shape(im) x = tl.cast(x, 'float32') y = tl.cast(y, 'float32') height_f = tl.cast(height, 'float32') width_f = tl.cast(width, 'float32') out_height = out_size[0] out_width = out_size[1] zero = tl.zeros([], dtype='int32') max_y = tl.cast(height - 1, 'int32') max_x = tl.cast(width - 1, 'int32') # scale indices from [-1, 1] to [0, width/height] x = (x + 1.0) * (width_f) / 2.0 y = (y + 1.0) * (height_f) / 2.0 # do sampling x0 = tl.cast(tl.floor(x), 'int32') x1 = x0 + 1 y0 = tl.cast(tl.floor(y), 'int32') y1 = y0 + 1 x0 = tl.clip_by_value(x0, zero, max_x) x1 = tl.clip_by_value(x1, zero, max_x) y0 = tl.clip_by_value(y0, zero, max_y) y1 = tl.clip_by_value(y1, zero, max_y) dim2 = width dim1 = width * height base = _repeat(tl.range(num_batch) * dim1, out_height * out_width) base_y0 = base + y0 * dim2 base_y1 = base + y1 * dim2 idx_a = base_y0 + x0 idx_b = base_y1 + x0 idx_c = base_y0 + x1 idx_d = base_y1 + x1 # use indices to lookup pixels in the flat image and restore # channels dim im_flat = tl.reshape(im, tl.stack([-1, channels])) im_flat = tl.cast(im_flat, 'float32') Ia = tl.gather(im_flat, idx_a) Ib = tl.gather(im_flat, idx_b) Ic = tl.gather(im_flat, idx_c) Id = tl.gather(im_flat, idx_d) # and finally calculate interpolated values x0_f = tl.cast(x0, 'float32') x1_f = tl.cast(x1, 'float32') y0_f = tl.cast(y0, 'float32') y1_f = tl.cast(y1, 'float32') wa = tl.expand_dims(((x1_f - x) * (y1_f - y)), 1) wb = tl.expand_dims(((x1_f - x) * (y - y0_f)), 1) wc = tl.expand_dims(((x - x0_f) * (y1_f - y)), 1) wd = tl.expand_dims(((x - x0_f) * (y - y0_f)), 1) output = tl.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) return output def _meshgrid(height, width): # This should be equivalent to: # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), # np.linspace(-1, 1, height)) # ones = np.ones(np.prod(x_t.shape)) # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) x_t = tl.matmul( tl.ones(shape=tl.stack([height, 1])), tl.transpose(a=tl.expand_dims(tl.linspace(-1.0, 1.0, width), 1), perm=[1, 0]) ) y_t = tl.matmul(tl.expand_dims(tl.linspace(-1.0, 1.0, height), 1), tl.ones(shape=tl.stack([1, width]))) x_t_flat = tl.reshape(x_t, (1, -1)) y_t_flat = tl.reshape(y_t, (1, -1)) ones = tl.ones(shape=tl.get_tensor_shape(x_t_flat)) grid = tl.concat(axis=0, values=[x_t_flat, y_t_flat, ones]) return grid def _transform(theta, input_dim, out_size): num_batch, _, _, num_channels = tl.get_tensor_shape(input_dim) theta = tl.reshape(theta, (-1, 2, 3)) theta = tl.cast(theta, 'float32') # grid of (x_t, y_t, 1), eq (1) in ref [1] out_height = out_size[0] out_width = out_size[1] grid = _meshgrid(out_height, out_width) grid = tl.expand_dims(grid, 0) grid = tl.reshape(grid, [-1]) grid = tl.tile(grid, tl.stack([num_batch])) grid = tl.reshape(grid, tl.stack([num_batch, 3, -1])) # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) T_g = tl.matmul(theta, grid) x_s = tl.slice(T_g, [0, 0, 0], [-1, 1, -1]) y_s = tl.slice(T_g, [0, 1, 0], [-1, 1, -1]) x_s_flat = tl.reshape(x_s, [-1]) y_s_flat = tl.reshape(y_s, [-1]) input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size) output = tl.reshape(input_transformed, tl.stack([num_batch, out_height, out_width, num_channels])) return output output = _transform(theta, U, out_size) return output def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer2dAffine'): """Batch Spatial Transformer function for `2D Affine Transformation `__. Parameters ---------- U : list of float tensor of inputs [batch, height, width, num_channels] thetas : list of float a set of transformations for each input [batch, num_transforms, 6] out_size : list of int the size of the output [out_height, out_width] name : str optional function name Returns ------ float Tensor of size [batch * num_transforms, out_height, out_width, num_channels] """ # with tf.compat.v1.variable_scope(name): num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) indices = [[i] * num_transforms for i in xrange(num_batch)] input_repeated = tl.gather(U, tl.reshape(indices, [-1])) return transformer(input_repeated, thetas, out_size) class SpatialTransformer2dAffine(Module): """The :class:`SpatialTransformer2dAffine` class is a 2D `Spatial Transformer Layer `__ for `2D Affine Transformation `__. Parameters ----------- out_size : tuple of int or None - The size of the output of the network (height, width), the feature maps will be resized by this. in_channels : int The number of in channels. data_format : str "channel_last" (NHWC, default) or "channels_first" (NCHW). name : str - A unique layer name. References ----------- - `Spatial Transformer Networks `__ - `TensorFlow/Models `__ """ def __init__( self, out_size=(40, 40), in_channels=None, data_format='channel_last', name=None, ): super(SpatialTransformer2dAffine, self).__init__(name) self.in_channels = in_channels self.out_size = out_size self.data_format = data_format if self.in_channels is not None: self.build(self.in_channels) self._built = True logging.info("SpatialTransformer2dAffine %s" % self.name) def __repr__(self): s = '{classname}(out_size={out_size}, ' if self.in_channels is not None: s += 'in_channels=\'{in_channels}\'' if self.name is not None: s += ', name=\'{name}\'' s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) def build(self, inputs_shape): if self.in_channels is None and len(inputs_shape) != 2: raise AssertionError("The dimension of theta layer input must be rank 2, please reshape or flatten it") if self.in_channels: shape = [self.in_channels, 6] else: # self.in_channels = inputs_shape[1] # BUG # shape = [inputs_shape[1], 6] self.in_channels = inputs_shape[0][-1] # zsdonghao shape = [self.in_channels, 6] self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros()) identity = np.reshape(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), newshape=(6, )) self.b = self._get_weights("biases", shape=(6, ), init=tl.initializers.Constant(identity)) def forward(self, inputs): """ :param inputs: a tuple (theta_input, U). - theta_input is of size [batch, in_channels]. We will use a :class:`Dense` to make the theta size to [batch, 6], value range to [0, 1] (via tanh). - U is the previous layer, which the affine transformation is applied to. :return: tensor of size [batch, out_size[0], out_size[1], n_channels] after affine transformation, n_channels is identical to that of U. """ theta_input, U = inputs theta = tl.tanh(tl.matmul(theta_input, self.W) + self.b) outputs = transformer(U, theta, out_size=self.out_size) # automatically set batch_size and channels # e.g. [?, 40, 40, ?] --> [64, 40, 40, 1] or [64, 20, 20, 4] batch_size = theta_input.shape[0] n_channels = U.shape[-1] if self.data_format == 'channel_last': outputs = tl.reshape(outputs, shape=[batch_size, self.out_size[0], self.out_size[1], n_channels]) else: raise Exception("unimplement data_format {}".format(self.data_format)) return outputs