You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_ops_vm_impl.py 11 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for nn ops"""
  16. import numpy as np
  17. from mindspore.common.tensor import Tensor
  18. from mindspore.ops import operations as P
  19. from mindspore.ops.operations import _grad_ops as G
  20. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  21. from .vm_interface import vm
  22. # pylint: disable=unused-argument
  23. @vm_impl_getters.register(P.ScalarSummary)
  24. def vm_impl_scalar_summary(self):
  25. """Generate vm_impl function for ScalarSummary"""
  26. def vm_impl(string_in, scalar):
  27. """Implement by vm mode."""
  28. return scalar
  29. return vm_impl
  30. @vm_impl_getters.register(P.ReLU)
  31. def vm_impl_relu(self):
  32. """Generate vm_impl function for ReLU"""
  33. def vm_impl(x):
  34. x = x.asnumpy()
  35. output = Tensor(vm.relu(x))
  36. return output
  37. return vm_impl
  38. @vm_impl_getters.register(P.Flatten)
  39. def vm_impl_flatten(self):
  40. """Generate vm_impl function for Flatten"""
  41. def vm_impl(x):
  42. x = x.asnumpy()
  43. return Tensor(vm.flatten_batch(x))
  44. return vm_impl
  45. @vm_impl_getters.register(P.Softmax)
  46. def vm_impl_softmax(self):
  47. """Generate vm_impl function for Softmax"""
  48. def vm_impl(x):
  49. x = x.asnumpy()
  50. return Tensor(vm.softmax(x))
  51. return vm_impl
  52. @vm_impl_getters.register(P.LogSoftmax)
  53. def vm_impl_log_softmax(self):
  54. """Generate vm_impl function for LogSoftmax"""
  55. def vm_impl(x):
  56. x = x.asnumpy()
  57. return Tensor(vm.logsoftmax(x))
  58. return vm_impl
  59. @vm_impl_getters.register(P.Tanh)
  60. def vm_impl_tanh(self):
  61. """Generate vm_impl function for Tanh"""
  62. def vm_impl(x):
  63. x = x.asnumpy()
  64. return Tensor(vm.tanh(x))
  65. return vm_impl
  66. @vm_impl_getters.register(P.BatchNorm)
  67. def vm_impl_batch_norm(self):
  68. """Generate vm_impl function for BatchNorm"""
  69. def vm_impl(x, scale, b, mean, variance):
  70. # pylint: disable=unused-argument
  71. x = x.asnumpy()
  72. scale = scale.asnumpy()
  73. b = b.asnumpy()
  74. mean = mean.asnumpy()
  75. variance = variance.asnumpy()
  76. out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \
  77. variance, \
  78. eps=self.epsilon)
  79. return Tensor(out), Tensor(x_mean), Tensor(x_var), \
  80. Tensor(running_mean), Tensor(running_var)
  81. return vm_impl
  82. @vm_impl_getters.register(P.Conv2D)
  83. def vm_impl_conv2d(self):
  84. """Generate vm_impl function for Conv2D"""
  85. def vm_impl(x, w):
  86. x = x.asnumpy()
  87. weight = w.asnumpy()
  88. bias = None
  89. out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation)
  90. return Tensor(out)
  91. return vm_impl
  92. @vm_impl_getters.register(G.MaxPoolGradWithArgmax)
  93. def vm_impl_max_pool_grad_with_argmax(self):
  94. """Generate vm_impl function for MaxPoolGradWithArgmax"""
  95. def vm_impl(x, dout, argmax):
  96. x = x.asnumpy()
  97. dout = dout.asnumpy()
  98. arg_max = argmax.asnumpy()
  99. dx = vm.max_pool_grad_with_argmax(x, dout, arg_max,
  100. self.kernel_size[1], self.kernel_size[2], self.strides[1])
  101. return Tensor(dx)
  102. return vm_impl
  103. @vm_impl_getters.register(P.MaxPoolWithArgmax)
  104. def vm_impl_max_pool_with_argmax(self):
  105. """Generate vm_impl function for MaxPoolWithArgmax"""
  106. def vm_impl(x):
  107. x = x.asnumpy()
  108. out, out_argmax = vm.max_pool_with_argmax(x, self.kernel_size[1], self.kernel_size[2], self.strides[1])
  109. return Tensor(out), Tensor(out_argmax)
  110. return vm_impl
  111. @vm_impl_getters.register(P.MaxPool)
  112. def vm_impl_max_pool(self):
  113. """Generate vm_impl function for MaxPool"""
  114. def vm_impl(x):
  115. x = x.asnumpy()
  116. out = vm.max_pooling(x, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
  117. return Tensor(out)
  118. return vm_impl
  119. @vm_impl_getters.register(G.MaxPoolGrad)
  120. def vm_impl_max_pool_grad(self):
  121. """Generate vm_impl function for MaxPoolGrad"""
  122. def vm_impl(x, out, dout):
  123. x = x.asnumpy()
  124. dout = dout.asnumpy()
  125. out = vm.max_pool_grad(x, dout, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
  126. return Tensor(out)
  127. return vm_impl
  128. @vm_impl_getters.register(P.AvgPool)
  129. def vm_impl_avg_pool(self):
  130. """Generate vm_impl function for AvgPool"""
  131. def vm_impl(x):
  132. x = x.asnumpy()
  133. out = vm.avg_pooling(x, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
  134. return Tensor(out)
  135. return vm_impl
  136. @vm_impl_getters.register(G.AvgPoolGrad)
  137. def vm_impl_avg_pool_grad(self):
  138. """Generate vm_impl function for AvgPoolGrad"""
  139. def vm_impl(dout, origin_shape):
  140. dout = dout.asnumpy()
  141. out = vm.avg_pool_grad(dout, origin_shape, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
  142. return Tensor(out)
  143. return vm_impl
  144. # pylint: disable=function-redefined
  145. @vm_impl_getters.register(G.BatchNormGrad)
  146. def vm_impl_fused_batch_norm_grad(self):
  147. """Generate vm_impl function for BatchNormGrad"""
  148. def vm_impl(dy, x, scale, save_mean, save_inv_variance):
  149. dy = dy.asnumpy()
  150. x = x.asnumpy()
  151. scale = scale.asnumpy()
  152. save_mean = save_mean.asnumpy()
  153. save_inv_variance = save_inv_variance.asnumpy()
  154. dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
  155. return (Tensor(dx), Tensor(dscale), Tensor(dshift))
  156. return vm_impl
  157. @vm_impl_getters.register(G.ReluGrad)
  158. def vm_impl_relu_grad(self):
  159. """Generate vm_impl function for ReluGrad"""
  160. def vm_impl(y_backprop, x):
  161. x = x.asnumpy()
  162. y_backprop = y_backprop.asnumpy()
  163. y_backprop = vm.relu_grad(x.copy()) * y_backprop
  164. return Tensor(y_backprop)
  165. return vm_impl
  166. @vm_impl_getters.register(P.Conv2DBackpropInput)
  167. def vm_impl_conv2d_backprop_input(self):
  168. """Generate vm_impl function for Conv2DBackpropInput"""
  169. def vm_impl(dout, w, x_size):
  170. dout = dout.asnumpy()
  171. w = w.asnumpy()
  172. dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad)
  173. return Tensor(dx)
  174. return vm_impl
  175. @vm_impl_getters.register(G.Conv2DBackpropFilter)
  176. def vm_impl_conv2d_backprop_filter(self):
  177. """Generate vm_impl function for Conv2DBackpropFilter"""
  178. def vm_impl(dout, x, w_size):
  179. x = x.asnumpy()
  180. dout = dout.asnumpy()
  181. dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad)
  182. return Tensor(dw)
  183. return vm_impl
  184. @vm_impl_getters.register(G.FlattenGrad)
  185. def vm_impl_flatten_grad(self):
  186. """Generate vm_impl function for FlattenGrad"""
  187. def vm_impl(dout, x):
  188. dout = dout.asnumpy()
  189. dout = vm.flatten_grad(dout, x)
  190. return Tensor(dout)
  191. return vm_impl
  192. @vm_impl_getters.register(P.BiasAdd)
  193. def vm_impl_bias_add(self):
  194. """Generate vm_impl function for BiasAdd"""
  195. def vm_impl(wx, bias):
  196. wx = wx.asnumpy()
  197. bias = bias.asnumpy()
  198. out = wx + bias
  199. return Tensor(out)
  200. return vm_impl
  201. @vm_impl_getters.register(G.BiasAddGrad)
  202. def vm_impl_bias_add_grad(self):
  203. """Generate vm_impl function for BiasAddGrad"""
  204. def vm_impl(dout):
  205. dout = dout.asnumpy()
  206. shape = np.shape(dout)
  207. return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1))))
  208. return vm_impl
  209. @vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits)
  210. def vm_impl_softmax_cross_entropy_with_logits(self):
  211. """Generate vm_impl function for SoftmaxCrossEntropyWithLogits"""
  212. def vm_impl(logits, labels):
  213. logits = logits.asnumpy()
  214. labels = labels.asnumpy()
  215. loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels)
  216. return (Tensor(np.array(loss)), Tensor(dx))
  217. return vm_impl
  218. @vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
  219. def vm_impl_sparse_softmax_cross_entropy_with_logits(self):
  220. """Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits"""
  221. def vm_impl(logits, labels):
  222. logits = logits.asnumpy()
  223. labels = labels.asnumpy()
  224. n_class = labels.max() + 1
  225. n_sample = labels.shape[0]
  226. one_hot_label = np.zeros((n_sample, n_class)) # 3个样本,4个类别
  227. one_hot_label[:, labels] = 1 # 非零列赋值为1
  228. loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label)
  229. if self.is_grad:
  230. return (Tensor(dx),)
  231. return (Tensor(np.array(loss)),)
  232. return vm_impl
  233. @vm_impl_getters.register(P.ApplyMomentum)
  234. def vm_impl_momentum(self):
  235. """Generate vm_impl function for Momentum"""
  236. def vm_impl(variable,
  237. accumulation,
  238. learning_rate,
  239. gradient,
  240. momentum,
  241. use_nesterov=False):
  242. gradient = gradient.asnumpy()
  243. accumulation = accumulation.asnumpy()
  244. variable = variable.asnumpy()
  245. shape = accumulation.shape
  246. learning_rate = np.full(shape, learning_rate.asnumpy())
  247. momentum = np.full(shape, momentum.asnumpy())
  248. accumulation = accumulation * momentum + gradient
  249. if use_nesterov is True:
  250. variable -= gradient * learning_rate + accumulation * momentum * learning_rate
  251. else:
  252. variable -= accumulation * learning_rate
  253. return Tensor(variable)
  254. return vm_impl
  255. @vm_impl_getters.register(P.ResizeBilinear)
  256. def vm_impl_resize_bilinear(self):
  257. """Generate vm_impl function for ResizeBilinear"""
  258. def vm_impl(x):
  259. out = vm.ResizeBilinear(x)
  260. return Tensor(out)
  261. return vm_impl
  262. @vm_impl_getters.register(G.ResizeBilinearGrad)
  263. def vm_impl_resize_bilinear_grad(self):
  264. """Generate vm_impl function for ResizeBilinearGrad"""
  265. def vm_impl(dout, original_image):
  266. out = vm.ResizeBilinearGrad(dout, original_image)
  267. return Tensor(out)
  268. return vm_impl