|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """STN module"""
- import numpy as np
- import mindspore
- from mindspore import Tensor
- from mindspore.ops import operations as P
- from mindspore.ops import composite as C
- import mindspore.nn as nn
-
- class STN(nn.Cell):
- '''STN'''
- def __init__(self, H, W):
- super(STN, self).__init__()
- batch_size = 1
- x = np.linspace(-1.0, 1.0, H)
- y = np.linspace(-1.0, 1.0, W)
- x_t, y_t = np.meshgrid(x, y)
- x_t = Tensor(x_t, mindspore.float32)
- y_t = Tensor(y_t, mindspore.float32)
- expand_dims = P.ExpandDims()
- x_t = expand_dims(x_t, 0)
- y_t = expand_dims(y_t, 0)
- flatten = P.Flatten()
- x_t_flat = flatten(x_t)
- y_t_flat = flatten(y_t)
- oneslike = P.OnesLike()
- ones = oneslike(x_t_flat)
- concat = P.Concat()
- sampling_grid = concat((x_t_flat, y_t_flat, ones))
- self.sampling_grid = expand_dims(sampling_grid, 0)
-
- batch_size = 128
- batch_idx = np.arange(batch_size)
- batch_idx = batch_idx.reshape((batch_size, 1, 1))
- self.batch_idx = Tensor(batch_idx, mindspore.float32)
- self.zero = Tensor(np.zeros([]), mindspore.float32)
-
-
- def get_pixel_value(self, img, x, y):
- """
- Utility function to get pixel value for coordinate
- vectors x and y from a 4D tensor image.
-
- Input
- -----
- - img: tensor of shape (B, H, W, C)
- - x: flattened tensor of shape (B*H*W,)
- - y: flattened tensor of shape (B*H*W,)
-
- Returns
- -------
- - output: tensor of shape (B, H, W, C)
- """
- shape = P.Shape()
- img_shape = shape(x)
- batch_size = img_shape[0]
- height = img_shape[1]
- width = img_shape[2]
- img[:, 0, :, :] = self.zero
- img[:, height-1, :, :] = self.zero
- img[:, :, 0, :] = self.zero
- img[:, :, width-1, :] = self.zero
-
- tile = P.Tile()
- batch_idx = P.Slice()(self.batch_idx, (0, 0, 0), (batch_size, 1, 1))
- b = tile(batch_idx, (1, height, width))
-
- expand_dims = P.ExpandDims()
- b = expand_dims(b, 3)
- x = expand_dims(x, 3)
- y = expand_dims(y, 3)
-
- concat = P.Concat(3)
- indices = concat((b, y, x))
- cast = P.Cast()
- indices = cast(indices, mindspore.int32)
- gather_nd = P.GatherNd()
-
- return cast(gather_nd(img, indices), mindspore.float32)
-
-
- def affine_grid_generator(self, height, width, theta):
- """
- This function returns a sampling grid, which when
- used with the bilinear sampler on the input feature
- map, will create an output feature map that is an
- affine transformation [1] of the input feature map.
-
- zero = Tensor(np.zeros([]), mindspore.float32)
- Input
- -----
- - height: desired height of grid/output. Used
- to downsample or upsample.
-
- - width: desired width of grid/output. Used
- to downsample or upsample.
-
- - theta: affine transform matrices of shape (num_batch, 2, 3).
- For each image in the batch, we have 6 theta parameters of
- the form (2x3) that define the affine transformation T.
-
- Returns
- -------
- - normalized grid (-1, 1) of shape (num_batch, 2, H, W).
- The 2nd dimension has 2 components: (x, y) which are the
- sampling points of the original image for each point in the
- target image.
-
- Note
- ----
- [1]: the affine transformation allows cropping, translation,
- and isotropic scaling.
- """
- shape = P.Shape()
- num_batch = shape(theta)[0]
-
- cast = P.Cast()
- theta = cast(theta, mindspore.float32)
-
- # transform the sampling grid - batch multiply
- matmul = P.BatchMatMul()
- tile = P.Tile()
- sampling_grid = tile(self.sampling_grid, (num_batch, 1, 1))
- cast = P.Cast()
- sampling_grid = cast(sampling_grid, mindspore.float32)
-
- batch_grids = matmul(theta, sampling_grid)
- # batch grid has shape (num_batch, 2, H*W)
-
- # reshape to (num_batch, H, W, 2)
- reshape = P.Reshape()
- batch_grids = reshape(batch_grids, (num_batch, 2, height, width))
- return batch_grids
-
-
- def bilinear_sampler(self, img, x, y):
- """
- Performs bilinear sampling of the input images according to the
- normalized coordinates provided by the sampling grid. Note that
- the sampling is done identically for each channel of the input.
-
- To test if the function works properly, output image should be
- identical to input image when theta is initialized to identity
- transform.
-
- Input
- -----
- - img: batch of images in (B, H, W, C) layout.
- - grid: x, y which is the output of affine_grid_generator.
-
- Returns
- -------
- - out: interpolated images according to grids. Same size as grid.
- """
- shape = P.Shape()
- H = shape(img)[1]
- W = shape(img)[2]
- cast = P.Cast()
- max_y = cast(H - 1, mindspore.float32)
- max_x = cast(W - 1, mindspore.float32)
- zero = self.zero
-
- # rescale x and y to [0, W-1/H-1]
- x = 0.5 * ((x + 1.0) * (max_x-1))
- y = 0.5 * ((y + 1.0) * (max_y-1))
-
- # grab 4 nearest corner points for each (x_i, y_i)
- floor = P.Floor()
- x0 = floor(x)
- x1 = x0 + 1
- y0 = floor(y)
- y1 = y0 + 1
-
- # clip to range [0, H-1/W-1] to not violate img boundaries
- x0 = C.clip_by_value(x0, zero, max_x)
- x1 = C.clip_by_value(x1, zero, max_x)
- y0 = C.clip_by_value(y0, zero, max_y)
- y1 = C.clip_by_value(y1, zero, max_y)
-
- # get pixel value at corner coords
- Ia = self.get_pixel_value(img, x0, y0)
- Ib = self.get_pixel_value(img, x0, y1)
- Ic = self.get_pixel_value(img, x1, y0)
- Id = self.get_pixel_value(img, x1, y1)
-
- # recast as float for delta calculation
- x0 = cast(x0, mindspore.float32)
- x1 = cast(x1, mindspore.float32)
- y0 = cast(y0, mindspore.float32)
- y1 = cast(y1, mindspore.float32)
-
- # calculate deltas
- wa = (x1-x) * (y1-y)
- wb = (x1-x) * (y-y0)
- wc = (x-x0) * (y1-y)
- wd = (x-x0) * (y-y0)
-
- # add dimension for addition
- expand_dims = P.ExpandDims()
- wa = expand_dims(wa, 3)
- wb = expand_dims(wb, 3)
- wc = expand_dims(wc, 3)
- wd = expand_dims(wd, 3)
-
- # compute output
- add_n = P.AddN()
- out = add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
-
- return out
-
-
- def construct(self, input_fmap, theta, out_dims=None, **kwargs):
- """
- Spatial Transformer Network layer implementation as described in [1].
-
- The layer is composed of 3 elements:
-
- - localization_net: takes the original image as input and outputs
- the parameters of the affine transformation that should be applied
- to the input image.
-
- - affine_grid_generator: generates a grid of (x,y) coordinates that
- correspond to a set of points where the input should be sampled
- to produce the transformed output.
-
- - bilinear_sampler: takes as input the original image and the grid
- and produces the output image using bilinear interpolation.
-
- Input
- -----
- - input_fmap: output of the previous layer. Can be input if spatial
- transformer layer is at the beginning of architecture. Should be
- a tensor of shape (B, H, W, C).
-
- - theta: affine transform tensor of shape (B, 6). Permits cropping,
- translation and isotropic scaling. Initialize to identity matrix.
- It is the output of the localization network.
-
- Returns
- -------
- - out_fmap: transformed input feature map. Tensor of size (B, C, H, W)-->(B, H, W, C).
-
- Notes
- -----
- [1]: 'Spatial Transformer Networks', Jaderberg et. al,
- (https://arxiv.org/abs/1506.02025)
- """
-
- # grab input dimensions
- trans = P.Transpose()
- input_fmap = trans(input_fmap, (0, 2, 3, 1))
- shape = P.Shape()
- input_size = shape(input_fmap)
- B = input_size[0]
- H = input_size[1]
- W = input_size[2]
- reshape = P.Reshape()
- theta = reshape(theta, (B, 2, 3))
-
- # generate grids of same size or upsample/downsample if specified
- if out_dims:
- out_H = out_dims[0]
- out_W = out_dims[1]
- batch_grids = self.affine_grid_generator(out_H, out_W, theta)
- else:
- batch_grids = self.affine_grid_generator(H, W, theta)
-
- x_s, y_s = P.Split(1, 2)(batch_grids)
- squeeze = P.Squeeze()
- x_s = squeeze(x_s)
- y_s = squeeze(y_s)
- out_fmap = self.bilinear_sampler(input_fmap, x_s, y_s)
- out_fmap = trans(out_fmap, (0, 3, 1, 2))
-
- return out_fmap
|