You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_nn_ops.py 19 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of neural network related operations."""
  16. from mindspore.common import dtype as mstype
  17. from .. import functional as F
  18. from .. import operations as P
  19. from ..operations import _grad_ops as G
  20. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  21. from .grad_base import bprop_getters
  22. @bprop_getters.register(P.BiasAdd)
  23. def get_bprop_bias_add(self):
  24. """Grad definition for `BiasAdd` operation."""
  25. bias_grad = G.BiasAddGrad()
  26. def bprop(x, w, out, dout):
  27. return dout, bias_grad(dout)
  28. return bprop
  29. @bprop_getters.register(P.Conv2D)
  30. def get_bprop_conv2d(self):
  31. """Grad definition for `Conv2D` operation."""
  32. input_grad = P.Conv2DBackpropInput(
  33. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  34. dilation=self.dilation, stride=self.stride, group=self.group
  35. )
  36. filter_grad = G.Conv2DBackpropFilter(
  37. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  38. dilation=self.dilation, stride=self.stride, group=self.group
  39. )
  40. get_shape = P.Shape()
  41. def bprop(x, w, out, dout):
  42. dx = input_grad(dout, w, get_shape(x))
  43. dw = filter_grad(dout, x, get_shape(w))
  44. return dx, dw
  45. return bprop
  46. @bprop_getters.register(P.ExtractImagePatches)
  47. def get_bprop_extract_image_patches(self):
  48. """Grad definition for `ExtractImagePatches` operation."""
  49. get_shape = P.Shape()
  50. reshape = P.Reshape()
  51. extract_image_patches = P.ExtractImagePatches(ksizes=self.ksizes,
  52. strides=self.strides,
  53. rates=self.rates,
  54. padding=self.padding)
  55. concat = P.Concat(axis=-1)
  56. expand_dims = P.ExpandDims()
  57. scatter_nd = P.ScatterNd()
  58. dtype = P.DType()
  59. fill = P.Fill()
  60. slice_op = P.Slice()
  61. transpose = P.Transpose()
  62. matmul = P.MatMul()
  63. cast = P.Cast()
  64. _, ksizes_row, ksizes_col, _ = self.ksizes
  65. def bprop(x, out, dout):
  66. x_shape = get_shape(x)
  67. x_batch, x_row, x_col, x_depth = x_shape
  68. x_indices_num = x_row * x_col + 1
  69. x_idx = F.tuple_to_array(range(1, x_indices_num))
  70. x_idx = reshape(x_idx, (1, x_row, x_col, 1))
  71. x_idx = cast(x_idx, mstype.float16)
  72. x_idx_patch = extract_image_patches(x_idx)
  73. x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2))
  74. x_idx_patch = cast(x_idx_patch, mstype.int32)
  75. out_shape = get_shape(out)
  76. _, out_row, out_col, _ = out_shape
  77. out_indices_num = out_row * out_col * ksizes_row * ksizes_col
  78. out_idx = F.tuple_to_array(range(out_indices_num))
  79. out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col))
  80. idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
  81. idx_tensor = reshape(idx_tensor, (-1, 2))
  82. sp_shape = (x_indices_num, out_indices_num)
  83. sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
  84. sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
  85. grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
  86. grad = transpose(grad, (1, 2, 3, 4, 0, 5))
  87. grad = reshape(grad, (-1, x_batch * x_depth))
  88. jac = matmul(sp_tensor, grad)
  89. dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
  90. dx = transpose(dx, (2, 0, 1, 3))
  91. return (dx,)
  92. return bprop
  93. @bprop_getters.register(P.DepthwiseConv2dNative)
  94. def get_bprop_depthwise_conv2d_native(self):
  95. """Grad definition for `DepthwiseConv2dNative` operation."""
  96. input_grad = G.DepthwiseConv2dNativeBackpropInput(
  97. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pads, self.mode, self.stride,
  98. self.dilation, self.group
  99. )
  100. filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
  101. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pads, self.mode, self.stride,
  102. self.dilation, self.group
  103. )
  104. get_shape = P.Shape()
  105. def bprop(x, w, out, dout):
  106. dx = input_grad(get_shape(x), w, dout)
  107. dw = filter_grad(x, get_shape(w), dout)
  108. return dx, dw
  109. return bprop
  110. @bprop_getters.register(P.MaxPoolWithArgmax)
  111. def get_bprop_max_pool_with_argmax(self):
  112. """Grad definition for `MaxPoolWithArgmax` operation."""
  113. maxpool_grad = G.MaxPoolGradWithArgmax(
  114. ksize=self.ksize,
  115. strides=self.strides,
  116. padding=self.padding,)
  117. def bprop(x, out, dout):
  118. dx = maxpool_grad(x, dout[0], out[1])
  119. return (dx,)
  120. return bprop
  121. @bprop_getters.register(P.MaxPool)
  122. def get_bprop_max_pool_grad(self):
  123. """Grad definition for `MaxPool` operation."""
  124. maxpool_grad = G.MaxPoolGrad(
  125. ksize=self.ksize,
  126. strides=self.strides,
  127. padding=self.padding)
  128. def bprop(x, out, dout):
  129. dx = maxpool_grad(x, out, dout)
  130. return (dx,)
  131. return bprop
  132. @bprop_getters.register(P.AvgPool)
  133. def get_bprop_avg_pool_grad(self):
  134. """Grad definition for `AvgPool` operation."""
  135. avgpool_grad = G.AvgPoolGrad(
  136. ksize=self.ksize,
  137. strides=self.strides,
  138. padding=self.padding)
  139. shape_op = P.Shape()
  140. avgpool_grad_gpu = G.AvgPoolGradGpu(
  141. ksize=self.ksize,
  142. strides=self.strides,
  143. padding=self.padding)
  144. def bprop(x, out, dout):
  145. dx = avgpool_grad(shape_op(x), dout)
  146. return (dx,)
  147. def bprop_gpu(x, out, dout):
  148. dx = avgpool_grad_gpu(x, out, dout)
  149. return (dx,)
  150. # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same
  151. if self.target == "GPU":
  152. bprop_fn = bprop_gpu
  153. else:
  154. bprop_fn = bprop
  155. return bprop_fn
  156. @bprop_getters.register(P.DropoutGenMask)
  157. def get_bprop_dropout_gen_mask(self):
  158. """Grad definition for `DropoutGenMask` operation."""
  159. def bprop(shape, keep_prob, out, dout):
  160. return (zeros_like(shape), zeros_like(keep_prob))
  161. return bprop
  162. @bprop_getters.register(P.DropoutDoMask)
  163. def get_bprop_dropout_do_mask(self):
  164. """Grad definition for `DropoutDoMask` operation."""
  165. do_mask = P.DropoutDoMask()
  166. def bprop(x, y, keep_prob, out, dout):
  167. return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
  168. return bprop
  169. @bprop_getters.register(P.ReLU)
  170. def get_bprop_relu(self):
  171. """Grad definition for `ReLU` operation."""
  172. input_grad = G.ReluGrad()
  173. def bprop(x, out, dout):
  174. dx = input_grad(dout, out)
  175. return (dx,)
  176. return bprop
  177. @bprop_getters.register(P.ReLU6)
  178. def get_bprop_relu6(self):
  179. """Grad definition for `ReLU6` operation."""
  180. input_grad = G.ReLU6Grad()
  181. def bprop(x, out, dout):
  182. dx = input_grad(dout, x)
  183. return (dx,)
  184. return bprop
  185. @bprop_getters.register(P.ReLUV2)
  186. def get_bprop_relu_v2(self):
  187. """Grad definition for `ReLUV2` operation."""
  188. input_grad = G.ReluGradV2()
  189. def bprop(x, out, dout):
  190. mask = out[1]
  191. dx = input_grad(dout[0], mask)
  192. return (dx,)
  193. return bprop
  194. @bprop_getters.register(P.HSwish)
  195. def get_bprop_hswish(self):
  196. """Grad definition for `HSwish` operation."""
  197. input_grad = G.HSwishGrad()
  198. def bprop(x, out, dout):
  199. dx = input_grad(dout, x)
  200. return (dx,)
  201. return bprop
  202. @bprop_getters.register(P.HSigmoid)
  203. def get_bprop_hsigmoid(self):
  204. """Grad definition for `HSigmoid` operation."""
  205. input_grad = G.HSigmoidGrad()
  206. def bprop(x, out, dout):
  207. dx = input_grad(dout, x)
  208. return (dx,)
  209. return bprop
  210. @bprop_getters.register(P.Elu)
  211. def get_bprop_elu(self):
  212. """Grad definition for `Elu` operation."""
  213. input_grad = G.EluGrad()
  214. def bprop(x, out, dout):
  215. dx = input_grad(dout, x)
  216. return (dx,)
  217. return bprop
  218. @bprop_getters.register(P.Sigmoid)
  219. def get_bprop_sigmoid(self):
  220. """Grad definition for `Sigmoid` operation."""
  221. input_grad = G.SigmoidGrad()
  222. def bprop(x, out, dout):
  223. dx = input_grad(out, dout)
  224. return (dx,)
  225. return bprop
  226. @bprop_getters.register(P.Softmax)
  227. def get_bprop_softmax(self):
  228. """Grad definition for `Softmax` operation."""
  229. sum_func = P.ReduceSum(keep_dims=True)
  230. sub = P.Sub()
  231. mul = P.Mul()
  232. axis = self.axis
  233. def bprop(x, out, dout):
  234. dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out)
  235. return (dx,)
  236. return bprop
  237. @bprop_getters.register(P.LogSoftmax)
  238. def get_bprop_log_softmax(self):
  239. """Grad definition for `LogSoftmax` operation."""
  240. logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
  241. def bprop(x, out, dout):
  242. dx = logsoftmax_grad(out, dout)
  243. return (dx,)
  244. return bprop
  245. @bprop_getters.register(P.Tanh)
  246. def get_bprop_tanh(self):
  247. """Grad definition for `Tanh` operation."""
  248. logsoftmax_grad = G.TanhGrad()
  249. def bprop(x, out, dout):
  250. dx = logsoftmax_grad(out, dout)
  251. return (dx,)
  252. return bprop
  253. @bprop_getters.register(P.Gelu)
  254. def get_bprop_gelu(self):
  255. """Grad definition for `Gelu` operation."""
  256. input_grad = G.GeluGrad()
  257. def bprop(x, out, dout):
  258. dx = input_grad(dout, x, out)
  259. return (dx,)
  260. return bprop
  261. @bprop_getters.register(P.FusedBatchNorm)
  262. def get_bprop_fused_batch_norm(self):
  263. """Grad definition for `FusedBatchNorm` operation."""
  264. input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum)
  265. def bprop(x, scale, b, mean, variance, out, dout):
  266. saved_mean = out[3]
  267. saved_variance = out[4]
  268. out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
  269. dx = out[0]
  270. dscale = out[1]
  271. dbias = out[2]
  272. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  273. return bprop
  274. @bprop_getters.register(P.BatchNorm)
  275. def get_bprop_batch_norm(self):
  276. """Grad definition for `BatchNorm` operation."""
  277. is_training = self.is_training
  278. input_grad = G.BatchNormGrad(is_training, self.epsilon)
  279. def bprop(x, scale, b, mean, variance, out, dout):
  280. if is_training:
  281. saved_reserve_1 = out[3]
  282. saved_reserve_2 = out[4]
  283. saved_reserve_3 = out[5]
  284. else:
  285. saved_reserve_1 = mean
  286. saved_reserve_2 = variance
  287. saved_reserve_3 = variance
  288. out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3)
  289. dx = out[0]
  290. dscale = out[1]
  291. dbias = out[2]
  292. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  293. return bprop
  294. @bprop_getters.register(P.LayerNorm)
  295. def get_bprop_layer_norm(self):
  296. """Grad definition for `LayerNorm` operation."""
  297. layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
  298. def bprop(x, gamma, beta, out, dout):
  299. dx, d_gamma, d_beta = layer_norm_grad(x, dout[0], out[2], out[1], gamma)
  300. return dx, d_gamma, d_beta
  301. return bprop
  302. @bprop_getters.register(P.L2Normalize)
  303. def get_bprop_l2normalize(self):
  304. """Grad definition for `L2Normalize` operation."""
  305. input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
  306. def bprop(x, out, dout):
  307. dx = input_grad(x, out, dout)
  308. return (dx,)
  309. return bprop
  310. @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
  311. def get_bprop_softmax_cross_entropy_with_logits(self):
  312. """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
  313. expand = P.ExpandDims()
  314. def bprop(logits, labels, out, dout):
  315. grad = out[1]
  316. grad = grad * expand(dout[0], -1)
  317. return grad, zeros_like(labels)
  318. return bprop
  319. @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
  320. def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
  321. """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
  322. is_grad = self.is_grad
  323. grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
  324. def bprop(logits, labels, out, dout):
  325. grad = out[0]
  326. if not is_grad:
  327. # if construct use loss
  328. grad = grad_op(logits, labels)
  329. grad = F.depend(grad, out)
  330. grad = grad * dout
  331. return grad, zeros_like(labels)
  332. return bprop
  333. @bprop_getters.register(P.ResizeBilinear)
  334. def get_bprop_resize_bilinear(self):
  335. """Grad definition for `ResizeBilinear` operation."""
  336. resize_grad = G.ResizeBilinearGrad(self.align_corners)
  337. def bprop(x, out, dout):
  338. dx = resize_grad(dout, x)
  339. return (dx,)
  340. return bprop
  341. @bprop_getters.register(P.OneHot)
  342. def get_bprop_onehot(self):
  343. """Grad definition for `OneHot` operation."""
  344. def bprop(indices, depth, on_value, off_value, out, dout):
  345. return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
  346. return bprop
  347. @bprop_getters.register(P.TopK)
  348. def get_bprop_top_kv2(self):
  349. """Grad definition for `TopK` operation."""
  350. scatter = P.ScatterNd()
  351. expand_dims = P.ExpandDims()
  352. shape_op = P.Shape()
  353. def bprop(input_x, k, out, dout):
  354. indices = out[1]
  355. indices = expand_dims(indices, -1)
  356. updates = dout[0]
  357. shapes = shape_op(input_x)
  358. return scatter(indices, updates, shapes), zeros_like(k)
  359. return bprop
  360. @bprop_getters.register(P.SmoothL1Loss)
  361. def get_bprop_smooth_l1_loss(self):
  362. """Grad definition for `SmoothL1Loss` operation."""
  363. grad = G.SmoothL1LossGrad(self.sigma)
  364. def bprop(prediction, target, out, dout):
  365. dx = grad(prediction, target, dout)
  366. return dx, zeros_like(target)
  367. return bprop
  368. @bprop_getters.register(P.L2Loss)
  369. def get_bprop_l2_loss(self):
  370. """Grad definition for `L2Loss` operation."""
  371. def bprop(x, out, dout):
  372. dx = x * dout
  373. return (dx,)
  374. return bprop
  375. @bprop_getters.register(P.PReLU)
  376. def get_bprop_prelu(self):
  377. """Grad definition for `PReLU` operation."""
  378. grad = G.PReLUGrad()
  379. def bprop(x, w, out, dout):
  380. dx, dw = grad(dout, x, w)
  381. return dx, dw
  382. return bprop
  383. @bprop_getters.register(P.LSTM)
  384. def get_bprop_lstm(self):
  385. """Grad definition for `LSTM` operation."""
  386. lstm_grad_data = G.LSTMGradData(
  387. input_size=self.input_size,
  388. hidden_size=self.hidden_size,
  389. num_layers=self.num_layers,
  390. has_bias=self.has_bias,
  391. bidirectional=self.bidirectional,
  392. dropout=self.dropout
  393. )
  394. lstm_grad_weight = G.LSTMGradWeight(
  395. input_size=self.input_size,
  396. hidden_size=self.hidden_size,
  397. num_layers=self.num_layers,
  398. has_bias=self.has_bias,
  399. bidirectional=self.bidirectional,
  400. dropout=self.dropout
  401. )
  402. def bprop(x, hx, cx, w, out, dout):
  403. y, _, _, reserve, state = out
  404. dy, dhy, dcy, _, _ = dout
  405. dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
  406. dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
  407. return dx, dhx, dcx, dw
  408. return bprop
  409. @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
  410. def get_bprop_sigmoid_crossentropy_with_logits(self):
  411. """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
  412. op = G.SigmoidCrossEntropyWithLogitsGrad()
  413. def bprop(x, y, out, dout):
  414. dx = op(x, y, dout)
  415. return (dx, zeros_like(y))
  416. return bprop
  417. @bprop_getters.register(P.Pad)
  418. def get_bprop_pad(self):
  419. """Grad definition for `Pad` operation."""
  420. shape_op = P.Shape()
  421. paddings = self.paddings
  422. def bprop(x, out, dout):
  423. begin = ()
  424. for item in paddings:
  425. begin += (item[0],)
  426. shp = shape_op(x)
  427. dx = P.Slice()(dout, begin, shp)
  428. return (dx,)
  429. return bprop
  430. @bprop_getters.register(P.MirrorPad)
  431. def get_bprop_mirror_pad(self):
  432. """Grad definition for `MirrorPad` operation."""
  433. mirror_pad_grad = G.MirrorPadGrad(self.mode)
  434. def bprop(x, paddings, out, dout):
  435. dx = mirror_pad_grad(dout, paddings, x)
  436. return (dx, zeros_like(paddings))
  437. return bprop
  438. @bprop_getters.register(P.ROIAlign)
  439. def get_bprop_roi_align(self):
  440. """Grad definition for `ROIAlign` operation."""
  441. shape_op = P.Shape()
  442. pooled_height = self.pooled_height
  443. pooled_width = self.pooled_width
  444. spatial_scale = self.spatial_scale
  445. sample_num = self.sample_num
  446. def bprop(inputs, rois, out, dout):
  447. rois_shape = shape_op(rois)
  448. inputs_shape = shape_op(inputs)
  449. dx = G.ROIAlignGrad(inputs_shape,
  450. pooled_height,
  451. pooled_width,
  452. spatial_scale,
  453. sample_num,
  454. )(dout, rois)
  455. return dx, zeros_like(rois_shape)
  456. return bprop
  457. @bprop_getters.register(P.Conv2DBackpropInput)
  458. def get_bprop_conv2d_backprop_input(self):
  459. """Grad definition for `Conv2DBackpropInput` operation."""
  460. filter_grad = G.Conv2DBackpropFilter(
  461. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  462. dilation=self.dilation, stride=self.stride, group=self.group
  463. )
  464. input_grad = P.Conv2D(
  465. self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
  466. dilation=self.dilation, stride=self.stride, group=self.group
  467. )
  468. def bprop(x, w, f_sizes, out, dout):
  469. dx = input_grad(dout, w)
  470. dw = filter_grad(x, dout, F.shape(w))
  471. return dx, dw
  472. return bprop
  473. @bprop_getters.register(P.BinaryCrossEntropy)
  474. def get_bprop_binary_cross_entropy(self):
  475. """Grad definition for `BinaryCrossEntropy` operation."""
  476. grad = G.BinaryCrossEntropyGrad(self.reduction)
  477. def bprop(x, y, weight, out, dout):
  478. dx = grad(x, y, dout, weight)
  479. return dx, zeros_like(y), zeros_like(weight)
  480. return bprop