You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stn.py 9.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """STN module"""
  16. import numpy as np
  17. import mindspore
  18. from mindspore import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import composite as C
  21. import mindspore.nn as nn
  22. class STN(nn.Cell):
  23. '''STN'''
  24. def __init__(self, H, W):
  25. super(STN, self).__init__()
  26. batch_size = 1
  27. x = np.linspace(-1.0, 1.0, H)
  28. y = np.linspace(-1.0, 1.0, W)
  29. x_t, y_t = np.meshgrid(x, y)
  30. x_t = Tensor(x_t, mindspore.float32)
  31. y_t = Tensor(y_t, mindspore.float32)
  32. expand_dims = P.ExpandDims()
  33. x_t = expand_dims(x_t, 0)
  34. y_t = expand_dims(y_t, 0)
  35. flatten = P.Flatten()
  36. x_t_flat = flatten(x_t)
  37. y_t_flat = flatten(y_t)
  38. oneslike = P.OnesLike()
  39. ones = oneslike(x_t_flat)
  40. concat = P.Concat()
  41. sampling_grid = concat((x_t_flat, y_t_flat, ones))
  42. self.sampling_grid = expand_dims(sampling_grid, 0)
  43. batch_size = 128
  44. batch_idx = np.arange(batch_size)
  45. batch_idx = batch_idx.reshape((batch_size, 1, 1))
  46. self.batch_idx = Tensor(batch_idx, mindspore.float32)
  47. self.zero = Tensor(np.zeros([]), mindspore.float32)
  48. def get_pixel_value(self, img, x, y):
  49. """
  50. Utility function to get pixel value for coordinate
  51. vectors x and y from a 4D tensor image.
  52. Input
  53. -----
  54. - img: tensor of shape (B, H, W, C)
  55. - x: flattened tensor of shape (B*H*W,)
  56. - y: flattened tensor of shape (B*H*W,)
  57. Returns
  58. -------
  59. - output: tensor of shape (B, H, W, C)
  60. """
  61. shape = P.Shape()
  62. img_shape = shape(x)
  63. batch_size = img_shape[0]
  64. height = img_shape[1]
  65. width = img_shape[2]
  66. img[:, 0, :, :] = self.zero
  67. img[:, height-1, :, :] = self.zero
  68. img[:, :, 0, :] = self.zero
  69. img[:, :, width-1, :] = self.zero
  70. tile = P.Tile()
  71. batch_idx = P.Slice()(self.batch_idx, (0, 0, 0), (batch_size, 1, 1))
  72. b = tile(batch_idx, (1, height, width))
  73. expand_dims = P.ExpandDims()
  74. b = expand_dims(b, 3)
  75. x = expand_dims(x, 3)
  76. y = expand_dims(y, 3)
  77. concat = P.Concat(3)
  78. indices = concat((b, y, x))
  79. cast = P.Cast()
  80. indices = cast(indices, mindspore.int32)
  81. gather_nd = P.GatherNd()
  82. return cast(gather_nd(img, indices), mindspore.float32)
  83. def affine_grid_generator(self, height, width, theta):
  84. """
  85. This function returns a sampling grid, which when
  86. used with the bilinear sampler on the input feature
  87. map, will create an output feature map that is an
  88. affine transformation [1] of the input feature map.
  89. zero = Tensor(np.zeros([]), mindspore.float32)
  90. Input
  91. -----
  92. - height: desired height of grid/output. Used
  93. to downsample or upsample.
  94. - width: desired width of grid/output. Used
  95. to downsample or upsample.
  96. - theta: affine transform matrices of shape (num_batch, 2, 3).
  97. For each image in the batch, we have 6 theta parameters of
  98. the form (2x3) that define the affine transformation T.
  99. Returns
  100. -------
  101. - normalized grid (-1, 1) of shape (num_batch, 2, H, W).
  102. The 2nd dimension has 2 components: (x, y) which are the
  103. sampling points of the original image for each point in the
  104. target image.
  105. Note
  106. ----
  107. [1]: the affine transformation allows cropping, translation,
  108. and isotropic scaling.
  109. """
  110. shape = P.Shape()
  111. num_batch = shape(theta)[0]
  112. cast = P.Cast()
  113. theta = cast(theta, mindspore.float32)
  114. # transform the sampling grid - batch multiply
  115. matmul = P.BatchMatMul()
  116. tile = P.Tile()
  117. sampling_grid = tile(self.sampling_grid, (num_batch, 1, 1))
  118. cast = P.Cast()
  119. sampling_grid = cast(sampling_grid, mindspore.float32)
  120. batch_grids = matmul(theta, sampling_grid)
  121. # batch grid has shape (num_batch, 2, H*W)
  122. # reshape to (num_batch, H, W, 2)
  123. reshape = P.Reshape()
  124. batch_grids = reshape(batch_grids, (num_batch, 2, height, width))
  125. return batch_grids
  126. def bilinear_sampler(self, img, x, y):
  127. """
  128. Performs bilinear sampling of the input images according to the
  129. normalized coordinates provided by the sampling grid. Note that
  130. the sampling is done identically for each channel of the input.
  131. To test if the function works properly, output image should be
  132. identical to input image when theta is initialized to identity
  133. transform.
  134. Input
  135. -----
  136. - img: batch of images in (B, H, W, C) layout.
  137. - grid: x, y which is the output of affine_grid_generator.
  138. Returns
  139. -------
  140. - out: interpolated images according to grids. Same size as grid.
  141. """
  142. shape = P.Shape()
  143. H = shape(img)[1]
  144. W = shape(img)[2]
  145. cast = P.Cast()
  146. max_y = cast(H - 1, mindspore.float32)
  147. max_x = cast(W - 1, mindspore.float32)
  148. zero = self.zero
  149. # rescale x and y to [0, W-1/H-1]
  150. x = 0.5 * ((x + 1.0) * (max_x-1))
  151. y = 0.5 * ((y + 1.0) * (max_y-1))
  152. # grab 4 nearest corner points for each (x_i, y_i)
  153. floor = P.Floor()
  154. x0 = floor(x)
  155. x1 = x0 + 1
  156. y0 = floor(y)
  157. y1 = y0 + 1
  158. # clip to range [0, H-1/W-1] to not violate img boundaries
  159. x0 = C.clip_by_value(x0, zero, max_x)
  160. x1 = C.clip_by_value(x1, zero, max_x)
  161. y0 = C.clip_by_value(y0, zero, max_y)
  162. y1 = C.clip_by_value(y1, zero, max_y)
  163. # get pixel value at corner coords
  164. Ia = self.get_pixel_value(img, x0, y0)
  165. Ib = self.get_pixel_value(img, x0, y1)
  166. Ic = self.get_pixel_value(img, x1, y0)
  167. Id = self.get_pixel_value(img, x1, y1)
  168. # recast as float for delta calculation
  169. x0 = cast(x0, mindspore.float32)
  170. x1 = cast(x1, mindspore.float32)
  171. y0 = cast(y0, mindspore.float32)
  172. y1 = cast(y1, mindspore.float32)
  173. # calculate deltas
  174. wa = (x1-x) * (y1-y)
  175. wb = (x1-x) * (y-y0)
  176. wc = (x-x0) * (y1-y)
  177. wd = (x-x0) * (y-y0)
  178. # add dimension for addition
  179. expand_dims = P.ExpandDims()
  180. wa = expand_dims(wa, 3)
  181. wb = expand_dims(wb, 3)
  182. wc = expand_dims(wc, 3)
  183. wd = expand_dims(wd, 3)
  184. # compute output
  185. add_n = P.AddN()
  186. out = add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
  187. return out
  188. def construct(self, input_fmap, theta, out_dims=None, **kwargs):
  189. """
  190. Spatial Transformer Network layer implementation as described in [1].
  191. The layer is composed of 3 elements:
  192. - localization_net: takes the original image as input and outputs
  193. the parameters of the affine transformation that should be applied
  194. to the input image.
  195. - affine_grid_generator: generates a grid of (x,y) coordinates that
  196. correspond to a set of points where the input should be sampled
  197. to produce the transformed output.
  198. - bilinear_sampler: takes as input the original image and the grid
  199. and produces the output image using bilinear interpolation.
  200. Input
  201. -----
  202. - input_fmap: output of the previous layer. Can be input if spatial
  203. transformer layer is at the beginning of architecture. Should be
  204. a tensor of shape (B, H, W, C).
  205. - theta: affine transform tensor of shape (B, 6). Permits cropping,
  206. translation and isotropic scaling. Initialize to identity matrix.
  207. It is the output of the localization network.
  208. Returns
  209. -------
  210. - out_fmap: transformed input feature map. Tensor of size (B, C, H, W)-->(B, H, W, C).
  211. Notes
  212. -----
  213. [1]: 'Spatial Transformer Networks', Jaderberg et. al,
  214. (https://arxiv.org/abs/1506.02025)
  215. """
  216. # grab input dimensions
  217. trans = P.Transpose()
  218. input_fmap = trans(input_fmap, (0, 2, 3, 1))
  219. shape = P.Shape()
  220. input_size = shape(input_fmap)
  221. B = input_size[0]
  222. H = input_size[1]
  223. W = input_size[2]
  224. reshape = P.Reshape()
  225. theta = reshape(theta, (B, 2, 3))
  226. # generate grids of same size or upsample/downsample if specified
  227. if out_dims:
  228. out_H = out_dims[0]
  229. out_W = out_dims[1]
  230. batch_grids = self.affine_grid_generator(out_H, out_W, theta)
  231. else:
  232. batch_grids = self.affine_grid_generator(H, W, theta)
  233. x_s, y_s = P.Split(1, 2)(batch_grids)
  234. squeeze = P.Squeeze()
  235. x_s = squeeze(x_s)
  236. y_s = squeeze(y_s)
  237. out_fmap = self.bilinear_sampler(input_fmap, x_s, y_s)
  238. out_fmap = trans(out_fmap, (0, 3, 1, 2))
  239. return out_fmap