You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_ops_vm_impl.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate vm_impl function for nn ops"""
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.ops.operations import _grad_ops as G
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
  21. from .vm_interface import vm
  22. # pylint: disable=unused-argument
  23. @vm_impl_getters.register(P.ScalarSummary)
  24. def vm_impl_scalar_summary(self):
  25. """Generate vm_impl function for ScalarSummary"""
  26. def vm_impl(string_in, scalar):
  27. """Implement by vm mode."""
  28. return scalar
  29. return vm_impl
  30. @vm_impl_getters.register(P.ReLU)
  31. def vm_impl_relu(self):
  32. """Generate vm_impl function for ReLU"""
  33. def vm_impl(x):
  34. x = x.asnumpy()
  35. output = Tensor(vm.relu(x))
  36. return output
  37. return vm_impl
  38. @vm_impl_getters.register(P.Flatten)
  39. def vm_impl_flatten(self):
  40. """Generate vm_impl function for Flatten"""
  41. def vm_impl(x):
  42. x = x.asnumpy()
  43. return Tensor(vm.flatten_batch(x))
  44. return vm_impl
  45. @vm_impl_getters.register(P.Softmax)
  46. def vm_impl_softmax(self):
  47. """Generate vm_impl function for Softmax"""
  48. def vm_impl(x):
  49. x = x.asnumpy()
  50. return Tensor(vm.softmax(x))
  51. return vm_impl
  52. @vm_impl_getters.register(P.LogSoftmax)
  53. def vm_impl_log_softmax(self):
  54. """Generate vm_impl function for LogSoftmax"""
  55. def vm_impl(x):
  56. x = x.asnumpy()
  57. return Tensor(vm.logsoftmax(x))
  58. return vm_impl
  59. @vm_impl_getters.register(P.Tanh)
  60. def vm_impl_tanh(self):
  61. """Generate vm_impl function for Tanh"""
  62. def vm_impl(x):
  63. x = x.asnumpy()
  64. return Tensor(vm.tanh(x))
  65. return vm_impl
  66. @vm_impl_getters.register(P.FusedBatchNorm)
  67. def vm_impl_fused_batch_norm(self):
  68. """Generate vm_impl function for FusedBatchNorm"""
  69. def vm_impl(x, scale, b, mean, variance):
  70. # pylint: disable=unused-argument
  71. x = x.asnumpy()
  72. scale = scale.asnumpy()
  73. b = b.asnumpy()
  74. mean = mean.asnumpy()
  75. variance = variance.asnumpy()
  76. out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \
  77. variance, \
  78. eps=self.epsilon, \
  79. momentum=self.momentum)
  80. return Tensor(out), Tensor(x_mean), Tensor(x_var), \
  81. Tensor(running_mean), Tensor(running_var)
  82. return vm_impl
  83. @vm_impl_getters.register(P.BatchNorm)
  84. def vm_impl_batch_norm(self):
  85. """Generate vm_impl function for BatchNorm"""
  86. def vm_impl(x, scale, b, mean, variance):
  87. # pylint: disable=unused-argument
  88. x = x.asnumpy()
  89. scale = scale.asnumpy()
  90. b = b.asnumpy()
  91. mean = mean.asnumpy()
  92. variance = variance.asnumpy()
  93. out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \
  94. variance, \
  95. eps=self.epsilon)
  96. return Tensor(out), Tensor(x_mean), Tensor(x_var), \
  97. Tensor(running_mean), Tensor(running_var)
  98. return vm_impl
  99. @vm_impl_getters.register(P.Conv2D)
  100. def vm_impl_conv2d(self):
  101. """Generate vm_impl function for Conv2D"""
  102. def vm_impl(x, w):
  103. x = x.asnumpy()
  104. weight = w.asnumpy()
  105. bias = None
  106. out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation)
  107. return Tensor(out)
  108. return vm_impl
  109. @vm_impl_getters.register(G.MaxPoolGradWithArgmax)
  110. def vm_impl_max_pool_grad_with_argmax(self):
  111. """Generate vm_impl function for MaxPoolGradWithArgmax"""
  112. def vm_impl(x, dout, argmax):
  113. x = x.asnumpy()
  114. dout = dout.asnumpy()
  115. arg_max = argmax.asnumpy()
  116. dx = vm.max_pool_grad_with_argmax(x, dout, arg_max,
  117. self.ksize[1], self.ksize[2], self.strides[1])
  118. return Tensor(dx)
  119. return vm_impl
  120. @vm_impl_getters.register(P.MaxPoolWithArgmax)
  121. def vm_impl_max_pool_with_argmax(self):
  122. """Generate vm_impl function for MaxPoolWithArgmax"""
  123. def vm_impl(x):
  124. x = x.asnumpy()
  125. out, out_argmax = vm.max_pool_with_argmax(x, self.ksize[1], self.ksize[2], self.strides[1])
  126. return Tensor(out), Tensor(out_argmax)
  127. return vm_impl
  128. @vm_impl_getters.register(P.MaxPool)
  129. def vm_impl_max_pool(self):
  130. """Generate vm_impl function for MaxPool"""
  131. def vm_impl(x):
  132. x = x.asnumpy()
  133. out = vm.max_pooling(x, self.ksize[-2], self.ksize[-1], self.strides[-2])
  134. return Tensor(out)
  135. return vm_impl
  136. @vm_impl_getters.register(G.MaxPoolGrad)
  137. def vm_impl_max_pool_grad(self):
  138. """Generate vm_impl function for MaxPoolGrad"""
  139. def vm_impl(x, out, dout):
  140. x = x.asnumpy()
  141. dout = dout.asnumpy()
  142. out = vm.max_pool_grad(x, dout, self.ksize[-2], self.ksize[-1], self.strides[-2])
  143. return Tensor(out)
  144. return vm_impl
  145. @vm_impl_getters.register(P.AvgPool)
  146. def vm_impl_avg_pool(self):
  147. """Generate vm_impl function for AvgPool"""
  148. def vm_impl(x):
  149. x = x.asnumpy()
  150. out = vm.avg_pooling(x, self.ksize[-2], self.ksize[-1], self.strides[-2])
  151. return Tensor(out)
  152. return vm_impl
  153. @vm_impl_getters.register(G.AvgPoolGrad)
  154. def vm_impl_avg_pool_grad(self):
  155. """Generate vm_impl function for AvgPoolGrad"""
  156. def vm_impl(dout, origin_shape):
  157. dout = dout.asnumpy()
  158. out = vm.avg_pool_grad(dout, origin_shape, self.ksize[-2], self.ksize[-1], self.strides[-2])
  159. return Tensor(out)
  160. return vm_impl
  161. @vm_impl_getters.register(G.FusedBatchNormGrad)
  162. def vm_impl_fused_batch_norm_grad(self):
  163. """Generate vm_impl function for FusedBatchNormGrad"""
  164. def vm_impl(dy, x, scale, save_mean, save_inv_variance):
  165. dy = dy.asnumpy()
  166. x = x.asnumpy()
  167. scale = scale.asnumpy()
  168. save_mean = save_mean.asnumpy()
  169. save_inv_variance = save_inv_variance.asnumpy()
  170. dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
  171. return (Tensor(dx), Tensor(dscale), Tensor(dshift))
  172. return vm_impl
  173. @vm_impl_getters.register(G.BatchNormGrad)
  174. def vm_impl_fused_batch_norm_grad(self):
  175. """Generate vm_impl function for BatchNormGrad"""
  176. def vm_impl(dy, x, scale, save_mean, save_inv_variance):
  177. dy = dy.asnumpy()
  178. x = x.asnumpy()
  179. scale = scale.asnumpy()
  180. save_mean = save_mean.asnumpy()
  181. save_inv_variance = save_inv_variance.asnumpy()
  182. dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
  183. return (Tensor(dx), Tensor(dscale), Tensor(dshift))
  184. return vm_impl
  185. @vm_impl_getters.register(G.ReluGrad)
  186. def vm_impl_relu_grad(self):
  187. """Generate vm_impl function for ReluGrad"""
  188. def vm_impl(y_backprop, x):
  189. x = x.asnumpy()
  190. y_backprop = y_backprop.asnumpy()
  191. y_backprop = vm.relu_grad(x.copy()) * y_backprop
  192. return Tensor(y_backprop)
  193. return vm_impl
  194. @vm_impl_getters.register(P.Conv2DBackpropInput)
  195. def vm_impl_conv2d_backprop_input(self):
  196. """Generate vm_impl function for Conv2DBackpropInput"""
  197. def vm_impl(dout, w, x_size):
  198. dout = dout.asnumpy()
  199. w = w.asnumpy()
  200. dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad)
  201. return Tensor(dx)
  202. return vm_impl
  203. @vm_impl_getters.register(G.Conv2DBackpropFilter)
  204. def vm_impl_conv2d_backprop_filter(self):
  205. """Generate vm_impl function for Conv2DBackpropFilter"""
  206. def vm_impl(dout, x, w_size):
  207. x = x.asnumpy()
  208. dout = dout.asnumpy()
  209. dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad)
  210. return Tensor(dw)
  211. return vm_impl
  212. @vm_impl_getters.register(G.FlattenGrad)
  213. def vm_impl_flatten_grad(self):
  214. """Generate vm_impl function for FlattenGrad"""
  215. def vm_impl(dout, x):
  216. dout = dout.asnumpy()
  217. dout = vm.flatten_grad(dout, x)
  218. return Tensor(dout)
  219. return vm_impl
  220. @vm_impl_getters.register(P.BiasAdd)
  221. def vm_impl_bias_add(self):
  222. """Generate vm_impl function for BiasAdd"""
  223. def vm_impl(wx, bias):
  224. wx = wx.asnumpy()
  225. bias = bias.asnumpy()
  226. out = wx + bias
  227. return Tensor(out)
  228. return vm_impl
  229. @vm_impl_getters.register(G.BiasAddGrad)
  230. def vm_impl_bias_add_grad(self):
  231. """Generate vm_impl function for BiasAddGrad"""
  232. def vm_impl(dout):
  233. dout = dout.asnumpy()
  234. shape = np.shape(dout)
  235. return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1))))
  236. return vm_impl
  237. @vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits)
  238. def vm_impl_softmax_cross_entropy_with_logits(self):
  239. """Generate vm_impl function for SoftmaxCrossEntropyWithLogits"""
  240. def vm_impl(logits, labels):
  241. logits = logits.asnumpy()
  242. labels = labels.asnumpy()
  243. loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels)
  244. return (Tensor(np.array(loss)), Tensor(dx))
  245. return vm_impl
  246. @vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
  247. def vm_impl_sparse_softmax_cross_entropy_with_logits(self):
  248. """Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits"""
  249. def vm_impl(logits, labels):
  250. logits = logits.asnumpy()
  251. labels = labels.asnumpy()
  252. n_class = labels.max() + 1
  253. n_sample = labels.shape[0]
  254. one_hot_label = np.zeros((n_sample, n_class)) # 3个样本,4个类别
  255. one_hot_label[:, labels] = 1 # 非零列赋值为1
  256. loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label)
  257. if self.is_grad:
  258. return (Tensor(dx),)
  259. return (Tensor(np.array(loss)),)
  260. return vm_impl
  261. @vm_impl_getters.register(P.ApplyMomentum)
  262. def vm_impl_momentum(self):
  263. """Generate vm_impl function for Momentum"""
  264. def vm_impl(variable,
  265. accumulation,
  266. learning_rate,
  267. gradient,
  268. momentum,
  269. use_nesterov=False):
  270. gradient = gradient.asnumpy()
  271. accumulation = accumulation.asnumpy()
  272. variable = variable.asnumpy()
  273. shape = accumulation.shape
  274. learning_rate = np.full(shape, learning_rate)
  275. momentum = np.full(shape, momentum)
  276. accumulation = accumulation * momentum + gradient
  277. if use_nesterov is True:
  278. variable -= gradient * learning_rate + accumulation * momentum * learning_rate
  279. else:
  280. variable -= accumulation * learning_rate
  281. return Tensor(variable)
  282. return vm_impl
  283. @vm_impl_getters.register(P.ResizeBilinear)
  284. def vm_impl_resize_bilinear(self):
  285. """Generate vm_impl function for ResizeBilinear"""
  286. def vm_impl(x):
  287. out = vm.ResizeBilinear(x)
  288. return Tensor(out)
  289. return vm_impl
  290. @vm_impl_getters.register(G.ResizeBilinearGrad)
  291. def vm_impl_resize_bilinear_grad(self):
  292. """Generate vm_impl function for ResizeBilinearGrad"""
  293. def vm_impl(dout, original_image):
  294. out = vm.ResizeBilinearGrad(dout, original_image)
  295. return Tensor(out)
  296. return vm_impl