You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_nn_ops.py 40 kB

6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of neural network related operations."""
  16. import os
  17. import numpy as np
  18. from mindspore.ops import _selected_grad_ops as SG
  19. from mindspore.ops.primitive import constexpr
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.ops.operations import nn_ops as nps
  22. from .grad_base import bprop_getters
  23. from .. import functional as F
  24. from .. import operations as P
  25. from ...common import dtype as mstype
  26. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  27. from ..operations import _grad_ops as G
  28. from ..operations import _inner_ops as inner
  29. from ... import context
  30. env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
  31. @bprop_getters.register(P.BiasAdd)
  32. def get_bprop_bias_add(self):
  33. """Grad definition for `BiasAdd` operation."""
  34. bias_grad = SG.BiasAddGrad(self.data_format)
  35. def bprop(x, w, out, dout):
  36. return dout, bias_grad(dout)
  37. return bprop
  38. @bprop_getters.register(P.Conv2D)
  39. def get_bprop_conv2d(self):
  40. """Grad definition for `Conv2D` operation."""
  41. input_grad = P.Conv2DBackpropInput(
  42. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  43. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  44. )
  45. filter_grad = G.Conv2DBackpropFilter(
  46. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  47. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  48. )
  49. get_shape = P.Shape()
  50. def bprop(x, w, out, dout):
  51. dx = input_grad(dout, w, get_shape(x))
  52. if env_force_bprop_seq == '1':
  53. x = F.depend(x, dx)
  54. dw = filter_grad(dout, x, get_shape(w))
  55. return dx, dw
  56. return bprop
  57. @bprop_getters.register(nps.Conv3D)
  58. def get_bprop_conv3d(self):
  59. """Grad definition for `Conv3D` operation."""
  60. input_grad = nps.Conv3DBackpropInput(
  61. self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
  62. pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  63. )
  64. filter_grad = G.Conv3DBackpropFilter(
  65. self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
  66. pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  67. )
  68. get_shape = P.Shape()
  69. def bprop(x, w, out, dout):
  70. dx = input_grad(w, dout, get_shape(x))
  71. dw = filter_grad(x, dout, get_shape(w))
  72. return dx, dw
  73. return bprop
  74. @bprop_getters.register(nps.Conv3DTranspose)
  75. def get_bprop_conv3d_transpose(self):
  76. """Grad definition for `Conv3DTranspose` operation."""
  77. input_grad = nps.Conv3D(
  78. out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
  79. pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  80. )
  81. filter_grad = G.Conv3DBackpropFilter(
  82. out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
  83. pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  84. )
  85. input_size = self.input_size
  86. def bprop(x, w, out, dout):
  87. dx = input_grad(dout, w)
  88. dw = filter_grad(dout, x, F.shape(w))
  89. return dx, dw, zeros_like(input_size)
  90. return bprop
  91. @bprop_getters.register(inner.ExtractImagePatches)
  92. def get_bprop_extract_image_patches(self):
  93. """Grad definition for `ExtractImagePatches` operation."""
  94. get_shape = P.Shape()
  95. reshape = P.Reshape()
  96. extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
  97. strides=self.strides,
  98. rates=self.rates,
  99. padding=self.padding)
  100. concat = P.Concat(axis=-1)
  101. expand_dims = P.ExpandDims()
  102. scatter_nd = P.ScatterNd()
  103. dtype = P.DType()
  104. fill = P.Fill()
  105. slice_op = P.Slice()
  106. transpose = P.Transpose()
  107. cast = P.Cast()
  108. matmul = P.MatMul()
  109. _, _, ksizes_row, ksizes_col = self.ksizes
  110. def bprop(x, out, dout):
  111. x_shape = get_shape(x)
  112. x_batch, x_depth, x_row, x_col = x_shape
  113. x_indices_num = x_row * x_col + 1
  114. x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
  115. x_idx = reshape(x_idx, (1, 1, x_row, x_col))
  116. x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
  117. x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
  118. out_shape = get_shape(out)
  119. _, _, out_row, out_col = out_shape
  120. out_indices_num = out_row * out_col * ksizes_row * ksizes_col
  121. out_idx = F.tuple_to_array(range(out_indices_num))
  122. out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
  123. idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
  124. idx_tensor = reshape(idx_tensor, (-1, 2))
  125. sp_shape = (x_indices_num, out_indices_num)
  126. sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
  127. sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
  128. grad = transpose(dout, (0, 2, 3, 1))
  129. grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
  130. grad = transpose(grad, (1, 2, 3, 4, 0, 5))
  131. grad = reshape(grad, (-1, x_batch * x_depth))
  132. jac = matmul(sp_tensor, grad)
  133. dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
  134. dx = transpose(dx, (2, 3, 0, 1))
  135. return (dx,)
  136. return bprop
  137. @bprop_getters.register(P.DepthwiseConv2dNative)
  138. def get_bprop_depthwise_conv2d_native(self):
  139. """Grad definition for `DepthwiseConv2dNative` operation."""
  140. input_grad = G.DepthwiseConv2dNativeBackpropInput(
  141. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
  142. self.dilation, self.group
  143. )
  144. filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
  145. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
  146. self.dilation, self.group
  147. )
  148. get_shape = P.Shape()
  149. def bprop(x, w, out, dout):
  150. dx = input_grad(get_shape(x), w, dout)
  151. if env_force_bprop_seq == '1':
  152. x = F.depend(x, dx)
  153. dw = filter_grad(x, get_shape(w), dout)
  154. return dx, dw
  155. return bprop
  156. @bprop_getters.register(P.MaxPoolWithArgmax)
  157. def get_bprop_max_pool_with_argmax(self):
  158. """Grad definition for `MaxPoolWithArgmax` operation."""
  159. maxpool_grad = G.MaxPoolGradWithArgmax(
  160. kernel_size=self.kernel_size,
  161. strides=self.strides,
  162. pad_mode=self.pad_mode)
  163. def bprop(x, out, dout):
  164. dx = maxpool_grad(x, dout[0], out[1])
  165. return (dx,)
  166. return bprop
  167. @bprop_getters.register(G.MaxPoolGrad)
  168. def get_bprop_max_pool_grad_grad(self):
  169. """Grad definition for `MaxPoolGrad` operation."""
  170. maxpool_grad_grad = G.MaxPoolGradGrad(
  171. kernel_size=self.kernel_size,
  172. strides=self.strides,
  173. pad_mode=self.pad_mode)
  174. def bprop(x1, x2, grad, out, dout):
  175. dx1 = zeros_like(x1)
  176. dx2 = zeros_like(x2)
  177. dgrad = maxpool_grad_grad(x1, x2, dout)
  178. return (dx1, dx2, dgrad)
  179. return bprop
  180. @bprop_getters.register(G.MaxPoolGradGrad)
  181. def get_bprop_max_pool_grad_grad_grad(self):
  182. """Grad definition for `MaxPoolGradGrad` operation."""
  183. maxpool_grad = G.MaxPoolGrad(
  184. kernel_size=self.kernel_size,
  185. strides=self.strides,
  186. pad_mode=self.pad_mode)
  187. def bprop(x1, x2, grad, out, dout):
  188. dx1 = zeros_like(x1)
  189. dx2 = zeros_like(x2)
  190. dgrad = maxpool_grad(x1, x2, dout)
  191. return (dx1, dx2, dgrad)
  192. return bprop
  193. @bprop_getters.register(P.MaxPool)
  194. def get_bprop_max_pool_grad(self):
  195. """Grad definition for `MaxPool` operation."""
  196. maxpool_grad = G.MaxPoolGrad(
  197. kernel_size=self.kernel_size,
  198. strides=self.strides,
  199. pad_mode=self.pad_mode,
  200. data_format=self.format)
  201. def bprop(x, out, dout):
  202. dx = maxpool_grad(x, out, dout)
  203. return (dx,)
  204. return bprop
  205. def _windowed_output_size(input_size, ksize, stride, pad_mode):
  206. """
  207. helper func for AvgPoolGrad
  208. """
  209. tmp_output = 0
  210. tmp_pad_need = 0
  211. tmp_pad_before = 0
  212. tmp_pad_after = 0
  213. if pad_mode == 'VALID':
  214. tmp_output = (input_size - ksize + stride) // stride
  215. tmp_pad_before = 0
  216. tmp_pad_after = 0
  217. elif pad_mode == 'SAME':
  218. tmp_output = (input_size + stride - 1) // stride
  219. tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size)
  220. tmp_pad_before = tmp_pad_need // 2
  221. tmp_pad_after = tmp_pad_need - tmp_pad_before
  222. return tmp_output, tmp_pad_before, tmp_pad_after
  223. @constexpr
  224. def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype):
  225. """
  226. helper func for AvgPoolGrad.
  227. `assist_input_matrix` is a 2d matrix with input_shape after padding,
  228. the value of element which is padded is 0, else are 1.
  229. For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize,
  230. w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the
  231. number of input that associate with output element.
  232. """
  233. n_input, c_input, h_input, w_input = x_shape
  234. h_ksize, w_ksize = ksize[2], ksize[3]
  235. h_stride, w_stride = stride[2], stride[3]
  236. n_output = n_input
  237. c_output = c_input
  238. h_output, w_output = 0, 0
  239. pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
  240. h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize,
  241. h_stride, pad_mode)
  242. w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize,
  243. w_stride, pad_mode)
  244. output_size = n_output * c_output * h_output * w_output
  245. output_shape = (n_output, c_output, h_output, w_output)
  246. output = np.array([0.0] * output_size)
  247. output = np.reshape(output, output_shape)
  248. in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right)
  249. assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32)
  250. if pad_top > 0:
  251. assist_input_matrix[:pad_top, :] = 0
  252. if pad_bottom > 0:
  253. assist_input_matrix[-pad_bottom:, :] = 0
  254. if pad_left > 0:
  255. assist_input_matrix[:, :pad_left] = 0
  256. if pad_right > 0:
  257. assist_input_matrix[:, -pad_right:] = 0
  258. for h in range(h_output):
  259. for w in range(w_output):
  260. curr_input = assist_input_matrix[h * h_stride: h * h_stride + h_ksize, w * w_stride: w * w_stride + w_ksize]
  261. curr_sum = np.sum(curr_input)
  262. if curr_sum > 0:
  263. output[:, :, h, w] = 1. / curr_sum
  264. return Tensor(output, x_dtype)
  265. @constexpr
  266. def _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype):
  267. kernel_matrix = np.ones(kernel_matrix_shape)
  268. return Tensor(kernel_matrix, x_dtype)
  269. @bprop_getters.register(P.AvgPool)
  270. def get_bprop_avg_pool_grad(self):
  271. """Grad definition for `AvgPool` operation."""
  272. # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same
  273. if self.target == "GPU":
  274. avgpool_grad_gpu = G.AvgPoolGradGpu(
  275. kernel_size=self.kernel_size,
  276. strides=self.strides,
  277. pad_mode=self.pad_mode,
  278. data_format=self.format)
  279. def bprop_gpu(x, out, dout):
  280. dx = avgpool_grad_gpu(x, out, dout)
  281. return (dx,)
  282. bprop_fn = bprop_gpu
  283. elif self.target == "CPU":
  284. avgpool_grad_cpu = G.AvgPoolGradCpu(
  285. kernel_size=self.kernel_size,
  286. strides=self.strides,
  287. pad_mode=self.pad_mode,
  288. data_format=self.format)
  289. def bprop_cpu(x, out, dout):
  290. dx = avgpool_grad_cpu(x, out, dout)
  291. return (dx,)
  292. bprop_fn = bprop_cpu
  293. elif self.target == "GE":
  294. avgpool_grad_ge = G.AvgPoolGrad(
  295. kernel_size=self.kernel_size,
  296. strides=self.strides,
  297. pad_mode=self.pad_mode)
  298. shape_op = P.Shape()
  299. def bprop_ge(x, out, dout):
  300. dx = avgpool_grad_ge(shape_op(x), dout)
  301. return (dx,)
  302. bprop_fn = bprop_ge
  303. else:
  304. avgpool_grad_vm = G.AvgPoolGradVm(
  305. kernel_size=self.kernel_size,
  306. strides=self.strides,
  307. pad_mode=self.pad_mode)
  308. k_size_nchw = avgpool_grad_vm.kernel_size
  309. stride_nchw = avgpool_grad_vm.strides
  310. pad_mode = self.pad_mode
  311. def bprop_vm(x, out, dout):
  312. x_shape_nchw = F.shape(x)
  313. x_dtype = F.dtype(x)
  314. kernel_matrix_shape = (1, x_shape_nchw[1],
  315. k_size_nchw[2],
  316. k_size_nchw[3])
  317. mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, pad_mode, x_dtype)
  318. kernel_matrix = _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype)
  319. dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix)
  320. return (dx,)
  321. bprop_fn = bprop_vm
  322. return bprop_fn
  323. @bprop_getters.register(P.DropoutGenMask)
  324. def get_bprop_dropout_gen_mask(self):
  325. """Grad definition for `DropoutGenMask` operation."""
  326. def bprop(shape, keep_prob, out, dout):
  327. return (zeros_like(shape), zeros_like(keep_prob))
  328. return bprop
  329. @bprop_getters.register(P.DropoutDoMask)
  330. def get_bprop_dropout_do_mask(self):
  331. """Grad definition for `DropoutDoMask` operation."""
  332. do_mask = P.DropoutDoMask()
  333. def bprop(x, y, keep_prob, out, dout):
  334. return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
  335. return bprop
  336. @bprop_getters.register(P.Mish)
  337. def get_bprop_mish(self):
  338. """Grad definition for `Mish` operation."""
  339. tanh = P.Tanh()
  340. tanh_grad = SG.TanhGrad()
  341. softplus = P.Softplus()
  342. softplus_grad = G.SoftplusGrad()
  343. def bprop(x, out, dout):
  344. dx1 = tanh(softplus(x))
  345. dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
  346. dx = (dx1 * dout + dx2)
  347. return (dx,)
  348. return bprop
  349. @bprop_getters.register(P.SeLU)
  350. def get_bprop_selu(self):
  351. """Grad definition for `SeLU` operation."""
  352. scale = 1.0507009873554804934193349852946
  353. elu_grad = G.EluGrad()
  354. def bprop(x, out, dout):
  355. dx = elu_grad(dout, out) * scale
  356. return (dx,)
  357. return bprop
  358. @bprop_getters.register(P.MulNoNan)
  359. def get_bprop_mul_no_nan(self):
  360. """Grad definition for `MulNoNan` operation."""
  361. mul_no_nan = P.MulNoNan()
  362. reduce_sum = P.ReduceSum()
  363. reshape = P.Reshape()
  364. def bprop(x, y, out, dout):
  365. x_shape = F.shape(x)
  366. y_shape = F.shape(y)
  367. dx = mul_no_nan(dout, y)
  368. dy = mul_no_nan(x, dout)
  369. broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
  370. if broadcast_x != ():
  371. dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
  372. if broadcast_y != ():
  373. dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
  374. return dx, dy
  375. return bprop
  376. @bprop_getters.register(P.ReLU)
  377. def get_bprop_relu(self):
  378. """Grad definition for `ReLU` operation."""
  379. input_grad = G.ReluGrad()
  380. def bprop(x, out, dout):
  381. dx = input_grad(dout, out)
  382. return (dx,)
  383. return bprop
  384. @bprop_getters.register(G.ReluGrad)
  385. def get_bprop_relu_grad(self):
  386. """Grad definition for `ReLUGrad` operation."""
  387. input_grad = G.ReluGrad()
  388. def bprop(grad, y, out, dout):
  389. dgrad = input_grad(dout, y)
  390. return dgrad, zeros_like(y)
  391. return bprop
  392. @bprop_getters.register(P.ReLU6)
  393. def get_bprop_relu6(self):
  394. """Grad definition for `ReLU6` operation."""
  395. input_grad = G.ReLU6Grad()
  396. def bprop(x, out, dout):
  397. dx = input_grad(dout, x)
  398. return (dx,)
  399. return bprop
  400. @bprop_getters.register(P.ReLUV2)
  401. def get_bprop_relu_v2(self):
  402. """Grad definition for `ReLUV2` operation."""
  403. input_grad = G.ReluGradV2()
  404. def bprop(x, out, dout):
  405. mask = out[1]
  406. dx = input_grad(dout[0], mask)
  407. return (dx,)
  408. return bprop
  409. @bprop_getters.register(P.HSwish)
  410. def get_bprop_hswish(self):
  411. """Grad definition for `HSwish` operation."""
  412. input_grad = G.HSwishGrad()
  413. def bprop(x, out, dout):
  414. dx = input_grad(dout, x)
  415. return (dx,)
  416. return bprop
  417. @bprop_getters.register(P.HSigmoid)
  418. def get_bprop_hsigmoid(self):
  419. """Grad definition for `HSigmoid` operation."""
  420. input_grad = G.HSigmoidGrad()
  421. def bprop(x, out, dout):
  422. dx = input_grad(dout, x)
  423. return (dx,)
  424. return bprop
  425. @bprop_getters.register(P.Elu)
  426. def get_bprop_elu(self):
  427. """Grad definition for `Elu` operation."""
  428. input_grad = G.EluGrad()
  429. def bprop(x, out, dout):
  430. dx = input_grad(dout, out)
  431. return (dx,)
  432. return bprop
  433. @bprop_getters.register(P.Sigmoid)
  434. def get_bprop_sigmoid(self):
  435. """Grad definition for `Sigmoid` operation."""
  436. input_grad = G.SigmoidGrad()
  437. def bprop(x, out, dout):
  438. dx = input_grad(out, dout)
  439. return (dx,)
  440. return bprop
  441. @bprop_getters.register(G.SigmoidGrad)
  442. def get_bprop_sigmoid_grad(self):
  443. """Grad definition for `SigmoidGrad` operation."""
  444. sigmoid_grad = G.SigmoidGrad()
  445. def bprop(y, grad, out, dout):
  446. dy = dout * grad * (1. - 2 * y)
  447. dgrad = sigmoid_grad(y, dout)
  448. return dy, dgrad
  449. return bprop
  450. @constexpr
  451. def _get_transpose_axis(x_shp, axis):
  452. rank = len(x_shp)
  453. if axis < 0:
  454. axis += rank
  455. reverse_axis = [i for i in range(rank)]
  456. reverse_axis[axis] = rank - 1
  457. reverse_axis[rank - 1] = axis
  458. return tuple(reverse_axis)
  459. @bprop_getters.register(P.Softmax)
  460. def get_bprop_softmax(self):
  461. """Grad definition for `Softmax` operation."""
  462. sum_func = P.ReduceSum(keep_dims=True)
  463. sub = P.Sub()
  464. mul = P.Mul()
  465. get_shape = P.Shape()
  466. transpose = P.Transpose()
  467. axis = self.axis
  468. if not isinstance(axis, int):
  469. axis = axis[0]
  470. def bprop(x, out, dout):
  471. # dx = (dout - sum(dout * out)) * out
  472. # This formula is correct only when the `axis` is the last dimension.
  473. # In order to support the scenario where the `axis` is other values,
  474. # we transpose the data of the `axis` dimension to the last dimension for calculation,
  475. # and then transpose it back after the calculation.
  476. reverse_axis = _get_transpose_axis(get_shape(x), axis)
  477. out = transpose(out, reverse_axis)
  478. dout = transpose(dout, reverse_axis)
  479. dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
  480. dx = transpose(dx, reverse_axis)
  481. return (dx,)
  482. return bprop
  483. @bprop_getters.register(P.LogSoftmax)
  484. def get_bprop_log_softmax(self):
  485. """Grad definition for `LogSoftmax` operation."""
  486. logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
  487. def bprop(x, out, dout):
  488. dx = logsoftmax_grad(out, dout)
  489. return (dx,)
  490. return bprop
  491. @bprop_getters.register(P.Softplus)
  492. def get_bprop_softplus(self):
  493. """Grad definition for `Softplus` operation."""
  494. softplus_grad = G.SoftplusGrad()
  495. def bprop(x, out, dout):
  496. dx = softplus_grad(dout, x)
  497. return (dx,)
  498. return bprop
  499. @bprop_getters.register(P.Softsign)
  500. def get_bprop_softsign(self):
  501. """Grad definition for `Softsign` operation."""
  502. mul = P.Mul()
  503. absolute = P.Abs()
  504. div = P.Div()
  505. square = P.Square()
  506. def bprop(x, out, dout):
  507. dx = mul(dout, div(1, square(1 + absolute(x))))
  508. return (dx,)
  509. return bprop
  510. @bprop_getters.register(P.Tanh)
  511. def get_bprop_tanh(self):
  512. """Grad definition for `Tanh` operation."""
  513. tanh_grad = SG.TanhGrad()
  514. def bprop(x, out, dout):
  515. dx = tanh_grad(out, dout)
  516. return (dx,)
  517. return bprop
  518. @bprop_getters.register(G.TanhGrad)
  519. def get_bprop_tanh_grad(self):
  520. """Grad definition for `TanhGrad` operation."""
  521. tanh_grad = G.TanhGrad()
  522. def bprop(y, grad, out, dout):
  523. dy = dout * -2.0 * grad * y
  524. dgrad = tanh_grad(y, dout)
  525. return dy, dgrad
  526. return bprop
  527. @bprop_getters.register(P.GeLU)
  528. def get_bprop_gelu(self):
  529. """Grad definition for `GeLU` operation."""
  530. input_grad = G.GeLUGrad()
  531. def bprop(x, out, dout):
  532. dx = input_grad(dout, x, out)
  533. return (dx,)
  534. return bprop
  535. @bprop_getters.register(P.Gelu)
  536. def get_bprop_gelu_2(self):
  537. """Grad definition for `GeLU` operation."""
  538. input_grad = G.GeLUGrad()
  539. def bprop(x, out, dout):
  540. dx = input_grad(dout, x, out)
  541. return (dx,)
  542. return bprop
  543. @bprop_getters.register(P.FastGeLU)
  544. def get_bprop_fast_gelu(self):
  545. """Grad definition for `FastGeLU` operation."""
  546. input_grad = G.FastGeLUGrad()
  547. def bprop(x, out, dout):
  548. dx = input_grad(dout, x)
  549. return (dx,)
  550. return bprop
  551. @bprop_getters.register(P.FastGelu)
  552. def get_bprop_fast_gelu_2(self):
  553. """Grad definition for `FastGeLU` operation."""
  554. input_grad = G.FastGeLUGrad()
  555. def bprop(x, out, dout):
  556. dx = input_grad(dout, x)
  557. return (dx,)
  558. return bprop
  559. @bprop_getters.register(P.FusedBatchNorm)
  560. def get_bprop_fused_batch_norm(self):
  561. """Grad definition for `FusedBatchNorm` operation."""
  562. input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum)
  563. target_cpu = False
  564. if self.target == "CPU":
  565. input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum)
  566. target_cpu = True
  567. def bprop(x, scale, b, mean, variance, out, dout):
  568. saved_mean = out[3]
  569. saved_variance = out[4]
  570. if target_cpu:
  571. out = input_grad(dout[0], x, scale, b, saved_mean, saved_variance)
  572. else:
  573. out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
  574. dx = out[0]
  575. dscale = out[1]
  576. dbias = out[2]
  577. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  578. return bprop
  579. @bprop_getters.register(P.FusedBatchNormEx)
  580. def get_bprop_fused_batch_norm_ex(self):
  581. """Grad definition for `FusedBatchNormEx` operation."""
  582. input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum, self.format)
  583. def bprop(x, scale, b, mean, variance, out, dout):
  584. saved_mean = out[3]
  585. saved_variance = out[4]
  586. reserve = out[5]
  587. out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
  588. dx = out[0]
  589. dscale = out[1]
  590. dbias = out[2]
  591. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  592. return bprop
  593. @bprop_getters.register(P.InstanceNorm)
  594. def get_bprop_instance_norm(self):
  595. """Grad definition for `InstanceNorm` operation."""
  596. is_training = self.is_training
  597. input_grad = G.InstanceNormGrad(is_training, self.epsilon, self.momentum)
  598. def bprop(x, gamma, beta, mean, variance, out, dout):
  599. saved_mean = out[1]
  600. saved_variance = out[2]
  601. out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
  602. dx = out[0]
  603. dgamma = out[1]
  604. dbeta = out[2]
  605. return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
  606. return bprop
  607. @bprop_getters.register(P.BatchNorm)
  608. def get_bprop_batch_norm(self):
  609. """Grad definition for `BatchNorm` operation."""
  610. is_training = self.is_training
  611. input_grad = G.BatchNormGrad(is_training, self.epsilon)
  612. def bprop(x, scale, b, mean, variance, out, dout):
  613. if is_training:
  614. saved_reserve_1 = out[3]
  615. saved_reserve_2 = out[4]
  616. else:
  617. saved_reserve_1 = mean
  618. saved_reserve_2 = variance
  619. out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
  620. dx = out[0]
  621. dscale = out[1]
  622. dbias = out[2]
  623. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  624. return bprop
  625. @bprop_getters.register(P.LayerNorm)
  626. def get_bprop_layer_norm(self):
  627. """Grad definition for `LayerNorm` operation."""
  628. layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
  629. def bprop(x, gamma, beta, out, dout):
  630. dx, d_gamma, d_beta = layer_norm_grad(
  631. x, dout[0], out[2], out[1], gamma)
  632. return dx, d_gamma, d_beta
  633. return bprop
  634. @bprop_getters.register(G.LayerNormGrad)
  635. def get_bprop_layer_norm_grad(self):
  636. """Grad definition for `LayerNormGrad` operation."""
  637. layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
  638. def bprop(x, dy, variance, mean, gamma, out, dout):
  639. d_x, d_dy, d_gamma = layer_norm_grad_grad(
  640. x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
  641. return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
  642. return bprop
  643. @bprop_getters.register(P.L2Normalize)
  644. def get_bprop_l2normalize(self):
  645. """Grad definition for `L2Normalize` operation."""
  646. input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
  647. def bprop(x, out, dout):
  648. dx = input_grad(x, out, dout)
  649. return (dx,)
  650. return bprop
  651. @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
  652. def get_bprop_softmax_cross_entropy_with_logits(self):
  653. """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
  654. expand = P.ExpandDims()
  655. def bprop(logits, labels, out, dout):
  656. grad = out[1]
  657. grad = grad * expand(dout[0], -1)
  658. return grad, zeros_like(labels)
  659. return bprop
  660. @bprop_getters.register(P.NLLLoss)
  661. def get_bprop_nll_loss(self):
  662. """Grad definition for `NLLLoss` operation."""
  663. nll_loss_grad = G.NLLLossGrad(reduction=self.reduction)
  664. def bprop(x, target, weight, out, dout):
  665. total_weight = out[1]
  666. dout_x = dout[0]
  667. dx = nll_loss_grad(x, dout_x, target, weight, total_weight)
  668. return dx, zeros_like(target), zeros_like(weight)
  669. return bprop
  670. @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
  671. def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
  672. """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
  673. is_grad = self.is_grad
  674. grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
  675. def bprop(logits, labels, out, dout):
  676. grad = out[0]
  677. if not is_grad:
  678. # if construct use loss
  679. grad = grad_op(logits, labels)
  680. grad = F.depend(grad, out)
  681. grad = grad * dout
  682. return grad, zeros_like(labels)
  683. return bprop
  684. @bprop_getters.register(P.ResizeBilinear)
  685. def get_bprop_resize_bilinear(self):
  686. """Grad definition for `ResizeBilinear` operation."""
  687. resize_grad = G.ResizeBilinearGrad(self.align_corners)
  688. def bprop(x, out, dout):
  689. dx = resize_grad(dout, x)
  690. return (dx,)
  691. return bprop
  692. @bprop_getters.register(P.OneHot)
  693. def get_bprop_onehot(self):
  694. """Grad definition for `OneHot` operation."""
  695. def bprop(indices, depth, on_value, off_value, out, dout):
  696. return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
  697. return bprop
  698. @constexpr
  699. def _range_op(start, limit, delta, dtype):
  700. """helper function for Grad TopK"""
  701. output_tensor = Tensor(list(range(start, limit, delta)), dtype)
  702. return output_tensor
  703. @constexpr
  704. def _get_1d_shape(in_shape):
  705. """helper function for Grad TopK"""
  706. out_shape = 1
  707. for i in in_shape:
  708. out_shape *= i
  709. return (out_shape,)
  710. @bprop_getters.register(P.TopK)
  711. def get_bprop_top_kv2(self):
  712. """Grad definition for `TopK` operation."""
  713. scatter = P.ScatterNd()
  714. expand_dims = P.ExpandDims()
  715. shape_op = P.Shape()
  716. reshape_op = P.Reshape()
  717. dtype = P.DType()
  718. def bprop(input_x, k, out, dout):
  719. in_shape = shape_op(input_x)
  720. in_lastdim = in_shape[-1]
  721. indices = out[1]
  722. ind_shape = shape_op(indices)
  723. ind_lastdim = ind_shape[-1]
  724. ind_2d = reshape_op(indices, (-1, ind_lastdim))
  725. outerdim = shape_op(ind_2d)[0]
  726. # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
  727. indices_dtype = dtype(indices)
  728. range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
  729. # expand_dims to (k, 1), then broadcast
  730. ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
  731. in_shape_1d = _get_1d_shape(in_shape)
  732. out_grad = reshape_op(
  733. scatter(
  734. expand_dims(ind, -1),
  735. reshape_op(dout[0], (-1,)),
  736. in_shape_1d),
  737. in_shape)
  738. return out_grad, zeros_like(k)
  739. return bprop
  740. @bprop_getters.register(P.SmoothL1Loss)
  741. def get_bprop_smooth_l1_loss(self):
  742. """Grad definition for `SmoothL1Loss` operation."""
  743. grad = G.SmoothL1LossGrad(self.beta)
  744. def bprop(prediction, target, out, dout):
  745. dx = grad(prediction, target, dout)
  746. dy = grad(target, prediction, dout)
  747. return dx, dy
  748. return bprop
  749. @bprop_getters.register(P.L2Loss)
  750. def get_bprop_l2_loss(self):
  751. """Grad definition for `L2Loss` operation."""
  752. def bprop(x, out, dout):
  753. dx = x * dout
  754. return (dx,)
  755. return bprop
  756. @bprop_getters.register(P.RNNTLoss)
  757. def get_bprop_rnnt_loss(self):
  758. """Grad definition for `RNNTLoss` operation."""
  759. def bprop(acts, labels, act_lens, label_lens, out, dout):
  760. grad = out[1]
  761. return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
  762. return bprop
  763. @bprop_getters.register(P.PReLU)
  764. def get_bprop_prelu(self):
  765. """Grad definition for `PReLU` operation."""
  766. grad = G.PReLUGrad()
  767. def bprop(x, w, out, dout):
  768. dx, dw = grad(dout, x, w)
  769. return dx, dw
  770. return bprop
  771. @bprop_getters.register(P.LSTM)
  772. def get_bprop_lstm(self):
  773. """Grad definition for `LSTM` operation."""
  774. lstm_grad_data = G.LSTMGradData(
  775. input_size=self.input_size,
  776. hidden_size=self.hidden_size,
  777. num_layers=self.num_layers,
  778. has_bias=self.has_bias,
  779. bidirectional=self.bidirectional,
  780. dropout=self.dropout
  781. )
  782. lstm_grad_weight = G.LSTMGradWeight(
  783. input_size=self.input_size,
  784. hidden_size=self.hidden_size,
  785. num_layers=self.num_layers,
  786. has_bias=self.has_bias,
  787. bidirectional=self.bidirectional,
  788. dropout=self.dropout
  789. )
  790. lstm_grad = G.LSTMGrad(
  791. input_size=self.input_size,
  792. hidden_size=self.hidden_size,
  793. num_layers=self.num_layers,
  794. has_bias=self.has_bias,
  795. bidirectional=self.bidirectional,
  796. dropout=self.dropout
  797. )
  798. def bprop(x, hx, cx, w, out, dout):
  799. y, _, _, reserve, state = out
  800. dy, dhy, dcy, _, _ = dout
  801. dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
  802. dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
  803. return dx, dhx, dcx, dw
  804. #
  805. def bprop_cpu(x, hx, cx, w, out, dout):
  806. y, hy, cy, reserve, _ = out
  807. dy, dhy, dcy, _, _ = dout
  808. dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
  809. return dx, dhx, dcx, dw
  810. if context.get_context('device_target') == "CPU":
  811. return bprop_cpu
  812. return bprop
  813. @bprop_getters.register(P.DynamicRNN)
  814. def get_bprop_dynamic_rnn(self):
  815. """Grad definition for `DynamicRNN` operation."""
  816. dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
  817. direction=self.direction,
  818. cell_depth=self.cell_depth,
  819. use_peephole=self.use_peephole,
  820. keep_prob=self.keep_prob,
  821. cell_clip=self.cell_clip,
  822. num_proj=self.num_proj,
  823. time_major=self.time_major,
  824. forget_bias=self.forget_bias)
  825. expand_dims = P.ExpandDims()
  826. def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
  827. dy, dh, dc, _, _, _, _, _, = dout
  828. dh = dh[-1]
  829. dc = dc[-1]
  830. y, h, c, i, j, f, o, tanhct = out
  831. dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
  832. c, dy, dh, dc, i, j, f, o, tanhct)
  833. dh_prev = expand_dims(dh_prev, 0)
  834. dc_prev = expand_dims(dc_prev, 0)
  835. return dx, dw, db, (0), dh_prev, dc_prev
  836. return bprop
  837. @bprop_getters.register(P.DynamicGRUV2)
  838. def get_bprop_dynamic_gru_v2(self):
  839. """Grad definition for `DynamicGRUV2` operation."""
  840. dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
  841. self.num_proj, self.time_major, self.gate_order,
  842. self.reset_after)
  843. def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
  844. y, out_h, update, reset, new, hidden_new = out
  845. dy, dout_h, _, _, _, _ = dout
  846. dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
  847. out_h, dy, dout_h[-1], update,
  848. reset, new, hidden_new, None, None)
  849. return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
  850. return bprop
  851. @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
  852. def get_bprop_sigmoid_crossentropy_with_logits(self):
  853. """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
  854. op = G.SigmoidCrossEntropyWithLogitsGrad()
  855. def bprop(x, y, out, dout):
  856. dx = op(x, y, dout)
  857. return (dx, zeros_like(y))
  858. return bprop
  859. @bprop_getters.register(P.Pad)
  860. def get_bprop_pad(self):
  861. """Grad definition for `Pad` operation."""
  862. shape_op = P.Shape()
  863. paddings = self.paddings
  864. def bprop(x, out, dout):
  865. begin = ()
  866. for item in paddings:
  867. begin += (item[0],)
  868. shp = shape_op(x)
  869. dx = P.Slice()(dout, begin, shp)
  870. return (dx,)
  871. return bprop
  872. @bprop_getters.register(P.MirrorPad)
  873. def get_bprop_mirror_pad(self):
  874. """Grad definition for `MirrorPad` operation."""
  875. mirror_pad_grad = G.MirrorPadGrad(self.mode)
  876. def bprop(x, paddings, out, dout):
  877. dx = mirror_pad_grad(dout, paddings)
  878. return (dx, zeros_like(paddings))
  879. return bprop
  880. @bprop_getters.register(P.ROIAlign)
  881. def get_bprop_roi_align(self):
  882. """Grad definition for `ROIAlign` operation."""
  883. shape_op = P.Shape()
  884. pooled_height = self.pooled_height
  885. pooled_width = self.pooled_width
  886. spatial_scale = self.spatial_scale
  887. sample_num = self.sample_num
  888. def bprop(inputs, rois, out, dout):
  889. inputs_shape = shape_op(inputs)
  890. dx = G.ROIAlignGrad(inputs_shape,
  891. pooled_height,
  892. pooled_width,
  893. spatial_scale,
  894. sample_num,
  895. )(dout, rois)
  896. return dx, zeros_like(rois)
  897. return bprop
  898. @bprop_getters.register(P.Conv2DBackpropInput)
  899. def get_bprop_conv2d_backprop_input(self):
  900. """Grad definition for `Conv2DBackpropInput` operation."""
  901. filter_grad = G.Conv2DBackpropFilter(
  902. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  903. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  904. )
  905. input_grad = P.Conv2D(
  906. self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
  907. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  908. )
  909. def bprop(x, w, f_sizes, out, dout):
  910. dx = input_grad(dout, w)
  911. if env_force_bprop_seq == '1':
  912. x = F.depend(x, dx)
  913. dw = filter_grad(x, dout, F.shape(w))
  914. return dx, dw, zeros_like(f_sizes)
  915. return bprop
  916. @bprop_getters.register(P.BinaryCrossEntropy)
  917. def get_bprop_binary_cross_entropy(self):
  918. """Grad definition for `BinaryCrossEntropy` operation."""
  919. grad = G.BinaryCrossEntropyGrad(self.reduction)
  920. def bprop(x, y, weight, out, dout):
  921. dx = grad(x, y, dout, weight)
  922. return dx, zeros_like(y), zeros_like(weight)
  923. return bprop
  924. @bprop_getters.register(P.BCEWithLogitsLoss)
  925. def get_bprop_ce_with_logits_loss(self):
  926. """Grad definition for `BCEWithLogitsLoss` operation."""
  927. reduction = self.reduction
  928. mul = P.Mul()
  929. sigmoid = P.Sigmoid()
  930. add = P.TensorAdd()
  931. sub = P.Sub()
  932. size = P.Size()
  933. def bprop(predict, target, weight, pos_weight, out, dout):
  934. sigmoid_input = sigmoid(predict)
  935. if pos_weight is not None:
  936. t = mul(target, pos_weight)
  937. dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
  938. else:
  939. dx = mul((sigmoid_input - target), dout)
  940. if weight is not None:
  941. dx = mul(dx, weight)
  942. if reduction == 'mean':
  943. dx = dx / size(dx)
  944. return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight)
  945. return bprop
  946. @bprop_getters.register(P.KLDivLoss)
  947. def get_bprop_kl_div_loss(self):
  948. """Grad definition for `KLDivLoss` operation."""
  949. grad = G.KLDivLossGrad(self.reduction)
  950. def bprop(x, y, out, dout):
  951. dx, dy = grad(x, y, dout)
  952. return dx, dy
  953. return bprop
  954. @bprop_getters.register(P.Dropout)
  955. def get_bprop_dropout(self):
  956. """Grad definition for `Dropout` operation."""
  957. grad = G.DropoutGrad(self.keep_prob)
  958. def bprop(x, out, dout):
  959. _, mask = out
  960. dy, _ = dout
  961. dx = grad(dy, mask)
  962. return (dx,)
  963. return bprop
  964. @bprop_getters.register(P.CTCLoss)
  965. def get_bprop_ctc_loss(self):
  966. """Grad definition for `CTCLoss` operation"""
  967. expand = P.ExpandDims()
  968. def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
  969. grad_loss = out[1]
  970. grad = grad_loss * expand(dout[0], -1)
  971. return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
  972. return bprop
  973. @bprop_getters.register(P.BasicLSTMCell)
  974. def get_bprop_basic_lstm_cell(self):
  975. """Grad definition for `BasicLSTMCell` operation."""
  976. basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
  977. forget_bias=self.forget_bias,
  978. activation=self.activation
  979. )
  980. basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
  981. basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
  982. def bprop(x, h, c, w, b, out, dout):
  983. _, _, it, jt, ft, ot, tanhct = out
  984. dct, dht, _, _, _, _, _ = dout
  985. dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
  986. dxt, dht = basic_lstm_cell_input_grad(dgate, w)
  987. dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
  988. return dxt, dht, dct_1, dw, db
  989. return bprop
  990. @bprop_getters.register(P.LRN)
  991. def get_bprop_lrn(self):
  992. """Grad definition for `LRN` operation."""
  993. grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
  994. def bprop(x, out, dout):
  995. dx = grad(dout, x, out)
  996. return (dx,)
  997. return bprop