You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_nn_ops.py 31 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of neural network related operations."""
  16. import numpy as np
  17. from mindspore.ops import _selected_grad_ops as SG
  18. from mindspore.ops.primitive import constexpr
  19. from mindspore.common.tensor import Tensor
  20. from .grad_base import bprop_getters
  21. from .. import functional as F
  22. from .. import operations as P
  23. from ...common import dtype as mstype
  24. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  25. from ..operations import _grad_ops as G
  26. from ..operations import _inner_ops as inner
  27. from ... import context
  28. @bprop_getters.register(P.BiasAdd)
  29. def get_bprop_bias_add(self):
  30. """Grad definition for `BiasAdd` operation."""
  31. bias_grad = SG.BiasAddGrad()
  32. def bprop(x, w, out, dout):
  33. return dout, bias_grad(dout)
  34. return bprop
  35. @bprop_getters.register(P.Conv2D)
  36. def get_bprop_conv2d(self):
  37. """Grad definition for `Conv2D` operation."""
  38. input_grad = P.Conv2DBackpropInput(
  39. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  40. dilation=self.dilation, stride=self.stride, group=self.group
  41. )
  42. filter_grad = G.Conv2DBackpropFilter(
  43. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  44. dilation=self.dilation, stride=self.stride, group=self.group
  45. )
  46. get_shape = P.Shape()
  47. def bprop(x, w, out, dout):
  48. dx = input_grad(dout, w, get_shape(x))
  49. dw = filter_grad(dout, x, get_shape(w))
  50. return dx, dw
  51. return bprop
  52. @bprop_getters.register(inner.ExtractImagePatches)
  53. def get_bprop_extract_image_patches(self):
  54. """Grad definition for `ExtractImagePatches` operation."""
  55. get_shape = P.Shape()
  56. reshape = P.Reshape()
  57. extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
  58. strides=self.strides,
  59. rates=self.rates,
  60. padding=self.padding)
  61. concat = P.Concat(axis=-1)
  62. expand_dims = P.ExpandDims()
  63. scatter_nd = P.ScatterNd()
  64. dtype = P.DType()
  65. fill = P.Fill()
  66. slice_op = P.Slice()
  67. transpose = P.Transpose()
  68. cast = P.Cast()
  69. matmul = P.MatMul()
  70. _, ksizes_row, ksizes_col, _ = self.ksizes
  71. def bprop(x, out, dout):
  72. x_shape = get_shape(x)
  73. x_batch, x_depth, x_row, x_col = x_shape
  74. x_indices_num = x_row * x_col + 1
  75. x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
  76. x_idx = reshape(x_idx, (1, 1, x_row, x_col))
  77. x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
  78. x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
  79. out_shape = get_shape(out)
  80. _, _, out_row, out_col = out_shape
  81. out_indices_num = out_row * out_col * ksizes_row * ksizes_col
  82. out_idx = F.tuple_to_array(range(out_indices_num))
  83. out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
  84. idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
  85. idx_tensor = reshape(idx_tensor, (-1, 2))
  86. sp_shape = (x_indices_num, out_indices_num)
  87. sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
  88. sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
  89. grad = transpose(dout, (0, 2, 3, 1))
  90. grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
  91. grad = transpose(grad, (1, 2, 3, 4, 0, 5))
  92. grad = reshape(grad, (-1, x_batch * x_depth))
  93. jac = matmul(sp_tensor, grad)
  94. dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
  95. dx = transpose(dx, (2, 3, 0, 1))
  96. return (dx,)
  97. def bprop_ge(x, out, dout):
  98. x_shape = get_shape(x)
  99. x_batch, x_row, x_col, x_depth = x_shape
  100. x_indices_num = x_row * x_col + 1
  101. x_idx = F.tuple_to_array(range(1, x_indices_num))
  102. x_idx = reshape(x_idx, (1, x_row, x_col, 1))
  103. x_idx_patch = extract_image_patches(x_idx)
  104. out_shape = get_shape(out)
  105. _, out_row, out_col, _ = out_shape
  106. out_indices_num = out_row * out_col * ksizes_row * ksizes_col
  107. out_idx = F.tuple_to_array(range(out_indices_num))
  108. out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
  109. idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
  110. idx_tensor = reshape(idx_tensor, (-1, 2))
  111. sp_shape = (x_indices_num, out_indices_num)
  112. sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
  113. sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
  114. grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
  115. grad = transpose(grad, (1, 2, 3, 4, 0, 5))
  116. grad = reshape(grad, (-1, x_batch * x_depth))
  117. jac = matmul(sp_tensor, grad)
  118. dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
  119. dx = transpose(dx, (2, 0, 1, 3))
  120. return (dx,)
  121. if context.get_context("enable_ge"):
  122. return bprop_ge
  123. return bprop
  124. @bprop_getters.register(P.DepthwiseConv2dNative)
  125. def get_bprop_depthwise_conv2d_native(self):
  126. """Grad definition for `DepthwiseConv2dNative` operation."""
  127. input_grad = G.DepthwiseConv2dNativeBackpropInput(
  128. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pads, self.mode, self.stride,
  129. self.dilation, self.group
  130. )
  131. filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
  132. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pads, self.mode, self.stride,
  133. self.dilation, self.group
  134. )
  135. get_shape = P.Shape()
  136. def bprop(x, w, out, dout):
  137. dx = input_grad(get_shape(x), w, dout)
  138. dw = filter_grad(x, get_shape(w), dout)
  139. return dx, dw
  140. return bprop
  141. @bprop_getters.register(P.MaxPoolWithArgmax)
  142. def get_bprop_max_pool_with_argmax(self):
  143. """Grad definition for `MaxPoolWithArgmax` operation."""
  144. maxpool_grad = G.MaxPoolGradWithArgmax(
  145. ksize=self.ksize,
  146. strides=self.strides,
  147. padding=self.padding)
  148. def bprop(x, out, dout):
  149. dx = maxpool_grad(x, dout[0], out[1])
  150. return (dx,)
  151. return bprop
  152. @bprop_getters.register(G.MaxPoolGrad)
  153. def get_bprop_max_pool_grad_grad(self):
  154. """Grad definition for `MaxPoolGrad` operation."""
  155. maxpool_grad_grad = G.MaxPoolGradGrad(
  156. ksize=self.ksize,
  157. strides=self.strides,
  158. padding=self.padding)
  159. def bprop(x1, x2, grad, out, dout):
  160. dx1 = zeros_like(x1)
  161. dx2 = zeros_like(x2)
  162. dgrad = maxpool_grad_grad(x1, x2, dout)
  163. return (dx1, dx2, dgrad)
  164. return bprop
  165. @bprop_getters.register(G.MaxPoolGradGrad)
  166. def get_bprop_max_pool_grad_grad_grad(self):
  167. """Grad definition for `MaxPoolGradGrad` operation."""
  168. maxpool_grad = G.MaxPoolGrad(
  169. ksize=self.ksize,
  170. strides=self.strides,
  171. padding=self.padding)
  172. def bprop(x1, x2, grad, out, dout):
  173. dx1 = zeros_like(x1)
  174. dx2 = zeros_like(x2)
  175. dgrad = maxpool_grad(x1, x2, dout)
  176. return (dx1, dx2, dgrad)
  177. return bprop
  178. @bprop_getters.register(P.MaxPool)
  179. def get_bprop_max_pool_grad(self):
  180. """Grad definition for `MaxPool` operation."""
  181. maxpool_grad = G.MaxPoolGrad(
  182. ksize=self.ksize,
  183. strides=self.strides,
  184. padding=self.padding)
  185. def bprop(x, out, dout):
  186. dx = maxpool_grad(x, out, dout)
  187. return (dx,)
  188. return bprop
  189. def _windowed_output_size(input_size, ksize, stride, padding):
  190. """
  191. helper func for AvgPoolGrad
  192. """
  193. tmp_output = 0
  194. tmp_pad_need = 0
  195. tmp_pad_before = 0
  196. tmp_pad_after = 0
  197. if padding == 'VALID':
  198. tmp_output = (input_size - ksize + stride) // stride
  199. tmp_pad_before = 0
  200. tmp_pad_after = 0
  201. elif padding == 'SAME':
  202. tmp_output = (input_size + stride - 1) // stride
  203. tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size)
  204. tmp_pad_before = tmp_pad_need // 2
  205. tmp_pad_after = tmp_pad_need - tmp_pad_before
  206. return tmp_output, tmp_pad_before, tmp_pad_after
  207. @constexpr
  208. def _get_mean_matrix(x_shape, ksize, stride, padding, x_dtype):
  209. """
  210. helper func for AvgPoolGrad.
  211. `assist_input_matrix` is a 2d matrix with input_shape after padding,
  212. the value of element which is padded is 0, else are 1.
  213. For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize,
  214. w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the
  215. number of input that assosiate with output element.
  216. """
  217. n_input, c_input, h_input, w_input = x_shape
  218. h_ksize, w_ksize = ksize[2], ksize[3]
  219. if h_ksize == h_input and w_ksize == w_input and padding == "VALID":
  220. return None
  221. h_stride, w_stride = stride[2], stride[3]
  222. n_output = n_input
  223. c_output = c_input
  224. h_output, w_output = 0, 0
  225. pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
  226. h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize,
  227. h_stride, padding)
  228. w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize,
  229. w_stride, padding)
  230. output_size = n_output * c_output * h_output * w_output
  231. output_shape = (n_output, c_output, h_output, w_output)
  232. output = np.array([0.0] * output_size)
  233. output = np.reshape(output, output_shape)
  234. in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right)
  235. assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32)
  236. if pad_top > 0:
  237. assist_input_matrix[:pad_top, :] = 0
  238. if pad_bottom > 0:
  239. assist_input_matrix[-pad_bottom:, :] = 0
  240. if pad_left > 0:
  241. assist_input_matrix[:, :pad_left] = 0
  242. if pad_right > 0:
  243. assist_input_matrix[:, -pad_right:] = 0
  244. for h in range(h_output):
  245. for w in range(w_output):
  246. curr_input = assist_input_matrix[h*h_stride : h*h_stride + h_ksize, w*w_stride : w*w_stride + w_ksize]
  247. curr_sum = np.sum(curr_input)
  248. if curr_sum > 0:
  249. output[:, :, h, w] = 1. / curr_sum
  250. return Tensor(output, x_dtype)
  251. @constexpr
  252. def _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, padding, x_dtype):
  253. if x_shape_nchw[2] == kernel_matrix_shape[2] \
  254. and x_shape_nchw[3] == kernel_matrix_shape[3] \
  255. and padding == 'VALID':
  256. return None
  257. kernel_matrix = np.ones(kernel_matrix_shape)
  258. return Tensor(kernel_matrix, x_dtype)
  259. @bprop_getters.register(P.AvgPool)
  260. def get_bprop_avg_pool_grad(self):
  261. """Grad definition for `AvgPool` operation."""
  262. # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same
  263. if self.target == "GPU":
  264. avgpool_grad_gpu = G.AvgPoolGradGpu(
  265. ksize=self.ksize,
  266. strides=self.strides,
  267. padding=self.padding)
  268. def bprop_gpu(x, out, dout):
  269. dx = avgpool_grad_gpu(x, out, dout)
  270. return (dx,)
  271. bprop_fn = bprop_gpu
  272. elif self.target == "GE":
  273. avgpool_grad_ge = G.AvgPoolGrad(
  274. ksize=self.ksize,
  275. strides=self.strides,
  276. padding=self.padding)
  277. shape_op = P.Shape()
  278. def bprop_ge(x, out, dout):
  279. dx = avgpool_grad_ge(shape_op(x), dout)
  280. return (dx,)
  281. bprop_fn = bprop_ge
  282. else:
  283. avgpool_grad_vm = G.AvgPoolGradVm(
  284. ksize=self.ksize,
  285. strides=self.strides,
  286. padding=self.padding)
  287. k_size_nchw = avgpool_grad_vm.ksize
  288. stride_nchw = avgpool_grad_vm.strides
  289. padding = self.padding
  290. def bprop_vm(x, out, dout):
  291. x_shape_nchw = F.shape(x)
  292. x_dtype = F.dtype(x)
  293. kernel_matrix_shape = (1, x_shape_nchw[1],
  294. k_size_nchw[2],
  295. k_size_nchw[3])
  296. mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, padding, x_dtype)
  297. kernel_matrix = _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, padding, x_dtype)
  298. dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix)
  299. return (dx,)
  300. bprop_fn = bprop_vm
  301. return bprop_fn
  302. @bprop_getters.register(P.DropoutGenMask)
  303. def get_bprop_dropout_gen_mask(self):
  304. """Grad definition for `DropoutGenMask` operation."""
  305. def bprop(shape, keep_prob, out, dout):
  306. return (zeros_like(shape), zeros_like(keep_prob))
  307. return bprop
  308. @bprop_getters.register(P.DropoutDoMask)
  309. def get_bprop_dropout_do_mask(self):
  310. """Grad definition for `DropoutDoMask` operation."""
  311. do_mask = P.DropoutDoMask()
  312. def bprop(x, y, keep_prob, out, dout):
  313. return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
  314. return bprop
  315. @bprop_getters.register(P.ReLU)
  316. def get_bprop_relu(self):
  317. """Grad definition for `ReLU` operation."""
  318. input_grad = G.ReluGrad()
  319. def bprop(x, out, dout):
  320. dx = input_grad(dout, out)
  321. return (dx,)
  322. return bprop
  323. @bprop_getters.register(P.ReLU6)
  324. def get_bprop_relu6(self):
  325. """Grad definition for `ReLU6` operation."""
  326. input_grad = G.ReLU6Grad()
  327. def bprop(x, out, dout):
  328. dx = input_grad(dout, x)
  329. return (dx,)
  330. return bprop
  331. @bprop_getters.register(P.ReLUV2)
  332. def get_bprop_relu_v2(self):
  333. """Grad definition for `ReLUV2` operation."""
  334. input_grad = G.ReluGradV2()
  335. def bprop(x, out, dout):
  336. mask = out[1]
  337. dx = input_grad(dout[0], mask)
  338. return (dx,)
  339. return bprop
  340. @bprop_getters.register(P.HSwish)
  341. def get_bprop_hswish(self):
  342. """Grad definition for `HSwish` operation."""
  343. input_grad = G.HSwishGrad()
  344. def bprop(x, out, dout):
  345. dx = input_grad(dout, x)
  346. return (dx,)
  347. return bprop
  348. @bprop_getters.register(P.HSigmoid)
  349. def get_bprop_hsigmoid(self):
  350. """Grad definition for `HSigmoid` operation."""
  351. input_grad = G.HSigmoidGrad()
  352. def bprop(x, out, dout):
  353. dx = input_grad(dout, x)
  354. return (dx,)
  355. return bprop
  356. @bprop_getters.register(P.Elu)
  357. def get_bprop_elu(self):
  358. """Grad definition for `Elu` operation."""
  359. input_grad = G.EluGrad()
  360. def bprop(x, out, dout):
  361. dx = input_grad(dout, out)
  362. return (dx,)
  363. return bprop
  364. @bprop_getters.register(P.Sigmoid)
  365. def get_bprop_sigmoid(self):
  366. """Grad definition for `Sigmoid` operation."""
  367. input_grad = G.SigmoidGrad()
  368. def bprop(x, out, dout):
  369. dx = input_grad(out, dout)
  370. return (dx,)
  371. return bprop
  372. @bprop_getters.register(P.Softmax)
  373. def get_bprop_softmax(self):
  374. """Grad definition for `Softmax` operation."""
  375. sum_func = P.ReduceSum(keep_dims=True)
  376. sub = P.Sub()
  377. mul = P.Mul()
  378. axis = self.axis
  379. def bprop(x, out, dout):
  380. dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
  381. return (dx,)
  382. return bprop
  383. @bprop_getters.register(P.LogSoftmax)
  384. def get_bprop_log_softmax(self):
  385. """Grad definition for `LogSoftmax` operation."""
  386. logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
  387. def bprop(x, out, dout):
  388. dx = logsoftmax_grad(out, dout)
  389. return (dx,)
  390. return bprop
  391. @bprop_getters.register(P.Softplus)
  392. def get_bprop_softplus(self):
  393. """Grad definition for `Softplus` operation."""
  394. softplus_grad = G.SoftplusGrad()
  395. def bprop(x, out, dout):
  396. dx = softplus_grad(dout, x)
  397. return (dx,)
  398. return bprop
  399. @bprop_getters.register(P.Softsign)
  400. def get_bprop_softsign(self):
  401. """Grad definition for `Softsign` operation."""
  402. mul = P.Mul()
  403. absolute = P.Abs()
  404. div = P.Div()
  405. square = P.Square()
  406. def bprop(x, out, dout):
  407. dx = mul(dout, div(1, square(1 + absolute(x))))
  408. return (dx,)
  409. return bprop
  410. @bprop_getters.register(P.Tanh)
  411. def get_bprop_tanh(self):
  412. """Grad definition for `Tanh` operation."""
  413. tanh_grad = SG.TanhGrad()
  414. def bprop(x, out, dout):
  415. dx = tanh_grad(out, dout)
  416. return (dx,)
  417. return bprop
  418. @bprop_getters.register(P.Gelu)
  419. def get_bprop_gelu(self):
  420. """Grad definition for `Gelu` operation."""
  421. input_grad = G.GeluGrad()
  422. def bprop(x, out, dout):
  423. dx = input_grad(dout, x, out)
  424. return (dx,)
  425. return bprop
  426. @bprop_getters.register(P.FusedBatchNorm)
  427. def get_bprop_fused_batch_norm(self):
  428. """Grad definition for `FusedBatchNorm` operation."""
  429. input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum)
  430. def bprop(x, scale, b, mean, variance, out, dout):
  431. saved_mean = out[3]
  432. saved_variance = out[4]
  433. out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
  434. dx = out[0]
  435. dscale = out[1]
  436. dbias = out[2]
  437. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  438. return bprop
  439. @bprop_getters.register(P.FusedBatchNormEx)
  440. def get_bprop_fused_batch_norm_ex(self):
  441. """Grad definition for `FusedBatchNormEx` operation."""
  442. input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum)
  443. def bprop(x, scale, b, mean, variance, out, dout):
  444. saved_mean = out[3]
  445. saved_variance = out[4]
  446. reserve = out[5]
  447. out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
  448. dx = out[0]
  449. dscale = out[1]
  450. dbias = out[2]
  451. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  452. return bprop
  453. @bprop_getters.register(P.BatchNorm)
  454. def get_bprop_batch_norm(self):
  455. """Grad definition for `BatchNorm` operation."""
  456. is_training = self.is_training
  457. input_grad = G.BatchNormGrad(is_training, self.epsilon)
  458. def bprop(x, scale, b, mean, variance, out, dout):
  459. if is_training:
  460. saved_reserve_1 = out[3]
  461. saved_reserve_2 = out[4]
  462. else:
  463. saved_reserve_1 = mean
  464. saved_reserve_2 = variance
  465. out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
  466. dx = out[0]
  467. dscale = out[1]
  468. dbias = out[2]
  469. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  470. return bprop
  471. @bprop_getters.register(P.LayerNorm)
  472. def get_bprop_layer_norm(self):
  473. """Grad definition for `LayerNorm` operation."""
  474. layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
  475. def bprop(x, gamma, beta, out, dout):
  476. dx, d_gamma, d_beta = layer_norm_grad(
  477. x, dout[0], out[2], out[1], gamma)
  478. return dx, d_gamma, d_beta
  479. return bprop
  480. @bprop_getters.register(P.L2Normalize)
  481. def get_bprop_l2normalize(self):
  482. """Grad definition for `L2Normalize` operation."""
  483. input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
  484. def bprop(x, out, dout):
  485. dx = input_grad(x, out, dout)
  486. return (dx,)
  487. return bprop
  488. @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
  489. def get_bprop_softmax_cross_entropy_with_logits(self):
  490. """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
  491. expand = P.ExpandDims()
  492. def bprop(logits, labels, out, dout):
  493. grad = out[1]
  494. grad = grad * expand(dout[0], -1)
  495. return grad, zeros_like(labels)
  496. return bprop
  497. @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
  498. def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
  499. """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
  500. is_grad = self.is_grad
  501. grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
  502. def bprop(logits, labels, out, dout):
  503. grad = out[0]
  504. if not is_grad:
  505. # if construct use loss
  506. grad = grad_op(logits, labels)
  507. grad = F.depend(grad, out)
  508. grad = grad * dout
  509. return grad, zeros_like(labels)
  510. return bprop
  511. @bprop_getters.register(P.ResizeBilinear)
  512. def get_bprop_resize_bilinear(self):
  513. """Grad definition for `ResizeBilinear` operation."""
  514. resize_grad = G.ResizeBilinearGrad(self.align_corners)
  515. def bprop(x, out, dout):
  516. dx = resize_grad(dout, x)
  517. return (dx,)
  518. return bprop
  519. @bprop_getters.register(P.OneHot)
  520. def get_bprop_onehot(self):
  521. """Grad definition for `OneHot` operation."""
  522. def bprop(indices, depth, on_value, off_value, out, dout):
  523. return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
  524. return bprop
  525. @constexpr
  526. def _range_op(start, limit, delta, dtype):
  527. """helper function for Grad TopK"""
  528. output_tensor = Tensor(list(range(start, limit, delta)), dtype)
  529. return output_tensor
  530. @constexpr
  531. def _get_1d_shape(in_shape):
  532. """helper function for Grad TopK"""
  533. out_shape = 1
  534. for i in in_shape:
  535. out_shape *= i
  536. return (out_shape,)
  537. @bprop_getters.register(P.TopK)
  538. def get_bprop_top_kv2(self):
  539. """Grad definition for `TopK` operation."""
  540. scatter = P.ScatterNd()
  541. expand_dims = P.ExpandDims()
  542. shape_op = P.Shape()
  543. reshape_op = P.Reshape()
  544. dtype = P.DType()
  545. def bprop(input_x, k, out, dout):
  546. in_shape = shape_op(input_x)
  547. in_lastdim = in_shape[-1]
  548. indices = out[1]
  549. ind_shape = shape_op(indices)
  550. ind_lastdim = ind_shape[-1]
  551. ind_2d = reshape_op(indices, (-1, ind_lastdim))
  552. outerdim = shape_op(ind_2d)[0]
  553. # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
  554. indices_dtype = dtype(indices)
  555. range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
  556. # expand_dims to (k, 1), then broadcast
  557. ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
  558. in_shape_1d = _get_1d_shape(in_shape)
  559. out_grad = reshape_op(
  560. scatter(
  561. expand_dims(ind, -1),
  562. reshape_op(dout[0], (-1,)),
  563. in_shape_1d),
  564. in_shape)
  565. return out_grad, zeros_like(k)
  566. return bprop
  567. @bprop_getters.register(P.SmoothL1Loss)
  568. def get_bprop_smooth_l1_loss(self):
  569. """Grad definition for `SmoothL1Loss` operation."""
  570. grad = G.SmoothL1LossGrad(self.beta)
  571. def bprop(prediction, target, out, dout):
  572. dx = grad(prediction, target, dout)
  573. dy = grad(target, prediction, dout)
  574. return dx, dy
  575. return bprop
  576. @bprop_getters.register(P.L2Loss)
  577. def get_bprop_l2_loss(self):
  578. """Grad definition for `L2Loss` operation."""
  579. def bprop(x, out, dout):
  580. dx = x * dout
  581. return (dx,)
  582. return bprop
  583. @bprop_getters.register(P.RNNTLoss)
  584. def get_bprop_rnnt_loss(self):
  585. """Grad definition for `RNNTLoss` operation."""
  586. def bprop(acts, labels, act_lens, label_lens, out, dout):
  587. grad = out[1]
  588. return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
  589. return bprop
  590. @bprop_getters.register(P.PReLU)
  591. def get_bprop_prelu(self):
  592. """Grad definition for `PReLU` operation."""
  593. grad = G.PReLUGrad()
  594. def bprop(x, w, out, dout):
  595. dx, dw = grad(dout, x, w)
  596. return dx, dw
  597. return bprop
  598. @bprop_getters.register(P.LSTM)
  599. def get_bprop_lstm(self):
  600. """Grad definition for `LSTM` operation."""
  601. lstm_grad_data = G.LSTMGradData(
  602. input_size=self.input_size,
  603. hidden_size=self.hidden_size,
  604. num_layers=self.num_layers,
  605. has_bias=self.has_bias,
  606. bidirectional=self.bidirectional,
  607. dropout=self.dropout
  608. )
  609. lstm_grad_weight = G.LSTMGradWeight(
  610. input_size=self.input_size,
  611. hidden_size=self.hidden_size,
  612. num_layers=self.num_layers,
  613. has_bias=self.has_bias,
  614. bidirectional=self.bidirectional,
  615. dropout=self.dropout
  616. )
  617. lstm_grad = G.LSTMGrad(
  618. input_size=self.input_size,
  619. hidden_size=self.hidden_size,
  620. num_layers=self.num_layers,
  621. has_bias=self.has_bias,
  622. bidirectional=self.bidirectional,
  623. dropout=self.dropout
  624. )
  625. def bprop(x, hx, cx, w, out, dout):
  626. y, _, _, reserve, state = out
  627. dy, dhy, dcy, _, _ = dout
  628. dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
  629. dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
  630. return dx, dhx, dcx, dw
  631. #
  632. def bprop_cpu(x, hx, cx, w, out, dout):
  633. y, hy, cy, reserve, _ = out
  634. dy, dhy, dcy, _, _ = dout
  635. dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
  636. return dx, dhx, dcx, dw
  637. if context.get_context('device_target') == "CPU":
  638. return bprop_cpu
  639. return bprop
  640. @bprop_getters.register(P.DynamicRNN)
  641. def get_bprop_dynamic_rnn(self):
  642. """Grad definition for `DynamicRNN` operation."""
  643. dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias)
  644. def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
  645. dy, dh, dc, _, _, _, _, _, = dout
  646. dh = dh[-1]
  647. dc = dc[-1]
  648. y, h, c, i, j, f, o, tanhct = out
  649. dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
  650. c, dy, dh, dc, i, j, f, o, tanhct)
  651. return dx, dw, db, (0), dh_prev, dc_prev
  652. return bprop
  653. @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
  654. def get_bprop_sigmoid_crossentropy_with_logits(self):
  655. """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
  656. op = G.SigmoidCrossEntropyWithLogitsGrad()
  657. def bprop(x, y, out, dout):
  658. dx = op(x, y, dout)
  659. return (dx, zeros_like(y))
  660. return bprop
  661. @bprop_getters.register(P.Pad)
  662. def get_bprop_pad(self):
  663. """Grad definition for `Pad` operation."""
  664. shape_op = P.Shape()
  665. paddings = self.paddings
  666. def bprop(x, out, dout):
  667. begin = ()
  668. for item in paddings:
  669. begin += (item[0],)
  670. shp = shape_op(x)
  671. dx = P.Slice()(dout, begin, shp)
  672. return (dx,)
  673. return bprop
  674. @bprop_getters.register(P.MirrorPad)
  675. def get_bprop_mirror_pad(self):
  676. """Grad definition for `MirrorPad` operation."""
  677. mirror_pad_grad = G.MirrorPadGrad(self.mode)
  678. def bprop(x, paddings, out, dout):
  679. dx = mirror_pad_grad(dout, paddings)
  680. return (dx, zeros_like(paddings))
  681. return bprop
  682. @bprop_getters.register(P.ROIAlign)
  683. def get_bprop_roi_align(self):
  684. """Grad definition for `ROIAlign` operation."""
  685. shape_op = P.Shape()
  686. pooled_height = self.pooled_height
  687. pooled_width = self.pooled_width
  688. spatial_scale = self.spatial_scale
  689. sample_num = self.sample_num
  690. def bprop(inputs, rois, out, dout):
  691. inputs_shape = shape_op(inputs)
  692. dx = G.ROIAlignGrad(inputs_shape,
  693. pooled_height,
  694. pooled_width,
  695. spatial_scale,
  696. sample_num,
  697. )(dout, rois)
  698. return dx, zeros_like(rois)
  699. return bprop
  700. @bprop_getters.register(P.Conv2DBackpropInput)
  701. def get_bprop_conv2d_backprop_input(self):
  702. """Grad definition for `Conv2DBackpropInput` operation."""
  703. filter_grad = G.Conv2DBackpropFilter(
  704. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  705. dilation=self.dilation, stride=self.stride, group=self.group
  706. )
  707. input_grad = P.Conv2D(
  708. self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
  709. dilation=self.dilation, stride=self.stride, group=self.group
  710. )
  711. def bprop(x, w, f_sizes, out, dout):
  712. dx = input_grad(dout, w)
  713. dw = filter_grad(x, dout, F.shape(w))
  714. return dx, dw, zeros_like(f_sizes)
  715. return bprop
  716. @bprop_getters.register(P.BinaryCrossEntropy)
  717. def get_bprop_binary_cross_entropy(self):
  718. """Grad definition for `BinaryCrossEntropy` operation."""
  719. grad = G.BinaryCrossEntropyGrad(self.reduction)
  720. def bprop(x, y, weight, out, dout):
  721. dx = grad(x, y, dout, weight)
  722. return dx, zeros_like(y), zeros_like(weight)
  723. return bprop
  724. @bprop_getters.register(P.KLDivLoss)
  725. def get_bprop_kl_div_loss(self):
  726. """Grad definition for `KLDivLoss` operation."""
  727. grad = G.KLDivLossGrad(self.reduction)
  728. def bprop(x, y, out, dout):
  729. dx, dy = grad(x, y, dout)
  730. return dx, dy
  731. return bprop
  732. @bprop_getters.register(P.Dropout)
  733. def get_bprop_dropout(self):
  734. """Grad definition for `Dropout` operation."""
  735. grad = G.DropoutGrad(self.keep_prob)
  736. def bprop(x, out, dout):
  737. _, mask = out
  738. dy, _ = dout
  739. dx = grad(dy, mask)
  740. return (dx,)
  741. return bprop
  742. @bprop_getters.register(P.CTCLoss)
  743. def get_bprop_ctc_loss(self):
  744. """Grad definition for `CTCLoss` operation"""
  745. expand = P.ExpandDims()
  746. def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
  747. grad_loss = out[1]
  748. grad = grad_loss * expand(dout[0], -1)
  749. return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
  750. return bprop
  751. @bprop_getters.register(P.BasicLSTMCell)
  752. def get_bprop_basic_lstm_cell(self):
  753. """Grad definition for `BasicLSTMCell` operation."""
  754. basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
  755. forget_bias=self.forget_bias,
  756. activation=self.activation
  757. )
  758. basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
  759. basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
  760. def bprop(x, h, c, w, b, out, dout):
  761. _, _, it, jt, ft, ot, tanhct = out
  762. dct, dht, _, _, _, _, _ = dout
  763. dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
  764. dxt, dht = basic_lstm_cell_input_grad(dgate, w)
  765. dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
  766. return dxt, dht, dct_1, dw, db
  767. return bprop
  768. @bprop_getters.register(P.LRN)
  769. def get_bprop_lrn(self):
  770. """Grad definition for `LRN` operation."""
  771. grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
  772. def bprop(x, out, dout):
  773. dx = grad(dout, x, out)
  774. return (dx,)
  775. return bprop