|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """Define the grad rules of neural network related operations."""
- from mindspore.common import dtype as mstype
- from .. import functional as F
- from .. import operations as P
- from ..operations import _grad_ops as G
- from ..operations import _inner_ops as inner
- from ..composite.multitype_ops.zeros_like_impl import zeros_like
- from .grad_base import bprop_getters
-
-
- @bprop_getters.register(P.BiasAdd)
- def get_bprop_bias_add(self):
- """Grad definition for `BiasAdd` operation."""
- bias_grad = G.BiasAddGrad()
-
- def bprop(x, w, out, dout):
- return dout, bias_grad(dout)
-
- return bprop
-
-
- @bprop_getters.register(P.Conv2D)
- def get_bprop_conv2d(self):
- """Grad definition for `Conv2D` operation."""
- input_grad = P.Conv2DBackpropInput(
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
- dilation=self.dilation, stride=self.stride, group=self.group
- )
- filter_grad = G.Conv2DBackpropFilter(
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
- dilation=self.dilation, stride=self.stride, group=self.group
- )
- get_shape = P.Shape()
-
- def bprop(x, w, out, dout):
- dx = input_grad(dout, w, get_shape(x))
- dw = filter_grad(dout, x, get_shape(w))
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(inner.ExtractImagePatches)
- def get_bprop_extract_image_patches(self):
- """Grad definition for `ExtractImagePatches` operation."""
- get_shape = P.Shape()
- reshape = P.Reshape()
- extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
- strides=self.strides,
- rates=self.rates,
- padding=self.padding)
- concat = P.Concat(axis=-1)
- expand_dims = P.ExpandDims()
- scatter_nd = P.ScatterNd()
- dtype = P.DType()
- fill = P.Fill()
- slice_op = P.Slice()
- transpose = P.Transpose()
- matmul = P.MatMul()
- cast = P.Cast()
- _, ksizes_row, ksizes_col, _ = self.ksizes
-
- def bprop(x, out, dout):
- x_shape = get_shape(x)
- x_batch, x_row, x_col, x_depth = x_shape
- x_indices_num = x_row * x_col + 1
- x_idx = F.tuple_to_array(range(1, x_indices_num))
- x_idx = reshape(x_idx, (1, x_row, x_col, 1))
- x_idx = cast(x_idx, mstype.float16)
- x_idx_patch = extract_image_patches(x_idx)
- x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2))
- x_idx_patch = cast(x_idx_patch, mstype.int32)
-
- out_shape = get_shape(out)
- _, out_row, out_col, _ = out_shape
- out_indices_num = out_row * out_col * ksizes_row * ksizes_col
- out_idx = F.tuple_to_array(range(out_indices_num))
- out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col))
-
- idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
- idx_tensor = reshape(idx_tensor, (-1, 2))
- sp_shape = (x_indices_num, out_indices_num)
- sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
- sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
-
- grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
- grad = transpose(grad, (1, 2, 3, 4, 0, 5))
- grad = reshape(grad, (-1, x_batch * x_depth))
-
- jac = matmul(sp_tensor, grad)
- dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
- dx = transpose(dx, (2, 0, 1, 3))
-
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.DepthwiseConv2dNative)
- def get_bprop_depthwise_conv2d_native(self):
- """Grad definition for `DepthwiseConv2dNative` operation."""
- input_grad = G.DepthwiseConv2dNativeBackpropInput(
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pads, self.mode, self.stride,
- self.dilation, self.group
- )
- filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pads, self.mode, self.stride,
- self.dilation, self.group
- )
- get_shape = P.Shape()
-
- def bprop(x, w, out, dout):
- dx = input_grad(get_shape(x), w, dout)
- dw = filter_grad(x, get_shape(w), dout)
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(P.MaxPoolWithArgmax)
- def get_bprop_max_pool_with_argmax(self):
- """Grad definition for `MaxPoolWithArgmax` operation."""
- maxpool_grad = G.MaxPoolGradWithArgmax(
- ksize=self.ksize,
- strides=self.strides,
- padding=self.padding)
-
- def bprop(x, out, dout):
- dx = maxpool_grad(x, dout[0], out[1])
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.MaxPool)
- def get_bprop_max_pool_grad(self):
- """Grad definition for `MaxPool` operation."""
- maxpool_grad = G.MaxPoolGrad(
- ksize=self.ksize,
- strides=self.strides,
- padding=self.padding)
-
- def bprop(x, out, dout):
- dx = maxpool_grad(x, out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.AvgPool)
- def get_bprop_avg_pool_grad(self):
- """Grad definition for `AvgPool` operation."""
- avgpool_grad = G.AvgPoolGrad(
- ksize=self.ksize,
- strides=self.strides,
- padding=self.padding)
- shape_op = P.Shape()
-
- avgpool_grad_gpu = G.AvgPoolGradGpu(
- ksize=self.ksize,
- strides=self.strides,
- padding=self.padding)
-
- def bprop(x, out, dout):
- dx = avgpool_grad(shape_op(x), dout)
- return (dx,)
-
- def bprop_gpu(x, out, dout):
- dx = avgpool_grad_gpu(x, out, dout)
- return (dx,)
-
- # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same
- if self.target == "GPU":
- bprop_fn = bprop_gpu
- else:
- bprop_fn = bprop
-
- return bprop_fn
-
-
- @bprop_getters.register(P.DropoutGenMask)
- def get_bprop_dropout_gen_mask(self):
- """Grad definition for `DropoutGenMask` operation."""
-
- def bprop(shape, keep_prob, out, dout):
- return (zeros_like(shape), zeros_like(keep_prob))
-
- return bprop
-
-
- @bprop_getters.register(P.DropoutDoMask)
- def get_bprop_dropout_do_mask(self):
- """Grad definition for `DropoutDoMask` operation."""
- do_mask = P.DropoutDoMask()
-
- def bprop(x, y, keep_prob, out, dout):
- return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
-
- return bprop
-
-
- @bprop_getters.register(P.ReLU)
- def get_bprop_relu(self):
- """Grad definition for `ReLU` operation."""
- input_grad = G.ReluGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.ReLU6)
- def get_bprop_relu6(self):
- """Grad definition for `ReLU6` operation."""
- input_grad = G.ReLU6Grad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.ReLUV2)
- def get_bprop_relu_v2(self):
- """Grad definition for `ReLUV2` operation."""
- input_grad = G.ReluGradV2()
-
- def bprop(x, out, dout):
- mask = out[1]
- dx = input_grad(dout[0], mask)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.HSwish)
- def get_bprop_hswish(self):
- """Grad definition for `HSwish` operation."""
- input_grad = G.HSwishGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.HSigmoid)
- def get_bprop_hsigmoid(self):
- """Grad definition for `HSigmoid` operation."""
- input_grad = G.HSigmoidGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Elu)
- def get_bprop_elu(self):
- """Grad definition for `Elu` operation."""
- input_grad = G.EluGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Sigmoid)
- def get_bprop_sigmoid(self):
- """Grad definition for `Sigmoid` operation."""
- input_grad = G.SigmoidGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Softmax)
- def get_bprop_softmax(self):
- """Grad definition for `Softmax` operation."""
- sum_func = P.ReduceSum(keep_dims=True)
- sub = P.Sub()
- mul = P.Mul()
- axis = self.axis
-
- def bprop(x, out, dout):
- dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.LogSoftmax)
- def get_bprop_log_softmax(self):
- """Grad definition for `LogSoftmax` operation."""
- logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
-
- def bprop(x, out, dout):
- dx = logsoftmax_grad(out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Softplus)
- def get_bprop_softplus(self):
- """Grad definition for `Softplus` operation."""
- softplus_grad = G.SoftplusGrad()
-
- def bprop(x, out, dout):
- dx = softplus_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Tanh)
- def get_bprop_tanh(self):
- """Grad definition for `Tanh` operation."""
- logsoftmax_grad = G.TanhGrad()
-
- def bprop(x, out, dout):
- dx = logsoftmax_grad(out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Gelu)
- def get_bprop_gelu(self):
- """Grad definition for `Gelu` operation."""
- input_grad = G.GeluGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x, out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.FusedBatchNorm)
- def get_bprop_fused_batch_norm(self):
- """Grad definition for `FusedBatchNorm` operation."""
- input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum)
-
- def bprop(x, scale, b, mean, variance, out, dout):
- saved_mean = out[3]
- saved_variance = out[4]
- out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
- dx = out[0]
- dscale = out[1]
- dbias = out[2]
- return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
-
- return bprop
-
-
- @bprop_getters.register(P.BatchNorm)
- def get_bprop_batch_norm(self):
- """Grad definition for `BatchNorm` operation."""
- is_training = self.is_training
- input_grad = G.BatchNormGrad(is_training, self.epsilon)
-
- def bprop(x, scale, b, mean, variance, out, dout):
- if is_training:
- saved_reserve_1 = out[3]
- saved_reserve_2 = out[4]
- else:
- saved_reserve_1 = mean
- saved_reserve_2 = variance
- out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
- dx = out[0]
- dscale = out[1]
- dbias = out[2]
- return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
-
- return bprop
-
-
- @bprop_getters.register(P.LayerNorm)
- def get_bprop_layer_norm(self):
- """Grad definition for `LayerNorm` operation."""
- layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
-
- def bprop(x, gamma, beta, out, dout):
- dx, d_gamma, d_beta = layer_norm_grad(x, dout[0], out[2], out[1], gamma)
- return dx, d_gamma, d_beta
-
- return bprop
-
-
- @bprop_getters.register(P.L2Normalize)
- def get_bprop_l2normalize(self):
- """Grad definition for `L2Normalize` operation."""
- input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
-
- def bprop(x, out, dout):
- dx = input_grad(x, out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
- def get_bprop_softmax_cross_entropy_with_logits(self):
- """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
- expand = P.ExpandDims()
-
- def bprop(logits, labels, out, dout):
- grad = out[1]
- grad = grad * expand(dout[0], -1)
- return grad, zeros_like(labels)
-
- return bprop
-
-
- @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
- def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
- """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
- is_grad = self.is_grad
- grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
-
- def bprop(logits, labels, out, dout):
- grad = out[0]
- if not is_grad:
- # if construct use loss
- grad = grad_op(logits, labels)
- grad = F.depend(grad, out)
- grad = grad * dout
- return grad, zeros_like(labels)
-
- return bprop
-
-
- @bprop_getters.register(P.ResizeBilinear)
- def get_bprop_resize_bilinear(self):
- """Grad definition for `ResizeBilinear` operation."""
- resize_grad = G.ResizeBilinearGrad(self.align_corners)
-
- def bprop(x, out, dout):
- dx = resize_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.OneHot)
- def get_bprop_onehot(self):
- """Grad definition for `OneHot` operation."""
-
- def bprop(indices, depth, on_value, off_value, out, dout):
- return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
-
- return bprop
-
-
- @bprop_getters.register(P.TopK)
- def get_bprop_top_kv2(self):
- """Grad definition for `TopK` operation."""
- scatter = P.ScatterNd()
- expand_dims = P.ExpandDims()
- shape_op = P.Shape()
-
- def bprop(input_x, k, out, dout):
- indices = out[1]
- indices = expand_dims(indices, -1)
- updates = dout[0]
- shapes = shape_op(input_x)
- return scatter(indices, updates, shapes), zeros_like(k)
-
- return bprop
-
-
- @bprop_getters.register(P.SmoothL1Loss)
- def get_bprop_smooth_l1_loss(self):
- """Grad definition for `SmoothL1Loss` operation."""
- grad = G.SmoothL1LossGrad(self.sigma)
-
- def bprop(prediction, target, out, dout):
- dx = grad(prediction, target, dout)
- return dx, zeros_like(target)
-
- return bprop
-
-
- @bprop_getters.register(P.L2Loss)
- def get_bprop_l2_loss(self):
- """Grad definition for `L2Loss` operation."""
-
- def bprop(x, out, dout):
- dx = x * dout
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.PReLU)
- def get_bprop_prelu(self):
- """Grad definition for `PReLU` operation."""
- grad = G.PReLUGrad()
-
- def bprop(x, w, out, dout):
- dx, dw = grad(dout, x, w)
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(P.LSTM)
- def get_bprop_lstm(self):
- """Grad definition for `LSTM` operation."""
- lstm_grad_data = G.LSTMGradData(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- num_layers=self.num_layers,
- has_bias=self.has_bias,
- bidirectional=self.bidirectional,
- dropout=self.dropout
- )
-
- lstm_grad_weight = G.LSTMGradWeight(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- num_layers=self.num_layers,
- has_bias=self.has_bias,
- bidirectional=self.bidirectional,
- dropout=self.dropout
- )
-
- def bprop(x, hx, cx, w, out, dout):
- y, _, _, reserve, state = out
- dy, dhy, dcy, _, _ = dout
- dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
- dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
- return dx, dhx, dcx, dw
-
- return bprop
-
-
- @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
- def get_bprop_sigmoid_crossentropy_with_logits(self):
- """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
- op = G.SigmoidCrossEntropyWithLogitsGrad()
-
- def bprop(x, y, out, dout):
- dx = op(x, y, dout)
- return (dx, zeros_like(y))
-
- return bprop
-
-
- @bprop_getters.register(P.Pad)
- def get_bprop_pad(self):
- """Grad definition for `Pad` operation."""
- shape_op = P.Shape()
- paddings = self.paddings
-
- def bprop(x, out, dout):
- begin = ()
- for item in paddings:
- begin += (item[0],)
- shp = shape_op(x)
- dx = P.Slice()(dout, begin, shp)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.MirrorPad)
- def get_bprop_mirror_pad(self):
- """Grad definition for `MirrorPad` operation."""
- mirror_pad_grad = G.MirrorPadGrad(self.mode)
-
- def bprop(x, paddings, out, dout):
- dx = mirror_pad_grad(dout, paddings, x)
- return (dx, zeros_like(paddings))
-
- return bprop
-
-
- @bprop_getters.register(P.ROIAlign)
- def get_bprop_roi_align(self):
- """Grad definition for `ROIAlign` operation."""
- shape_op = P.Shape()
- pooled_height = self.pooled_height
- pooled_width = self.pooled_width
- spatial_scale = self.spatial_scale
- sample_num = self.sample_num
-
- def bprop(inputs, rois, out, dout):
- inputs_shape = shape_op(inputs)
- dx = G.ROIAlignGrad(inputs_shape,
- pooled_height,
- pooled_width,
- spatial_scale,
- sample_num,
- )(dout, rois)
- return dx, zeros_like(rois)
-
- return bprop
-
-
- @bprop_getters.register(P.Conv2DBackpropInput)
- def get_bprop_conv2d_backprop_input(self):
- """Grad definition for `Conv2DBackpropInput` operation."""
- filter_grad = G.Conv2DBackpropFilter(
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
- dilation=self.dilation, stride=self.stride, group=self.group
- )
- input_grad = P.Conv2D(
- self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
- dilation=self.dilation, stride=self.stride, group=self.group
- )
-
- def bprop(x, w, f_sizes, out, dout):
- dx = input_grad(dout, w)
- dw = filter_grad(x, dout, F.shape(w))
- return dx, dw, zeros_like(f_sizes)
-
- return bprop
-
-
- @bprop_getters.register(P.BinaryCrossEntropy)
- def get_bprop_binary_cross_entropy(self):
- """Grad definition for `BinaryCrossEntropy` operation."""
- grad = G.BinaryCrossEntropyGrad(self.reduction)
-
- def bprop(x, y, weight, out, dout):
- dx = grad(x, y, dout, weight)
- return dx, zeros_like(y), zeros_like(weight)
-
- return bprop
-
-
- @bprop_getters.register(P.Dropout)
- def get_bprop_dropout(self):
- """Grad definition for `Dropout` operation."""
- grad = P.DropoutGrad(self.drop_prob)
-
- def bprop(x, out, dout):
- _, mask = out
- dy, _ = dout
- dx = grad(dy, mask)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.CTCLoss)
- def get_bprop_ctc_loss(self):
- """Grad definition for `CTCLoss` operation"""
- expand = P.ExpandDims()
-
- def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
- grad_loss = out[1]
- grad = grad_loss * expand(dout[0], -1)
- return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
-
- return bprop
|