|
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """Define the grad rules of neural network related operations."""
- from mindspore.ops.primitive import constexpr
- from mindspore.ops.operations import nn_ops as nps
- from .grad_base import bprop_getters
- from .. import functional as F
- from .. import operations as P
- from ...common import dtype as mstype
- from ..composite.multitype_ops.zeros_like_impl import zeros_like
- from ..operations import _grad_ops as G
- from ..operations import _inner_ops as inner
- from ... import context
- from .._utils.utils import range_op, get_1d_shape
-
- @bprop_getters.register(P.BiasAdd)
- def get_bprop_bias_add(self):
- """Grad definition for `BiasAdd` operation."""
- bias_grad = G.BiasAddGrad(self.data_format)
-
- def bprop(x, w, out, dout):
- return dout, bias_grad(dout)
-
- return bprop
-
-
- @bprop_getters.register(P.Conv2D)
- def get_bprop_conv2d(self):
- """Grad definition for `Conv2D` operation."""
- self.out_channel = self.get_attr_dict()["out_channel"]
- self.pad_list = self.get_attr_dict()["pad_list"]
- input_grad = P.Conv2DBackpropInput(
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
- )
- filter_grad = G.Conv2DBackpropFilter(
- self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
- )
- get_shape = P.Shape()
- get_dyn_shape = P.DynamicShape()
-
- def bprop(x, w, out, dout):
- x_shape = get_shape(x)
- w_shape = get_shape(w)
- if -1 in x_shape:
- x_shape = get_dyn_shape(x)
- if -1 in w_shape:
- w_shape = get_dyn_shape(w)
- dx = input_grad(dout, w, x_shape)
-
- dw = filter_grad(dout, x, w_shape)
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(nps.Conv3D)
- def get_bprop_conv3d(self):
- """Grad definition for `Conv3D` operation."""
- input_grad = nps.Conv3DBackpropInput(
- self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
- pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
- )
- filter_grad = G.Conv3DBackpropFilter(
- self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
- pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
- )
- get_shape = P.Shape()
-
- def bprop(x, w, out, dout):
- dx = input_grad(w, dout, get_shape(x))
- dw = filter_grad(x, dout, get_shape(w))
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(nps.Conv3DTranspose)
- def get_bprop_conv3d_transpose(self):
- """Grad definition for `Conv3DTranspose` operation."""
- stride = (self.stride[2], self.stride[3], self.stride[4])
- dilation = (self.dilation[2], self.dilation[3], self.dilation[4])
- input_grad = nps.Conv3D(
- out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
- pad=self.pad_list, stride=stride, dilation=dilation, group=self.group, data_format=self.data_format
- )
- filter_grad = G.Conv3DBackpropFilter(
- out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
- pad=self.pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
- )
-
- def bprop(x, w, out, dout):
- dx = input_grad(dout, w)
- dw = filter_grad(dout, x, F.shape(w))
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(inner.ExtractImagePatches)
- def get_bprop_extract_image_patches(self):
- """Grad definition for `ExtractImagePatches` operation."""
- get_shape = P.Shape()
- reshape = P.Reshape()
- extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
- strides=self.strides,
- rates=self.rates,
- padding=self.padding)
- concat = P.Concat(axis=-1)
- expand_dims = P.ExpandDims()
- scatter_nd = P.ScatterNd()
- dtype = P.DType()
- fill = P.Fill()
- slice_op = P.Slice()
- transpose = P.Transpose()
- cast = P.Cast()
- matmul = P.MatMul()
-
- _, _, ksizes_row, ksizes_col = self.ksizes
-
- def bprop(x, out, dout):
- x_shape = get_shape(x)
- x_batch, x_depth, x_row, x_col = x_shape
- x_indices_num = x_row * x_col + 1
- x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
- x_idx = reshape(x_idx, (1, 1, x_row, x_col))
- x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
- x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
-
- out_shape = get_shape(out)
- _, _, out_row, out_col = out_shape
- out_indices_num = out_row * out_col * ksizes_row * ksizes_col
- out_idx = F.tuple_to_array(range(out_indices_num))
- out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
-
- idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
- idx_tensor = reshape(idx_tensor, (-1, 2))
- sp_shape = (x_indices_num, out_indices_num)
- sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
- sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
-
- grad = transpose(dout, (0, 2, 3, 1))
- grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
- grad = transpose(grad, (1, 2, 3, 4, 0, 5))
- grad = reshape(grad, (-1, x_batch * x_depth))
-
- jac = matmul(sp_tensor, grad)
- dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
- dx = transpose(dx, (2, 3, 0, 1))
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.DepthwiseConv2dNative)
- def get_bprop_depthwise_conv2d_native(self):
- """Grad definition for `DepthwiseConv2dNative` operation."""
- input_grad = G.DepthwiseConv2dNativeBackpropInput(
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
- self.dilation, self.group
- )
- filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
- self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
- self.dilation, self.group
- )
- get_shape = P.Shape()
-
- def bprop(x, w, out, dout):
- dx = input_grad(get_shape(x), w, dout)
-
- dw = filter_grad(x, get_shape(w), dout)
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(P.MaxPoolWithArgmax)
- def get_bprop_max_pool_with_argmax(self):
- """Grad definition for `MaxPoolWithArgmax` operation."""
- maxpool_grad = G.MaxPoolGradWithArgmax(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode)
-
- def bprop(x, out, dout):
- dx = maxpool_grad(x, dout[0], out[1])
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(G.MaxPoolGrad)
- def get_bprop_max_pool_grad_grad(self):
- """Grad definition for `MaxPoolGrad` operation."""
- maxpool_grad_grad = G.MaxPoolGradGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode)
-
- def bprop(x1, x2, grad, out, dout):
- dx1 = zeros_like(x1)
- dx2 = zeros_like(x2)
- dgrad = maxpool_grad_grad(x1, x2, dout)
- return (dx1, dx2, dgrad)
-
- return bprop
-
-
- @bprop_getters.register(G.MaxPoolGradGrad)
- def get_bprop_max_pool_grad_grad_grad(self):
- """Grad definition for `MaxPoolGradGrad` operation."""
- maxpool_grad = G.MaxPoolGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode)
-
- def bprop(x1, x2, grad, out, dout):
- dx1 = zeros_like(x1)
- dx2 = zeros_like(x2)
- dgrad = maxpool_grad(x1, x2, dout)
- return (dx1, dx2, dgrad)
-
- return bprop
-
-
- @bprop_getters.register(P.MaxPool)
- def get_bprop_max_pool_grad(self):
- """Grad definition for `MaxPool` operation."""
- maxpool_grad = G.MaxPoolGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode,
- data_format=self.format)
-
- def bprop(x, out, dout):
- dx = maxpool_grad(x, out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.MaxPool3D)
- def get_bprop_max_pool3d_grad(self):
- """Grad definition for `MaxPool3D` operation."""
- max_pool3d_grad = G.MaxPool3DGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode,
- pad_list=self.pad_list,
- data_format=self.data_format)
-
- def bprop(x, out, dout):
- dx = max_pool3d_grad(x, out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(G.MaxPool3DGrad)
- def get_bprop_max_pool3d_grad_grad(self):
- """Grad definition for `MaxPool3Grad` operation."""
- max_pool3d_grad_grad = G.MaxPool3DGradGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode,
- data_format=self.data_format)
-
- def bprop(x, y, grad, out, dout):
- dgrad = max_pool3d_grad_grad(x, y, dout)
- return zeros_like(x), zeros_like(y), dgrad
-
- return bprop
-
-
- @bprop_getters.register(G.MaxPool3DGradGrad)
- def get_bprop_max_pool3d_grad_grad_grad(self):
- """Grad definition for `MaxPool3GradGrad` operation."""
-
- max_pool3d_grad = G.MaxPool3DGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode,
- data_format=self.data_format)
-
- def bprop(x, y, grad, out, dout):
- dgrad = max_pool3d_grad(x, y, dout)
- return zeros_like(x), zeros_like(y), dgrad
-
- return bprop
-
-
- @bprop_getters.register(P.AvgPool)
- def get_bprop_avg_pool_grad(self):
- """Grad definition for `AvgPool` operation."""
- avgpool_grad = G.AvgPoolGrad(
- kernel_size=self.kernel_size,
- strides=self.strides,
- pad_mode=self.pad_mode,
- data_format=self.format)
-
- def bprop(x, out, dout):
- dx = avgpool_grad(x, out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.AdaptiveAvgPool2D)
- def get_bprop_adaptive_avg_pool2d_grad(self):
- """Grad definition for `AdaptiveAvgPool2D` operation."""
- adaptive_avgpool_grad = G.AdaptiveAvgPool2DGrad()
-
- def bprop(x, out, dout):
- dx = adaptive_avgpool_grad(x, dout)
- return (dx,)
-
- return bprop
-
- @bprop_getters.register(P.AvgPool3D)
- def get_bprop_avg_pool_3d_grad(self):
- """Grad definition for `AvgPool3D` operation."""
- pad_list = self.get_attr_dict()['pad_list']
- count_include_pad = self.get_attr_dict()['count_include_pad']
- avgpool3d_grad = G.AvgPool3DGrad(kernel_size=self.kernel_size,
- strides=self.strides,
- pads=pad_list,
- ceil_mode=self.ceil_mode,
- count_include_pad=count_include_pad,
- divisor_override=self.divisor_override,
- data_format=self.data_format)
-
- def bprop(x, out, dout):
- x_shape = F.shape(x)
- dx = avgpool3d_grad(x_shape, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.DropoutGenMask)
- def get_bprop_dropout_gen_mask(self):
- """Grad definition for `DropoutGenMask` operation."""
-
- def bprop(shape, keep_prob, out, dout):
- return (zeros_like(shape), zeros_like(keep_prob))
-
- return bprop
-
-
- @bprop_getters.register(P.DropoutDoMask)
- def get_bprop_dropout_do_mask(self):
- """Grad definition for `DropoutDoMask` operation."""
- do_mask = P.DropoutDoMask()
-
- def bprop(x, y, keep_prob, out, dout):
- return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
-
- return bprop
-
-
- @bprop_getters.register(P.Mish)
- def get_bprop_mish(self):
- """Grad definition for `Mish` operation."""
- tanh = P.Tanh()
- tanh_grad = G.TanhGrad()
- softplus = P.Softplus()
- softplus_grad = G.SoftplusGrad()
-
- def bprop(x, out, dout):
- dx1 = tanh(softplus(x))
- dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
- dx = (dx1 * dout + dx2)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.SeLU)
- def get_bprop_selu(self):
- """Grad definition for `SeLU` operation."""
- scale = 1.0507009873554804934193349852946
- elu_grad = G.EluGrad()
-
- def bprop(x, out, dout):
- dx = elu_grad(dout, out) * scale
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.MulNoNan)
- def get_bprop_mul_no_nan(self):
- """Grad definition for `MulNoNan` operation."""
- mul_no_nan = P.MulNoNan()
- reduce_sum = P.ReduceSum()
- reshape = P.Reshape()
-
- def bprop(x, y, out, dout):
- x_shape = F.shape(x)
- y_shape = F.shape(y)
- dx = mul_no_nan(dout, y)
- dy = mul_no_nan(x, dout)
- broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
- if broadcast_x != ():
- dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
- if broadcast_y != ():
- dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
- return dx, dy
-
- return bprop
-
-
- @bprop_getters.register(P.ReLU)
- def get_bprop_relu(self):
- """Grad definition for `ReLU` operation."""
- input_grad = G.ReluGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(G.ReluGrad)
- def get_bprop_relu_grad(self):
- """Grad definition for `ReLUGrad` operation."""
- input_grad = G.ReluGrad()
-
- def bprop(grad, y, out, dout):
- dgrad = input_grad(dout, y)
- return dgrad, zeros_like(y)
-
- return bprop
-
-
- @bprop_getters.register(P.ReLU6)
- def get_bprop_relu6(self):
- """Grad definition for `ReLU6` operation."""
- input_grad = G.ReLU6Grad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.ReLUV2)
- def get_bprop_relu_v2(self):
- """Grad definition for `ReLUV2` operation."""
- input_grad = G.ReluGradV2()
-
- def bprop(x, out, dout):
- mask = out[1]
- dx = input_grad(dout[0], mask)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.HSwish)
- def get_bprop_hswish(self):
- """Grad definition for `HSwish` operation."""
- input_grad = G.HSwishGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.HSigmoid)
- def get_bprop_hsigmoid(self):
- """Grad definition for `HSigmoid` operation."""
- input_grad = G.HSigmoidGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Elu)
- def get_bprop_elu(self):
- """Grad definition for `Elu` operation."""
- input_grad = G.EluGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Sigmoid)
- def get_bprop_sigmoid(self):
- """Grad definition for `Sigmoid` operation."""
- input_grad = G.SigmoidGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(G.SigmoidGrad)
- def get_bprop_sigmoid_grad(self):
- """Grad definition for `SigmoidGrad` operation."""
- sigmoid_grad = G.SigmoidGrad()
-
- def bprop(y, grad, out, dout):
- dy = dout * grad * (1. - 2 * y)
- dgrad = sigmoid_grad(y, dout)
- return dy, dgrad
-
- return bprop
-
-
- @constexpr
- def _get_transpose_axis(x_shp, axis):
- rank = len(x_shp)
- if axis < 0:
- axis += rank
- reverse_axis = [i for i in range(rank)]
- reverse_axis[axis] = rank - 1
- reverse_axis[rank - 1] = axis
- return tuple(reverse_axis)
-
-
- @bprop_getters.register(P.Softmax)
- def get_bprop_softmax(self):
- """Grad definition for `Softmax` operation."""
- sum_func = P.ReduceSum(keep_dims=True)
- sub = P.Sub()
- mul = P.Mul()
- get_shape = P.Shape()
- transpose = P.Transpose()
- axis = self.axis
- if not isinstance(axis, int):
- axis = axis[0]
-
- def bprop(x, out, dout):
- # dx = (dout - sum(dout * out)) * out
- # This formula is correct only when the `axis` is the last dimension.
- # In order to support the scenario where the `axis` is other values,
- # we transpose the data of the `axis` dimension to the last dimension for calculation,
- # and then transpose it back after the calculation.
- reverse_axis = _get_transpose_axis(get_shape(x), axis)
- out = transpose(out, reverse_axis)
- dout = transpose(dout, reverse_axis)
- dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
- dx = transpose(dx, reverse_axis)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.LogSoftmax)
- def get_bprop_log_softmax(self):
- """Grad definition for `LogSoftmax` operation."""
- logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
-
- def bprop(x, out, dout):
- dx = logsoftmax_grad(out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Softplus)
- def get_bprop_softplus(self):
- """Grad definition for `Softplus` operation."""
- softplus_grad = G.SoftplusGrad()
-
- def bprop(x, out, dout):
- dx = softplus_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Softsign)
- def get_bprop_softsign(self):
- """Grad definition for `Softsign` operation."""
- mul = P.Mul()
- absolute = P.Abs()
- div = P.Div()
- square = P.Square()
-
- def bprop(x, out, dout):
- dx = mul(dout, div(1, square(1 + absolute(x))))
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Tanh)
- def get_bprop_tanh(self):
- """Grad definition for `Tanh` operation."""
- tanh_grad = G.TanhGrad()
-
- def bprop(x, out, dout):
- dx = tanh_grad(out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(G.TanhGrad)
- def get_bprop_tanh_grad(self):
- """Grad definition for `TanhGrad` operation."""
- tanh_grad = G.TanhGrad()
-
- def bprop(y, grad, out, dout):
- dy = dout * -2.0 * grad * y
- dgrad = tanh_grad(y, dout)
- return dy, dgrad
-
- return bprop
-
-
- @bprop_getters.register(P.Gelu)
- @bprop_getters.register(P.GeLU)
- def get_bprop_gelu(self):
- """Grad definition for `GeLU` operation."""
- input_grad = G.GeLUGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x, out)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.FastGeLU)
- def get_bprop_fast_gelu(self):
- """Grad definition for `FastGeLU` operation."""
- input_grad = G.FastGeLUGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.FastGelu)
- def get_bprop_fast_gelu_2(self):
- """Grad definition for `FastGeLU` operation."""
- input_grad = G.FastGeLUGrad()
-
- def bprop(x, out, dout):
- dx = input_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.InstanceNorm)
- def get_bprop_instance_norm(self):
- """Grad definition for `InstanceNorm` operation."""
- input_grad = G.InstanceNormGrad(self.epsilon, self.momentum)
-
- def bprop(x, gamma, beta, mean, variance, out, dout):
- saved_mean = out[1]
- saved_variance = out[2]
- out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
- dx = out[0]
- dgamma = out[1]
- dbeta = out[2]
- return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
-
- return bprop
-
-
- @bprop_getters.register(P.BatchNorm)
- def get_bprop_batch_norm(self):
- """Grad definition for `BatchNorm` operation."""
- is_training = self.is_training
- input_grad = G.BatchNormGrad(is_training, self.epsilon, self.data_format)
-
- def bprop(x, scale, b, mean, variance, out, dout):
- if is_training:
- saved_mean = out[3]
- saved_variance = out[4]
- reserve = out[2]
- else:
- saved_mean = mean
- saved_variance = variance
- reserve = out[2]
- out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
- dx = out[0]
- dscale = out[1]
- dbias = out[2]
- return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
-
- return bprop
-
-
- @bprop_getters.register(P.LayerNorm)
- def get_bprop_layer_norm(self):
- """Grad definition for `LayerNorm` operation."""
- layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
-
- def bprop(x, gamma, beta, out, dout):
- dx, d_gamma, d_beta = layer_norm_grad(
- x, dout[0], out[2], out[1], gamma)
- return dx, d_gamma, d_beta
-
- return bprop
-
-
- @bprop_getters.register(G.LayerNormGrad)
- def get_bprop_layer_norm_grad(self):
- """Grad definition for `LayerNormGrad` operation."""
- layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
-
- def bprop(x, dy, variance, mean, gamma, out, dout):
- d_x, d_dy, d_gamma = layer_norm_grad_grad(
- x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
- return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
-
- return bprop
-
-
- @bprop_getters.register(P.L2Normalize)
- def get_bprop_l2normalize(self):
- """Grad definition for `L2Normalize` operation."""
- input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
-
- def bprop(x, out, dout):
- dx = input_grad(x, out, dout)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
- def get_bprop_softmax_cross_entropy_with_logits(self):
- """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
- expand = P.ExpandDims()
-
- def bprop(logits, labels, out, dout):
- grad = out[1]
- grad = grad * expand(dout[0], -1)
- return grad, zeros_like(labels)
-
- return bprop
-
-
- @bprop_getters.register(P.NLLLoss)
- def get_bprop_nll_loss(self):
- """Grad definition for `NLLLoss` operation."""
- nll_loss_grad = G.NLLLossGrad(reduction=self.reduction)
-
- def bprop(x, target, weight, out, dout):
- total_weight = out[1]
- dout_x = dout[0]
- dx = nll_loss_grad(x, dout_x, target, weight, total_weight)
- return dx, zeros_like(target), zeros_like(weight)
-
- return bprop
-
-
- @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
- def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
- """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
- is_grad = self.is_grad
- grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
-
- def bprop(logits, labels, out, dout):
- grad = out[0]
- if not is_grad:
- # if construct use loss
- grad = grad_op(logits, labels)
- grad = F.depend(grad, out)
- grad = grad * dout
- return grad, zeros_like(labels)
-
- return bprop
-
-
- @bprop_getters.register(P.ResizeBilinear)
- def get_bprop_resize_bilinear(self):
- """Grad definition for `ResizeBilinear` operation."""
- resize_grad = G.ResizeBilinearGrad(self.align_corners)
-
- def bprop(x, out, dout):
- dx = resize_grad(dout, x)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.OneHot)
- def get_bprop_onehot(self):
- """Grad definition for `OneHot` operation."""
-
- def bprop(indices, depth, on_value, off_value, out, dout):
- return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
-
- return bprop
-
-
- @bprop_getters.register(P.TopK)
- def get_bprop_top_kv2(self):
- """Grad definition for `TopK` operation."""
- scatter = P.ScatterNd()
- expand_dims = P.ExpandDims()
- shape_op = P.Shape()
- reshape_op = P.Reshape()
- dtype = P.DType()
-
- def bprop(input_x, k, out, dout):
- in_shape = shape_op(input_x)
- in_lastdim = in_shape[-1]
-
- indices = out[1]
- ind_shape = shape_op(indices)
- ind_lastdim = ind_shape[-1]
-
- ind_2d = reshape_op(indices, (-1, ind_lastdim))
- outerdim = shape_op(ind_2d)[0]
-
- # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
- indices_dtype = dtype(indices)
- range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
-
- # expand_dims to (k, 1), then broadcast
- ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
- in_shape_1d = get_1d_shape(in_shape)
-
- out_grad = reshape_op(
- scatter(
- expand_dims(ind, -1),
- reshape_op(dout[0], (-1,)),
- in_shape_1d),
- in_shape)
- return out_grad, zeros_like(k)
-
- return bprop
-
-
- @bprop_getters.register(P.SmoothL1Loss)
- def get_bprop_smooth_l1_loss(self):
- """Grad definition for `SmoothL1Loss` operation."""
- grad = G.SmoothL1LossGrad(self.beta)
-
- def bprop(prediction, target, out, dout):
- dx = grad(prediction, target, dout)
- dy = grad(target, prediction, dout)
- return dx, dy
-
- return bprop
-
-
- @bprop_getters.register(P.L2Loss)
- def get_bprop_l2_loss(self):
- """Grad definition for `L2Loss` operation."""
-
- def bprop(x, out, dout):
- dx = x * dout
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.RNNTLoss)
- def get_bprop_rnnt_loss(self):
- """Grad definition for `RNNTLoss` operation."""
-
- def bprop(acts, labels, act_lens, label_lens, out, dout):
- grad = out[1]
- return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
-
- return bprop
-
-
- @bprop_getters.register(P.PReLU)
- def get_bprop_prelu(self):
- """Grad definition for `PReLU` operation."""
- grad = G.PReLUGrad()
-
- def bprop(x, w, out, dout):
- dx, dw = grad(dout, x, w)
- return dx, dw
-
- return bprop
-
-
- @bprop_getters.register(P.LSTM)
- def get_bprop_lstm(self):
- """Grad definition for `LSTM` operation."""
- lstm_grad_data = G.LSTMGradData(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- num_layers=self.num_layers,
- has_bias=self.has_bias,
- bidirectional=self.bidirectional,
- dropout=self.dropout
- )
-
- lstm_grad_weight = G.LSTMGradWeight(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- num_layers=self.num_layers,
- has_bias=self.has_bias,
- bidirectional=self.bidirectional,
- dropout=self.dropout
- )
- lstm_grad = G.LSTMGrad(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- num_layers=self.num_layers,
- has_bias=self.has_bias,
- bidirectional=self.bidirectional,
- dropout=self.dropout
- )
-
- def bprop(x, hx, cx, w, out, dout):
- y, _, _, reserve, state = out
- dy, dhy, dcy, _, _ = dout
- dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
- dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
- return dx, dhx, dcx, dw
-
- #
- def bprop_cpu(x, hx, cx, w, out, dout):
- y, hy, cy, reserve, _ = out
- dy, dhy, dcy, _, _ = dout
- dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
- return dx, dhx, dcx, dw
-
- if context.get_context('device_target') == "CPU":
- return bprop_cpu
-
- return bprop
-
-
- @bprop_getters.register(P.DynamicRNN)
- def get_bprop_dynamic_rnn(self):
- """Grad definition for `DynamicRNN` operation."""
- dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
- direction=self.direction,
- cell_depth=self.cell_depth,
- use_peephole=self.use_peephole,
- keep_prob=self.keep_prob,
- cell_clip=self.cell_clip,
- num_proj=self.num_proj,
- time_major=self.time_major,
- forget_bias=self.forget_bias)
- expand_dims = P.ExpandDims()
-
- def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
- dy, dh, dc, _, _, _, _, _, = dout
- dh = dh[-1]
- dc = dc[-1]
- y, h, c, i, j, f, o, tanhct = out
- dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
- c, dy, dh, dc, i, j, f, o, tanhct)
- dh_prev = expand_dims(dh_prev, 0)
- dc_prev = expand_dims(dc_prev, 0)
- return dx, dw, db, (0), dh_prev, dc_prev
-
- return bprop
-
-
- @bprop_getters.register(P.DynamicGRUV2)
- def get_bprop_dynamic_gru_v2(self):
- """Grad definition for `DynamicGRUV2` operation."""
- dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
- self.num_proj, self.time_major, self.gate_order,
- self.reset_after)
-
- def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
- y, out_h, update, reset, new, hidden_new = out
- dy, dout_h, _, _, _, _ = dout
-
- dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
- out_h, dy, dout_h[-1], update,
- reset, new, hidden_new, None, None)
- return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
-
- return bprop
-
-
- @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
- def get_bprop_sigmoid_crossentropy_with_logits(self):
- """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
- op = G.SigmoidCrossEntropyWithLogitsGrad()
-
- def bprop(x, y, out, dout):
- dx = op(x, y, dout)
- return (dx, zeros_like(y))
-
- return bprop
-
-
- @bprop_getters.register(P.Pad)
- def get_bprop_pad(self):
- """Grad definition for `Pad` operation."""
- shape_op = P.Shape()
- paddings = self.paddings
-
- def bprop(x, out, dout):
- begin = ()
- for item in paddings:
- begin += (item[0],)
- shp = shape_op(x)
- dx = P.Slice()(dout, begin, shp)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.MirrorPad)
- def get_bprop_mirror_pad(self):
- """Grad definition for `MirrorPad` operation."""
- mirror_pad_grad = G.MirrorPadGrad(self.mode)
-
- def bprop(x, paddings, out, dout):
- dx = mirror_pad_grad(dout, paddings)
- return (dx, zeros_like(paddings))
-
- return bprop
-
-
- @bprop_getters.register(P.ROIAlign)
- def get_bprop_roi_align(self):
- """Grad definition for `ROIAlign` operation."""
- shape_op = P.Shape()
- pooled_height = self.pooled_height
- pooled_width = self.pooled_width
- spatial_scale = self.spatial_scale
- sample_num = self.sample_num
-
- def bprop(inputs, rois, out, dout):
- inputs_shape = shape_op(inputs)
- dx = G.ROIAlignGrad(inputs_shape,
- pooled_height,
- pooled_width,
- spatial_scale,
- sample_num,
- )(dout, rois)
- return dx, zeros_like(rois)
-
- return bprop
-
-
- @bprop_getters.register(P.Conv2DTranspose)
- @bprop_getters.register(P.Conv2DBackpropInput)
- def get_bprop_conv2d_backprop_input(self):
- """Grad definition for `Conv2DBackpropInput` operation."""
- pad_list = self.get_attr_dict()['pad_list']
- out_channel = self.get_attr_dict()['out_channel']
- filter_grad = G.Conv2DBackpropFilter(
- out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
- )
- input_grad = P.Conv2D(
- out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
- dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
- )
- get_shape = P.Shape()
- get_dyn_shape = P.DynamicShape()
-
- def bprop(x, w, f_sizes, out, dout):
- w_shape = get_shape(w)
- if -1 in w_shape:
- w_shape = get_dyn_shape(w)
- dx = input_grad(dout, w)
- dw = filter_grad(x, dout, w_shape)
- return dx, dw, zeros_like(f_sizes)
-
- return bprop
-
-
- @bprop_getters.register(P.BinaryCrossEntropy)
- def get_bprop_binary_cross_entropy(self):
- """Grad definition for `BinaryCrossEntropy` operation."""
- grad = G.BinaryCrossEntropyGrad(self.reduction)
-
- def bprop(x, y, weight, out, dout):
- dx = grad(x, y, dout, weight)
- return dx, zeros_like(y), zeros_like(weight)
-
- return bprop
-
-
- @bprop_getters.register(P.BCEWithLogitsLoss)
- def get_bprop_ce_with_logits_loss(self):
- """Grad definition for `BCEWithLogitsLoss` operation."""
- reduction = self.reduction
- mul = P.Mul()
- sigmoid = P.Sigmoid()
- add = P.Add()
- sub = P.Sub()
- size = P.Size()
- neg = P.Neg()
- log = P.Log()
-
- def bprop(predict, target, weight, pos_weight, out, dout):
- sigmoid_input = sigmoid(predict)
- if pos_weight is not None:
- t = mul(target, pos_weight)
- dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
- grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout)
- else:
- dx = mul((sigmoid_input - target), dout)
- grad_target = mul(predict, neg(dout))
- if weight is not None:
- dx = mul(dx, weight)
- grad_target = mul(grad_target, weight)
- if reduction == 'mean':
- dx = dx / size(dx)
- grad_target = grad_target / size(target)
- return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
-
- return bprop
-
-
- @bprop_getters.register(P.KLDivLoss)
- def get_bprop_kl_div_loss(self):
- """Grad definition for `KLDivLoss` operation."""
- grad = G.KLDivLossGrad(self.reduction)
-
- def bprop(x, y, out, dout):
- dx, dy = grad(x, y, dout)
- return dx, dy
-
- return bprop
-
-
- @bprop_getters.register(P.Dropout)
- def get_bprop_dropout(self):
- """Grad definition for `Dropout` operation."""
- grad = G.DropoutGrad(self.keep_prob)
-
- def bprop(x, out, dout):
- _, mask = out
- dy, _ = dout
- dx = grad(dy, mask)
- return (dx,)
-
- return bprop
-
-
- @bprop_getters.register(P.Dropout2D)
- @bprop_getters.register(P.Dropout3D)
- def get_bprop_dropout3d(self):
- """Grad definition for `Dropout2D` and `Dropout3D` operation."""
- dtype = P.DType()
- cast = P.Cast()
- mul = P.Mul()
- keep_prob = self.keep_prob
-
- def bprop(x, out, dout):
- _, mask = dout
- y = cast(mask, mstype.float32)
- if keep_prob != 0:
- y = y * (1 / keep_prob)
- y = mul(x, y)
- y = cast(y, dtype(x))
- return (y,)
-
- return bprop
-
-
- @bprop_getters.register(P.CTCLoss)
- def get_bprop_ctc_loss(self):
- """Grad definition for `CTCLoss` operation"""
- expand = P.ExpandDims()
-
- def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
- grad_loss = out[1]
- grad = grad_loss * expand(dout[0], -1)
- return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
-
- return bprop
-
-
- @bprop_getters.register(P.BasicLSTMCell)
- def get_bprop_basic_lstm_cell(self):
- """Grad definition for `BasicLSTMCell` operation."""
- basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
- forget_bias=self.forget_bias,
- activation=self.activation
- )
-
- basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
-
- basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
-
- def bprop(x, h, c, w, b, out, dout):
- _, _, it, jt, ft, ot, tanhct = out
- dct, dht, _, _, _, _, _ = dout
- dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
- dxt, dht = basic_lstm_cell_input_grad(dgate, w)
- dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
- return dxt, dht, dct_1, dw, db
-
- return bprop
-
-
- @bprop_getters.register(P.LRN)
- def get_bprop_lrn(self):
- """Grad definition for `LRN` operation."""
- grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
-
- def bprop(x, out, dout):
- dx = grad(dout, x, out)
- return (dx,)
-
- return bprop
|