|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """Operators for gradients."""
- import math
- from functools import partial
- from mindspore._checkparam import _check_3d_int_or_tuple
- from .. import signature as sig
- from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
- from ..._checkparam import Validator as validator, Rel
- from .._utils import get_concat_offset
- from ...common import dtype as mstype
- from .. import functional as F
- from ... import context
-
-
- class AbsGrad(PrimitiveWithInfer):
- """Computes gradients for abs operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize AbsGrad"""
-
- def infer_shape(self, y, dy):
- return y
-
- def infer_dtype(self, y, dy):
- return y
-
-
- class ACosGrad(PrimitiveWithInfer):
- """
- Computes ACosGrad of input element-wise.
-
- Returns:
- Tensor, has the same type as input.
- """
-
- @prim_attr_register
- def __init__(self):
- """Initialize ACosGrad"""
-
- def infer_shape(self, x, dout):
- validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
- return x
-
- def infer_dtype(self, x, dout):
- args = {"x": x, "dout": dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return x
-
-
- class AcoshGrad(PrimitiveWithInfer):
- """Performs grad of Acosh operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize AcoshGrad"""
-
- def infer_shape(self, x, dout):
- validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
- return x
-
- def infer_dtype(self, x, dout):
- args = {"x": x, "dout": dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return x
-
-
- class AsinGrad(PrimitiveWithInfer):
- """
- Computes AsinGrad of input element-wise.
-
- Returns:
- Tensor, has the same type as input.
- """
-
- @prim_attr_register
- def __init__(self):
- """Initialize AsinGrad"""
-
- def infer_shape(self, x, dout):
- validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
- return x
-
- def infer_dtype(self, x, dout):
- args = {"x": x, "dout": dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return x
-
-
- class AsinhGrad(PrimitiveWithInfer):
- """Performs grad of Asinh operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize AsinhGrad"""
-
- def infer_shape(self, x, dout):
- validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
- return x
-
- def infer_dtype(self, x, dout):
- args = {"x": x, "dout": dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return x
-
-
- class ReciprocalGrad(PrimitiveWithInfer):
- """Performs grad of Reciprocal operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize ReciprocalGrad"""
-
- def infer_shape(self, x_shape, dout_shape):
- validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, x_dtype, dout_dtype):
- args = {"x": x_dtype, "dout": dout_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
- return x_dtype
-
-
- class RsqrtGrad(PrimitiveWithInfer):
- """Performs grad of Rsqrt operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize RsqrtGrad"""
-
- def infer_shape(self, x_shape, dout_shape):
- validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, x_dtype, dout_dtype):
- args = {"x": x_dtype, "dout": dout_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8],
- self.name)
- return x_dtype
-
-
- class SoftmaxGrad(PrimitiveWithInfer):
- """Performs grad of Softmax operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize SoftmaxGrad"""
-
- def infer_shape(self, x_shape, dout_shape):
- validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, x_dtype, dout_dtype):
- args = {"x": x_dtype, "dout": dout_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
- return x_dtype
-
-
- class SqrtGrad(PrimitiveWithInfer):
- """Performs grad of Sqrt operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize SqrtGrad"""
-
- def infer_shape(self, x_shape, dout_shape):
- validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, x_dtype, dout_dtype):
- args = {"x": x_dtype, "dout": dout_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
- return x_dtype
-
-
- class BatchNormGrad(PrimitiveWithInfer):
- """Performs grad of BatchNorm operation."""
-
- @prim_attr_register
- def __init__(self, is_training=False, epsilon=1e-5):
- self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
- self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
- self.add_prim_attr('data_format', "NCHW")
-
- def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
- validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
- return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)
-
- def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type):
- return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
-
-
- class BiasAddGrad(PrimitiveWithInfer):
- """Computes gradients of BiasAdd."""
-
- @prim_attr_register
- def __init__(self, data_format="NCHW"):
- self.init_prim_io_names(inputs=['dout'], outputs=['output'])
- self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
- if context.get_context("device_target") != "GPU" and self.format == "NHWC":
- raise ValueError("NHWC format only support in GPU target.")
- self.add_prim_attr('data_format', self.format)
-
- def infer_shape(self, d_output):
- channel = d_output[1] if self.format == "NCHW" else d_output[-1]
- return (channel,)
-
- def infer_dtype(self, dout_dtype):
- return dout_dtype
-
-
- class KLDivLossGrad(PrimitiveWithInfer):
- """Computes gradients for `KLDivLoss` operation."""
-
- @prim_attr_register
- def __init__(self, reduction='mean'):
- self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
-
- def infer_shape(self, x_shape, y_shape, doutput_shape):
- validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
- return x_shape, y_shape
-
- def infer_dtype(self, x_type, y_type, doutput_type):
- args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
- validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
- return x_type, y_type
-
-
- class BinaryCrossEntropyGrad(PrimitiveWithInfer):
- """Computes gradients for `BinaryCrossEntropy` operation."""
-
- @prim_attr_register
- def __init__(self, reduction='mean'):
- self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
-
- def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
- validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
- if weight_shape:
- validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
- args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
- validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
- if weight_type:
- validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
- return x_type
-
-
- class ConcatOffset(PrimitiveWithInfer):
- """primitive for computing Concat's gradient."""
-
- @prim_attr_register
- def __init__(self, N=2, axis=0):
- """Initialize ConcatOffset"""
-
- def __infer__(self, input_x):
- axis = self.axis
- x_shp = input_x['shape']
- x_type = input_x['dtype']
- offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
- self.add_prim_attr('T', x_type[0].element_type())
- offset_values = []
- for i in range(len(x_shp)):
- values = []
- for j in range(len(x_shp[0])):
- value = 0
- if j == axis:
- value = offset[i]
- values.append(value)
- offset_values.append(tuple(values))
- out = {'shape': None,
- 'dtype': None,
- 'value': tuple(offset_values)}
- return out
-
-
- class Conv3DBackpropFilter(PrimitiveWithInfer):
- """
- Computes the gradients of convolution 3D with respect to the filter.
-
- Args:
- out_channel (int): The dimension of the output.
- kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
- mode (int): Modes for different convolutions. Not currently used.
- pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
- pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
- head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
- integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
- pad[3], pad[4] and pad[5] correspondingly.
- stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
- dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
- group (int): Splits input into groups. Default: 1.
- data_format (str): The optional value for data format. Currently only support 'NCDHW'.
-
- Inputs:
- - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
- - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default
- data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
- - **w_size** (Tensor) - A tuple describes the shape of the weight which conforms to the format
- :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
-
- Outputs:
- Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight.
-
- Supported Platforms:
- ``Ascend``
-
- Examples:
- >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16)
- >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
- >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
- >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
- >>> output = conv3d_backprop_input(x, dout, F.shape(w))
- >>> print(output.shape)
- (32, 32, 4, 6, 2)
- """
-
-
- @prim_attr_register
- def __init__(self,
- out_channel,
- kernel_size,
- pad_mode="valid",
- pad=0,
- mode=1,
- stride=(1, 1, 1, 1, 1),
- dilation=(1, 1, 1, 1, 1),
- group=1,
- data_format="NCDHW"):
- """Initialize Convolution"""
- self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y'])
- self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
- self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
- self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
- self.add_prim_attr('strides', self.stride)
- self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
- self.add_prim_attr('dilations', self.dilation)
- validator.check_value_type('pad', pad, (int, tuple), self.name)
- if isinstance(pad, int):
- pad = (pad,) * 6
- validator.check_equal_int(len(pad), 6, 'pad size', self.name)
- self.pad_list = pad
- self.add_prim_attr('pads', self.pad_list)
-
- self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
- if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
- raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
- if self.pad_mode == 'pad':
- for item in pad:
- validator.check_non_negative_int(item, 'pad item', self.name)
- self.add_prim_attr('pad_mode', self.pad_mode)
-
- self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
- self.group = validator.check_positive_int(group, 'group', self.name)
- self.add_prim_attr('groups', self.group)
- self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
- self.add_prim_attr('data_format', self.format)
- self.add_prim_attr('io_format', "NCDHW")
-
- def __infer__(self, x, doutput, w_size):
- w_size_v = w_size['value']
- validator.check_value_type('w_size', w_size_v, [tuple], self.name)
- for i, dim_len in enumerate(w_size_v):
- validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
- args = {"x": x['dtype'], "doutput": doutput['dtype']}
- valid_dtypes = [mstype.float16, mstype.float32]
- validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
-
- validator.check("filter's batch", w_size_v[0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
- validator.check("filter's channel", w_size_v[1], "input_size's channel", x['shape'][1], Rel.EQ, self.name)
- validator.check("input_size's batch", x['shape'][0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name)
-
- # infer shape
- x_shape = x['shape']
- dout_shape = doutput['shape']
- kernel_d = self.kernel_size[0]
- kernel_h = self.kernel_size[1]
- kernel_w = self.kernel_size[2]
- stride_d = self.stride[2]
- stride_h = self.stride[3]
- stride_w = self.stride[4]
- dilation_d = self.dilation[2]
- dilation_h = self.dilation[3]
- dilation_w = self.dilation[4]
- # The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
- if self.pad_mode == "valid":
- self.pad_list = (0, 0, 0, 0, 0, 0)
- if self.pad_mode == "same":
- pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_shape[2])
- pad_head = math.floor(pad_needed_d / 2)
- pad_tail = pad_needed_d - pad_head
-
- pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_shape[3])
- pad_top = math.floor(pad_needed_h / 2)
- pad_bottom = pad_needed_h - pad_top
-
- pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_shape[4])
- pad_left = math.floor(pad_needed_w / 2)
- pad_right = pad_needed_w - pad_left
- self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
-
- self.add_prim_attr('pads', self.pad_list)
- out = {
- 'value': None,
- 'shape': w_size_v,
- 'dtype': mstype.float32,
- }
- return out
-
-
- class Conv2DBackpropFilter(PrimitiveWithInfer):
- """
- Computes the gradients of convolution with respect to the filter.
-
- Args:
- out_channel (int): The dimensionality of the output space.
- kernel_size (Union[int, tuple[int]]): The size of the convolution window.
- pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
- pad (int): The pad value to be filled. Default: 0.
- mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
- 2 deconvolution, 3 depthwise convolution. Default: 1.
- stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
- dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
- group (int): Splits input into groups. Default: 1.
- data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\
- default is 'NCHW'.
-
- Returns:
- Tensor, the gradients of convolution.
- """
-
- @prim_attr_register
- def __init__(self,
- out_channel,
- kernel_size,
- pad_mode="valid",
- pad=0,
- pad_list=(0, 0, 0, 0),
- mode=1,
- stride=(1, 1),
- dilation=(1, 1, 1, 1),
- group=1,
- data_format="NCHW"):
- """Initialize Convolution"""
- self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
- self.out_channel = out_channel
- self.kernel_size = kernel_size
- self.mode = mode
- pad_mode = pad_mode.upper()
- self.add_prim_attr('pad_mode', pad_mode)
- self.pad = pad
- if isinstance(stride, tuple) and len(stride) == 4:
- self.stride = (stride[2], stride[3])
- self.add_prim_attr('stride', self.stride)
- self.dilation = dilation
- self.group = group
- self.add_prim_attr('groups', group)
- self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
- if context.get_context("device_target") != "GPU" and self.format == "NHWC":
- raise ValueError("NHWC format only support in GPU target.")
- self.add_prim_attr('data_format', self.format)
-
- def __infer__(self, doutput, x, w_size):
- w_size_v = w_size['value']
- validator.check_value_type('w_size', w_size_v, [tuple], self.name)
- for i, dim_len in enumerate(w_size_v):
- validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
- args = {"x": x['dtype'], "doutput": doutput['dtype']}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
- self.name)
- out = {
- 'value': None,
- 'shape': w_size_v,
- 'dtype': doutput['dtype'],
- }
- return out
-
-
- class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
- """
- Returns the gradient of filter for DepthwiseConv2dNative.
-
- Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
-
- Refer to class DepthwiseConv2dNative for more details.
-
- Args:
- channel_multiplier (int): The multipiler for the original output conv.
- kernel_size (int or tuple): The size of the conv kernel.
- mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution,
- 2 deconvolution,3 depthwise convolution. Defaul: 3.
- pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
- pad (int): The pad value to be filled. Default: 0.
- pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
- stride (int): The stride to be applied to the convolution filter. Default: 1.
- dilation (int): Specifies the space to use between kernel elements. Default: 1.
- group (int): Splits input into groups. Default: 1.
-
- Returns:
- Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
- """
-
- @prim_attr_register
- def __init__(self,
- channel_multiplier,
- kernel_size,
- pad_mode="valid",
- pad=0,
- pads=(0, 0, 0, 0),
- mode=3,
- stride=1,
- dilation=1,
- group=1):
- """Initialize Convolution"""
- self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
- self.channel_multiplier = channel_multiplier
- self.kernel_size = kernel_size
- self.mode = mode
- self.pad_mode = pad_mode
- self.pad = pad
- self.pads = pads
- self.stride = stride
- self.dilation = dilation
- self.group = group
- self.add_prim_attr('data_format', "NCHW")
-
- def __call__(self, x, w_size, dout):
- raise NotImplementedError
-
- def __infer__(self, x, w_size, dout):
- w_size_v = w_size['value']
- args = {'x': x['dtype'], 'dout': dout['dtype']}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- out = {
- 'value': None,
- 'shape': w_size_v,
- 'dtype': dout['dtype'],
- }
- return out
-
-
- class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
- """
- Returns the gradient of input for DepthwiseConv2dNative.
-
- Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
-
- Args:
- channel_multiplier (int): The multipiler for the original output conv.
- kernel_size (int or tuple): The size of the conv kernel.
- mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
- 2 deconvolution,3 depthwise convolution. Default: 3.
- pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
- pad (int): The pad value to be filled. Default: 0.
- pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
- stride (int): The stride to be applied to the convolution filter. Default: 1.
- dilation (int): Specifies the space to use between kernel elements. Default: 1.
- group (int): Splits input into groups. Default: 1.
-
- Returns:
- Tensor, the value is the gradient of input for DepthwiseConv2dNative.
- """
-
- @prim_attr_register
- def __init__(self,
- channel_multiplier,
- kernel_size,
- pad_mode="valid",
- pad=0,
- pads=(0, 0, 0, 0),
- mode=3,
- stride=1,
- dilation=1,
- group=1):
- """Initialize Convolution"""
- self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
- self.channel_multiplier = channel_multiplier
- self.kernel_size = kernel_size
- self.mode = mode
- self.pad_mode = pad_mode
- self.pad = pad
- self.pads = pads
- self.stride = stride
- self.dilation = dilation
- self.group = group
- self.add_prim_attr('data_format', "NCHW")
-
- def __call__(self, x_size, w, dout):
- raise NotImplementedError
-
- def __infer__(self, x_size, w, dout):
- args = {'w': w['dtype'], 'dout': dout['dtype']}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- x_size_v = x_size['value']
- out = {
- 'value': None,
- 'shape': x_size_v,
- 'dtype': dout['dtype'],
- }
- return out
-
-
- class DropoutGrad(PrimitiveWithInfer):
- """
- The gradient of Dropout. During training, randomly zeroes some of the elements
- of the input tensor with probability.
-
- Args:
- keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
- means dropping out 10% of input units.
-
- Inputs:
- - **shape** (tuple[int]) - The shape of target mask.
-
- Outputs:
- Tensor, the value of generated mask for input shape.
-
- Examples:
- >>> dropout_grad = ops.DropoutGrad(keep_prob=0.5)
- >>> in = Tensor((20, 16, 50, 50))
- >>> out = dropout_grad(in)
- """
-
- @prim_attr_register
- def __init__(self, keep_prob=0.5):
- self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
-
- def infer_shape(self, dy_shape, mask_shape):
- return dy_shape
-
- def infer_dtype(self, dy_dtype, mask_dtype):
- valid_dtypes = (mstype.float16, mstype.float32)
- validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
- validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name)
- return dy_dtype
-
-
- class FlattenGrad(PrimitiveWithInfer):
- """Performs gradients of Flatten."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
-
- def __infer__(self, *args):
- out = {
- 'value': None,
- 'shape': args[1]['value'],
- 'dtype': args[0]['dtype'],
- }
- return out
-
-
- class FusedBatchNormGrad(Primitive):
- """Gradients of FusedBatchNorm operation."""
-
- @prim_attr_register
- def __init__(self, epsilon=0.0, momentum=0.1):
- self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance'],
- outputs=['dx', 'bn_scale', 'bn_bias'])
-
- def __call__(self, dy, x, scale, save_mean, save_inv_variance):
- raise NotImplementedError
-
-
- class FusedBatchNormGradCPU(PrimitiveWithInfer):
- """Gradients of FusedBatchNorm operation for CPU."""
-
- @prim_attr_register
- def __init__(self, epsilon=0.0, momentum=0.1):
- self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'bias', 'save_mean', 'save_inv_variance'],
- outputs=['dx', 'bn_scale', 'bn_bias'])
- self.add_prim_attr('data_format', "NCHW")
-
- def infer_shape(self, dy_shape, x_shape, scale_shape, bias_shape, save_mean_shape, save_inv_variance_shape):
- return (x_shape, scale_shape, bias_shape)
-
- def infer_dtype(self, dy_type, x_type, scale_type, bias_type, save_mean_type, save_inv_variance_type):
- return (x_type, scale_type, bias_type)
-
-
- class FusedBatchNormGradEx(PrimitiveWithInfer):
- """Gradients of FusedBatchNormEx operation."""
-
- @prim_attr_register
- def __init__(self, epsilon=0.0, momentum=0.1, data_format="NCHW"):
- self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'],
- outputs=['dx', 'bn_scale', 'bn_bias'])
- self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
- if context.get_context("device_target") != "GPU" and self.format == "NHWC":
- raise ValueError("NHWC format only support in GPU target.")
- self.add_prim_attr('data_format', self.format)
-
- def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape):
- return (x_shape, scale_shape, scale_shape)
-
- def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_type, save_variance_type, reserve_type):
- return (x_type, scale_type, scale_type)
-
-
- class UniqueGrad(Primitive):
- """Gradients of Unique operation."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
-
- def __call__(self, dy, x, scale, save_mean, save_inv_variance):
- raise NotImplementedError
-
-
- class BNTrainingReduceGrad(PrimitiveWithInfer):
- """Gradients of FusedBatchNorm operation."""
-
- @prim_attr_register
- def __init__(self, epsilon=0.0001):
- _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
- self.init_prim_io_names(inputs=_inputs, outputs=['y'])
-
- def infer_shape(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
- return grads
-
- def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
- return grads
-
-
- class BNTrainingUpdateGrad(PrimitiveWithInfer):
- """Gradients of FusedBatchNorm operation."""
-
- @prim_attr_register
- def __init__(self, epsilon=0.0001):
- self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
- outputs=['diff_scale', 'diff_offset'])
-
- def infer_shape(self, grads, x, batch_mean, batch_variance):
- return (batch_mean, batch_variance)
-
- def infer_dtype(self, grads, x, batch_mean, batch_variance):
- return (batch_mean, batch_variance)
-
-
- class GeluGrad(PrimitiveWithInfer):
- """Gradients of Gelu operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize GeluGrad"""
-
- def infer_shape(self, y_backprop_shape, x_shape, y_shape):
- return x_shape
-
- def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
- tuple(map(partial(validator.check_tensor_dtype_valid,
- valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
- ("y_backprop", "x", "y"),
- (y_backprop_dtype, x_dtype, y_dtype)))
- return x_dtype
-
-
- class FastGeluGrad(PrimitiveWithInfer):
- """Gradients of FastGelu operation."""
-
- @prim_attr_register
- def __init__(self):
- """init FastGeluGrad"""
-
- def infer_shape(self, y_backprop_shape, x_shape):
- return x_shape
-
- def infer_dtype(self, y_backprop_dtype, x_dtype):
- tuple(map(partial(validator.check_tensor_dtype_valid,
- valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
- ("y_backprop", "x"),
- (y_backprop_dtype, x_dtype)))
- return x_dtype
-
-
- class _PoolGrad(PrimitiveWithInfer):
- """Gradients of the max/avg pool operation."""
-
- @prim_attr_register
- def __init__(self, ksize, strides, padding="VALID", data_format="NCHW"):
- self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
-
- validator.check_value_type('ksize', ksize, [int, tuple], self.name)
- validator.check_value_type('strides', strides, [int, tuple], self.name)
- self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
- self.add_prim_attr("padding", self.padding)
- self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
- if context.get_context("device_target") != "GPU" and self.format == "NHWC":
- raise ValueError("NHWC format only support in GPU target.")
- self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
- if not self.is_maxpoolgradwithargmax:
- self.add_prim_attr('data_format', self.format)
-
- def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
- validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
- error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
- f"or a tuple of two or four positive int numbers, but got {arg_val}")
- if isinstance(arg_val, int):
- ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
- elif len(arg_val) == 2:
- ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
- elif len(arg_val) == 4:
- ret = arg_val
- else:
- raise error_msg
- # whether all elements of tuple are positive integers
- for item in ret:
- if not isinstance(item, int) or item <= 0:
- raise error_msg
- return ret
-
- ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
- self.ksize = ksize if self.format == "NCHW" else [ksize[0], ksize[2], ksize[3], ksize[1]]
- self.add_prim_attr("ksize", self.ksize)
-
- strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
- self.strides = strides if self.format == "NCHW" else [strides[0], strides[2], strides[3], strides[1]]
- self.add_prim_attr("strides", self.strides)
-
-
- class AvgPoolGrad(_PoolGrad):
- """Gradients of the avg pool operation for ge."""
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID"):
- super(AvgPoolGrad, self).__init__(ksize, strides, padding)
-
- def __infer__(self, origin_input, dout):
- out = {
- 'value': None,
- 'shape': tuple(origin_input['value']),
- 'dtype': dout['dtype'],
- }
-
- return out
-
-
- class AvgPoolGradVm(_PoolGrad):
- """Gradients of the avg pool operation for vm."""
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID"):
- super(AvgPoolGradVm, self).__init__(ksize, strides, padding)
- self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
-
- def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
- out = {
- 'value': None,
- 'shape': tuple(origin_input['value']),
- 'dtype': dout['dtype'],
- }
-
- return out
-
-
- class AvgPoolGradGpu(_PoolGrad):
- """Gradients of the avg pool operation for gpu."""
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"):
- super(AvgPoolGradGpu, self).__init__(ksize, strides, padding, data_format)
-
- def infer_shape(self, x1_shape, x2_shape, grad_shape):
- return x1_shape
-
- def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
- return x1_dtype
-
-
- class MaxPoolGrad(_PoolGrad):
- """Performs gradients of the max pool operation."""
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"):
- super(MaxPoolGrad, self).__init__(ksize, strides, padding, data_format)
-
- def infer_shape(self, x1_shape, x2_shape, grad_shape):
- return x1_shape
-
- def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
- return x1_dtype
-
-
- class MaxPoolGradGrad(_PoolGrad):
- r"""
- Performs gradients of the MaxPoolGrad operation.
-
- Args:
- ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
- is an int number that represents height and width are both ksize, or a tuple
- of two int numbers that represent height and width respectively. Default: 1.
- strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
- the height and width of movement are both strides, or a tuple of two int numbers that
- represent height and width of movement respectively. Default: 1.
- padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
- Default: "valid".
-
- - same: Adopts the way of completion. The height and width of the output will be the same as
- the input. The total number of padding will be calculated in horizontal and vertical
- directions and evenly distributed to top and bottom, left and right if possible.
- Otherwise, the last extra padding will be done from the bottom and the right side.
-
- - valid: Adopts the way of discarding. The possible largest height and width of output
- will be returned without padding. Extra pixels will be discarded.
-
- Inputs:
- - **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be float16.
- - **origin_output** (Tensor) - Data type same as `origin_input`.
- - **grad** (Tensor) - Data type same as `origin_input`.
-
- Outputs:
- Tensor, with data type same as `origin_input`.
-
- """
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID"):
- super(MaxPoolGradGrad, self).__init__(ksize, strides, padding)
-
- def infer_shape(self, x1_shape, x2_shape, grad_shape):
- return x1_shape
-
- def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
- args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
- return x1_dtype
-
-
- class MaximumGrad(Primitive):
- """Grad for maximum."""
-
- @prim_attr_register
- def __init__(self, grad_x=True, grad_y=True):
- """Initialize MaximumGrad"""
-
- def __call__(self, x, y, dout):
- raise NotImplementedError
-
-
- class MaxPoolGradWithArgmax(_PoolGrad):
- """Computes the gradients of MaxPoolWithArgmax."""
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID"):
- self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
- super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding)
-
- def infer_shape(self, x_shape, grad_shape, argmax_shape):
- if not grad_shape:
- raise TypeError("The dout of MaxPoolGradWithArgmax should be a Tensor.")
- return x_shape
-
- def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
- return grad_dtype
-
-
- class MaxPoolGradGradWithArgmax(_PoolGrad):
- r"""
- Computes the gradients of MaxPoolGradWithArgmax.
-
- Args:
- ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
- is an int number that represents height and width are both ksize, or a tuple
- of two int numbers that represent height and width respectively. Default: 1.
- strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
- the height and width of movement are both strides, or a tuple of two int numbers that
- represent height and width of movement respectively. Default: 1.
- padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
- Default: "valid".
-
- - same: Adopts the way of completion. The height and width of the output will be the same as
- the input. The total number of padding will be calculated in horizontal and vertical
- directions and evenly distributed to top and bottom, left and right if possible.
- Otherwise, the last extra padding will be done from the bottom and the right side.
-
- - valid: Adopts the way of discarding. The possible largest height and width of output
- will be returned without padding. Extra pixels will be discarded.
-
- Inputs:
- - **x** (Tensor) - Tensor with data format "NCHW", data type must be float16.
- - **grad** (Tensor) - Data type same as `x`.
- - **argmax** (Tensor) - Data type must be uint16 or int64.
-
- Outputs:
- Tensor, with data type same as `x`.
-
- """
-
- @prim_attr_register
- def __init__(self, ksize=1, strides=1, padding="VALID"):
- self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
- super(MaxPoolGradGradWithArgmax, self).__init__(ksize, strides, padding)
-
- def infer_shape(self, x_shape, grad_shape, argmax_shape):
- if not grad_shape:
- raise TypeError("The dout of MaxPoolGradGradWithArgmax should be a Tensor.")
- return x_shape
-
- def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
- args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
- return grad_dtype
-
-
- class MinimumGrad(Primitive):
- """Grad for minimum."""
-
- @prim_attr_register
- def __init__(self, grad_x=True, grad_y=True):
- """Initialize MinimumGrad"""
-
- def __call__(self, x, y, dout):
- raise NotImplementedError
-
-
- class L2NormalizeGrad(PrimitiveWithInfer):
- r"""
- Gradients of L2 normalize.
-
- Args:
- axis (int): The begin axis for the input to apply L2 normalize. Default: 0.
- epsilon (float): A small value added for numerical stability. Default: 1e-4.
-
- Inputs:
- - **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize.
- - **out** (Tensor) - Must be the output of forward operator L2Normalize.
- - **dout** (Tensor) - The backprop of the next layer.
-
- Outputs:
- Tensor, gradients of L2Normalize `input_x`.
- """
-
- @prim_attr_register
- def __init__(self, axis=0, epsilon=1e-4):
- validator.check_value_type('axis', axis, [int], self.name)
- validator.check_value_type('epsilon', epsilon, [int, float], self.name)
-
- def infer_shape(self, input_x, out, dout):
- validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name)
- validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name)
- return input_x
-
- def infer_dtype(self, input_x, out, dout):
- args = {'input_x': input_x, 'out': out, 'dout': dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return input_x
-
-
- class LayerNormGrad(Primitive):
- """
- Applies the layer normalization to the input array.
-
- This operator will calculate the input gradients of layernorm.
-
- Args:
- begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
- begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
-
- Returns:
- tuple[int], tuple of 3 values (the gradients of layernorm input, gamma, beta).
- """
-
- @prim_attr_register
- def __init__(self, begin_norm_axis=1, begin_params_axis=1):
- """init"""
- self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
- self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
-
- def __call__(self, x, dy, variance, mean, gamma):
- raise NotImplementedError
-
-
- class LogSoftmaxGrad(PrimitiveWithInfer):
- """Computes gradient for the Log Softmax activation."""
-
- @prim_attr_register
- def __init__(self, axis=-1):
- """Initialize LogSoftmaxGrad"""
- validator.check_value_type("axis", axis, [int], self.name)
-
- def infer_shape(self, dout, logits):
- rank = len(logits)
- validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
- return logits
-
- def infer_dtype(self, dout, logits):
- validator.check_subclass("logits", logits, mstype.tensor, self.name)
- return logits
-
-
- class LSTMGradData(PrimitiveWithInfer):
- """Computes the data gradients of LSTM."""
-
- @prim_attr_register
- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
- self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
- self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
- self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
- self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
- self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
- self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
-
- if bidirectional:
- self.num_directions = 2
- else:
- self.num_directions = 1
-
- def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
- hx_shape, cx_shape, reserve_shape, state_shape):
- # dhy and dcy should be same shape
- validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
- validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
- validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
- validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
- validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
-
- validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
- validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
-
- # dy: (seq_len, batch_size, hidden_size * num_directions)
- validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
- validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
- validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
-
- # (seq_len, batch_size, input_size)
- dx_shape = (y_shape[0], y_shape[1], self.input_size)
- dhx_shape = dhy_shape
- dcx_shape = dcy_shape
-
- return (dx_shape, dhx_shape, dcx_shape)
-
- def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
- hx_dtype, cx_dtype, reserve_dtype, state_dtype):
- args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
- return (dy_dtype, dy_dtype, dy_dtype)
-
-
- class LSTMGradWeight(PrimitiveWithInfer):
- """Computes the weight gradients of LSTM."""
-
- @prim_attr_register
- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
- self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
- self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
- self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
- self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
- self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
- self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
-
- if bidirectional:
- self.num_directions = 2
- else:
- self.num_directions = 1
-
- def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
- weight_size = 0
- gate_size = 4 * self.hidden_size
- for layer in range(self.num_layers):
- for _ in range(self.num_directions):
- input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
- weight_size += gate_size * input_layer_size
- weight_size += gate_size * self.hidden_size
- if self.has_bias:
- weight_size += 2 * gate_size
-
- return (weight_size, 1, 1)
-
- def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
- return hx_dtype
-
-
- class LSTMGrad(PrimitiveWithInfer):
- """Computes the data and weight gradients of LSTM."""
-
- @prim_attr_register
- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
- self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
- self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
- self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
- self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
- self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
- self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
- self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
-
- if bidirectional:
- self.num_directions = 2
- else:
- self.num_directions = 1
-
- def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
- dcy_shape, reserve_shape):
- # dhy and dcy should be same shape
- validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
- validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
- validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
- validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
- validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
-
- validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
- validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
-
- # dy: (seq_len, batch_size, hidden_size * num_directions)
- validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
- validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
- validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
-
- # (seq_len, batch_size, input_size)
- dx_shape = (y_shape[0], y_shape[1], self.input_size)
- dhx_shape = dhy_shape
- dcx_shape = dcy_shape
- weight_size = 0
- gate_size = 4 * self.hidden_size
- for layer in range(self.num_layers):
- for _ in range(self.num_directions):
- input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
- weight_size += gate_size * input_layer_size
- weight_size += gate_size * self.hidden_size
- if self.has_bias:
- weight_size += gate_size
-
- return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1))
-
- def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype,
- dcy_dtype, reserve_dtype):
- return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
-
-
- class DynamicRNNGrad(PrimitiveWithInfer):
- """Computes the input gradients of DynamicRNN."""
-
- @prim_attr_register
- def __init__(self,
- cell_type='LSTM',
- direction='UNIDIRECTIONAL',
- cell_depth=1,
- use_peephole=False,
- keep_prob=1.0,
- cell_clip=-1.0,
- num_proj=0,
- time_major=True,
- forget_bias=0.0):
- self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
- self.add_prim_attr("io_format", "ND")
-
- def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
- c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
- validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
- num_step, batch_size, input_size = x_shape
- hidden_size = w_shape[-1] // 4
- if w_shape[-1] % 4 != 0:
- raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
- validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
- input_size + hidden_size, Rel.EQ, self.name)
- valid_shape = [num_step, batch_size, hidden_size]
- validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
- validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
- validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
-
- return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape
-
- def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype,
- c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype):
- return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
-
-
- class DynamicGRUV2Grad(PrimitiveWithInfer):
- r"""
- Computes the input gradients of DynamicGRUV2.
-
- Args:
- direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
- Only 'UNIDIRECTIONAL' is currently supported.
- cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
- keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
- cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
- num_proj (int): An integer identifying the num proj in the op. Default: 0.
- time_major (bool): A bool identifying the time major in the op. Default: True.
- bias_type (str): An string identifying the type of bias_type function in the op. Default to "double_bias".
- gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh.
- 'zrh' is another option.
- reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True.
-
- Inputs:
- - **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
- The data type must be float16 or float32.
- - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`.
- The data type must be float16 or float32.
- - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden_size, 3 x hidden_size)`.
- The data type must be float16 or float32.
- - **y** (Tensor) - A Tensor of shape :math:
- if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
- if num_proj == 0 `(num_step, batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **init_h** (Tensor) - Hidden state of initial time.
- Tensor of shape :math:`(batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
- - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
- - **update** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **reset** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
- The data type must be float16 or float32.
- - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`.
- Only `None` is currently supported.
- - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32.
-
- Outputs:
- - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`.
- Has the same type with input `x`.
- - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
- Has the same type with input `x`.
- - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`.
- Has the same type with input `x`.
- - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`.
- Has the same type with input `x`.
- - **dx** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
- Has the same type with input `x`.
- - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch_size, hidden_size)`.
- Has the same type with input `x`.
- """
-
- @prim_attr_register
- def __init__(self,
- direction='UNIDIRECTIONAL',
- cell_depth=1,
- keep_prob=1.0,
- cell_clip=-1.0,
- num_proj=0,
- time_major=True,
- bias_type="double_bias",
- gate_order="rzh",
- reset_after=True):
- self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
- self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
- self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
- self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
- self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
- self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
- self.bias_type = validator.check_string(bias_type,
- ['no_bias', 'single_bias', 'double_bias'], "bias_type", self.name)
- self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
- self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
- self.add_prim_attr("io_format", "ND")
-
- def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape,
- dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape):
- validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
- validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
- validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
- validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name)
- num_step, batch_size, input_size = x_shape
- hidden_size = whidden_shape[0]
- validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size",
- 3 * hidden_size, Rel.EQ, self.name)
- validator.check("weight_input_shape", winput_shape, "excepted shape",
- [input_size, 3 * hidden_size], Rel.EQ, self.name)
- if self.num_proj > 0:
- valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)]
- else:
- valid_y_shape = [num_step, batch_size, hidden_size]
- validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name)
-
- validator.check("init_h_shape", init_h_shape, "excepted shape",
- [batch_size, hidden_size], Rel.EQ, self.name)
- valid_shape = [num_step, batch_size, hidden_size]
- validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("dh_shape", dh_shape, "excepted shape",
- [batch_size, hidden_size], Rel.EQ, self.name)
- validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
- if seq_shape is not None:
- validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name)
-
- dx_shape = (num_step, batch_size, input_size)
- dh_shape = (batch_size, hidden_size)
- dwinput_shape = (input_size, 3 * hidden_size)
- dwhidden_shape = (hidden_size, 3 * hidden_size)
- db_shape = (3 * hidden_size,)
- return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape
-
- def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype,
- dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype):
- valid_types = (mstype.float16, mstype.float32)
- args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype,
- "dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype,
- "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
- validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
- validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
- validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
- validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name)
- validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
- if seq_dtype is not None:
- validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
- if mask_dtype is not None:
- validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
- return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
-
-
- class PReLUGrad(PrimitiveWithInfer):
- r"""
- Gradients of PReLU operation.
-
- Note:
- 1-dimensional input_x is not supported.
-
- Inputs:
- - **y_backprop** (Tensor) - Representing the backprop of the next layer.
- - **input_x** (Tensor) - Must be the input `input_x` of forward operator PRelu.
- - **weight** (Tensor) - Float Tensor, w > 0, must be the input `weight` of forward operator PRelu.
-
- Outputs:
- Tensor, with the same type as `input_x`.
- """
-
- @prim_attr_register
- def __init__(self):
- pass
-
- def infer_shape(self, y_backprop_shape, A_shape, w_shape):
- if len(A_shape) == 1:
- raise ValueError(f'For \'{self.name}\' input_x rank 1 is not supported.')
- return y_backprop_shape, w_shape
-
- def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
- tuple(map(partial(validator.check_tensor_dtype_valid,
- valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
- ('y_backprop', "input_x", "weight"),
- (y_backprop_dtype, A_dtype, w_dtype)))
- return y_backprop_dtype, w_dtype
-
-
- class ReluGrad(Primitive):
- """Performs grad of Relu operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize ReluGrad"""
- self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
-
- def __call__(self, y_backprop, x):
- raise NotImplementedError
-
-
- class ReLU6Grad(PrimitiveWithInfer):
- """Performs grad of ReLU6 operation."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
-
- def __call__(self, y_grad, x):
- raise NotImplementedError
-
- def infer_shape(self, y_grad_shape, x_shape):
- return x_shape
-
- def infer_dtype(self, y_grad_dtype, x_dtype):
- valid_dtypes = (mstype.float16, mstype.float32)
- validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
- validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
- return x_dtype
-
-
- class ReluGradV2(PrimitiveWithInfer):
- """Performs grad of ReLUV2 operation."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
-
- def __call__(self, gradients, mask):
- raise NotImplementedError
-
- def infer_shape(self, gradients_shape, mask_shape):
- return gradients_shape
-
- def infer_dtype(self, gradients_dtype, mask_dtype):
- validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name)
- validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name)
- return gradients_dtype
-
-
- class EluGrad(PrimitiveWithInfer):
- """Performs grad of Elu operation."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize EluGrad"""
-
- def infer_shape(self, y_grad_shape, x_shape):
- return x_shape
-
- def infer_dtype(self, y_grad_dtype, x_dtype):
- args = {'y_grad': y_grad_dtype, 'x': x_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
- return x_dtype
-
-
- class GatherDGrad(PrimitiveWithInfer):
- """Performs grad of GatherD operation."""
-
- @prim_attr_register
- def __init__(self, dim=0, shape=None):
- """Initialize GatherDGrad"""
- validator.check_is_int(dim, int)
- self.add_prim_attr("dim", dim)
- self.out_shape = shape
- self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output'])
-
- def infer_shape(self, index_shape, grad_shape):
- return self.out_shape
-
- def infer_dtype(self, index_dtype, grad_dtype):
- return grad_dtype
-
-
- class ResizeBilinearGrad(PrimitiveWithInfer):
- """Performs grad of ResizeBilinear operation."""
-
- @prim_attr_register
- def __init__(self, align_corners=False):
- """init"""
-
- def infer_shape(self, dout_shape, orig_shape):
- return orig_shape
-
- def infer_dtype(self, dout_dtype, orig_type):
- return orig_type
-
-
- class ResizeNearestNeighborGrad(PrimitiveWithInfer):
- """
- Compute gradient of `ResizeNearestNeighbor` operator.
-
- Note:
- The shape of input parameter `size` must be (height, width).
-
- Args:
- align_corners (bool): Whether the centers of the 4 corner pixels of the input
- and output tensors are aligned. Default: False.
- """
-
- @prim_attr_register
- def __init__(self, align_corners=False):
- """Initialize ResizeNearestNeighborGrad"""
- self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
-
- def __infer__(self, grads, size):
- shp = (grads['shape'][0],) + (grads['shape'][1],) + size['value']
- return {'shape': shp,
- 'dtype': grads['dtype'],
- 'value': None}
-
-
- class ROIAlignGrad(PrimitiveWithInfer):
- """
- ROIAlignGrad operator.
-
- Args:
- pooled_height (int): The output feature height.
- pooled_width (int): The output feature width.
- spatial_scale (float): The feature stride.
- sample_num (int): Number of sampling points. Default: 2.
- """
-
- @prim_attr_register
- def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
- """Initialize ROIAlignGrad"""
- validator.check_value_type("pooled_height", pooled_height, [int], self.name)
- validator.check_value_type("pooled_width", pooled_width, [int], self.name)
- validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
- validator.check_value_type("sample_num", sample_num, [int], self.name)
- validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name)
- self.xdiff_shape = xdiff_shape
- self.pooled_height = pooled_height
- self.pooled_width = pooled_width
- self.spatial_scale = spatial_scale
- self.sample_num = sample_num
-
- def infer_shape(self, ydiff_shape, rois_shape):
- return self.xdiff_shape
-
- def infer_dtype(self, ydiff_type, rois_type):
- return ydiff_type
-
-
- class SigmoidGrad(PrimitiveWithInfer):
- """Gets the gradient of Sigmoid operation."""
-
- @prim_attr_register
- def __init__(self):
- pass
-
- def infer_shape(self, out, dout):
- return out
-
- def infer_dtype(self, out, dout):
- args = {'out': out, 'dout': dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return out
-
-
- class HSigmoidGrad(PrimitiveWithInfer):
- """Gets the gradient of HSigmoid operation."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
-
- def infer_shape(self, y_grad_shape, x_shape):
- return x_shape
-
- def infer_dtype(self, y_grad_dtype, x_dtype):
- valid_dtypes = (mstype.float16, mstype.float32)
- validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
- validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
- return x_dtype
-
-
- class HSwishGrad(PrimitiveWithInfer):
- """Gets the gradient of HSwish operation."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
-
- def infer_shape(self, y_grad_shape, x_shape):
- return x_shape
-
- def infer_dtype(self, y_grad_dtype, x_dtype):
- valid_dtypes = (mstype.float16, mstype.float32)
- validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
- validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
- return x_dtype
-
-
- class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
- """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize SigmoidCrossEntropyWithLogitsGrad"""
- self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
-
- def infer_shape(self, x_shape, y_shape, dout_shape):
- validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
- validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
- args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return dout_dtype
-
-
- class SliceGrad(PrimitiveWithInfer):
- """Reverse of slice."""
-
- @prim_attr_register
- def __init__(self):
- """Initialize SliceGrad"""
- self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
-
- def __infer__(self, dy, x, begin, size):
- dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
- dy_shape_len = len(dy_shape)
- for i in range(dy_shape_len):
- validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
- validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
- return {'shape': x_shape,
- 'dtype': x['dtype'],
- 'value': None}
-
-
- class SmoothL1LossGrad(PrimitiveWithInfer):
- """Computes gradient for prediction on SmoothL1Loss."""
-
- @prim_attr_register
- def __init__(self, beta=1.0):
- self.add_prim_attr('sigma', beta)
-
- def infer_shape(self, prediction, target, dloss):
- validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
- validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
- return prediction
-
- def infer_dtype(self, prediction, target, dloss):
- args = {"prediction": prediction, "target": target, 'dloss': dloss}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return dloss
-
-
- class StridedSliceGrad(PrimitiveWithInfer):
- """
- Performs grad of StridedSlice operation.
-
- Args:
- begin_mask (int): Start indexing the slice. Default: 0.
- end_mask (int): End indexing the slice. Default: 0.
- ellipsis_mask (int): An int32 mask. Default: 0.
- new_axis_mask (int): An int32 mask. Default: 0.
- shrink_axis_mask (int): An int32 mask. Default: 0.
-
- Returns:
- Tensor, has the same shape of input.
- """
-
- @prim_attr_register
- def __init__(self,
- begin_mask=0,
- end_mask=0,
- ellipsis_mask=0,
- new_axis_mask=0,
- shrink_axis_mask=0):
- """Initialize StrideSliceGrad"""
- validator.check_value_type('begin_mask', begin_mask, [int], self.name)
- validator.check_value_type('end_mask', end_mask, [int], self.name)
- validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
- validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
- validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
- self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
-
- def __infer__(self, dy, shapex, begin, end, strides):
- validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name)
-
- for idx, item in enumerate(shapex['value']):
- validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
- for idx, item in enumerate(begin['value']):
- validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
- for idx, item in enumerate(end['value']):
- validator.check_value_type("end[%d]" % idx, item, [int], self.name)
- for idx, item in enumerate(strides['value']):
- validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
-
- return {'shape': shapex['value'],
- 'dtype': dy['dtype'],
- 'value': None}
-
-
- class SoftplusGrad(PrimitiveWithInfer):
- """Computes gradient for the Log Softmax activation."""
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output'])
-
- def infer_shape(self, dout_shape, x_shape):
- validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
- return x_shape
-
- def infer_dtype(self, dout_dtype, x_dtype):
- args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
- return x_dtype
-
-
- class TanhGrad(PrimitiveWithInfer):
- """Computes gradient of hyperbolic tangent of input element-wise."""
-
- @prim_attr_register
- def __init__(self):
- pass
-
- def infer_shape(self, out, dout):
- return out
-
- def infer_dtype(self, out, dout):
- args = {"out": out, "dout": dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return out
-
-
- class MirrorPadGrad(PrimitiveWithInfer):
- """Gradients of MirrorPad operation."""
-
- @prim_attr_register
- def __init__(self, mode="REFLECT"):
- """Initialize MirrorPad"""
- validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
- self.mode = mode
-
- def __infer__(self, dout, paddings):
- validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
- validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
- validator.check("paddings rank", len(paddings['shape']), "expected", 2, Rel.EQ, self.name)
- validator.check("paddings dim_1", paddings['shape'][1], "expected", 2, Rel.EQ, self.name)
-
- if paddings['value'] is None:
- raise ValueError(f"For {self.name}, paddings must be const.")
- paddings_value = paddings['value'].asnumpy()
- y_shape = ()
- dout_shape = dout['shape']
- for i, val in enumerate(dout_shape):
- y_shape += (val - paddings_value[i][0] - paddings_value[i][1],)
- return {'shape': y_shape,
- 'dtype': dout['dtype'],
- 'value': None}
-
-
- class EmbeddingLookupCommGrad(PrimitiveWithInfer):
- """
- Performs the gradient for the communication part of EmbeddingLookup operator.
-
- This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
- this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
- """
-
- @prim_attr_register
- def __init__(self):
- self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
- self.add_prim_attr('primitive_target', 'CPU')
-
- def __infer__(self, dy, split_num):
- """
- This primitive is implemented by three steps:
- 1) Splits the 'dy' along dimension 0 into 'split_num' parts.
- 2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
- 3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
- along dimension 0.
-
- The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
- """
- dy_shape = tuple(dy['shape'])
- split_num_value = split_num['value']
- validator.check_value_type("split_num_value", split_num_value, [int], self.name)
- dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
- return {'shape': dy_shape_all,
- 'dtype': dy['dtype'],
- 'value': None}
-
-
- class RefToEmbed(Primitive):
- r"""
- Make a key from Ref.
-
- The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
- and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
-
- Inputs:
- - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
-
- Outputs:
- symbolic_key, made from the Ref.
-
- Examples:
- >>> class Net(nn.Cell):
- >>> def __init__(self):
- >>> super(Net, self).__init__()
- >>> self.weight = mindspore.Parameter(1.0, name='weight')
- >>>
- >>> def construct(self):
- >>> key = RefToEmbed()(self.weight)
- >>> return key, self.weight
- """
- __mindspore_signature__ = (
- sig.make_sig('variable', sig.sig_rw.RW_REF),
- )
-
- @prim_attr_register
- def __init__(self):
- pass
-
-
- class AtanGrad(PrimitiveWithInfer):
- """
- Computes AtanGrad of input element-wise.
-
- Returns:
- Tensor, has the same type as input.
- """
-
- @prim_attr_register
- def __init__(self):
- """Initialize AtanGrad"""
-
- def infer_shape(self, x, dout):
- validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
- return x
-
- def infer_dtype(self, x, dout):
- args = {"x": x, "dout": dout}
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
- return x
-
-
- class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
- """Computes the state gradients of BasicLSTMCell."""
-
- @prim_attr_register
- def __init__(self, forget_bias, activation):
- self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
- self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
- self.add_prim_attr("io_format", "ND")
-
- def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
- # dhy and dcy should be same shape
- validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
- validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
- validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name)
- validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name)
- validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name)
- validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name)
- validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name)
- validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name)
- validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name)
-
- dgate_shape = (c_shape[0], 4 * c_shape[1])
- dct_1_shape = c_shape
-
- return (dgate_shape, dct_1_shape)
-
- def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
- validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
- validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name)
- validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name)
- validator.check_subclass("it", it_dtype, [mstype.tensor], self.name)
- validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name)
- validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name)
- validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name)
- validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name)
- validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
- return (c_dtype, c_dtype)
-
-
- class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
- """Computes the weight gradients of BasicLSTM."""
-
- @prim_attr_register
- def __init__(self):
- self.add_prim_attr("io_format", "HWCN")
-
- def infer_shape(self, x_shape, h_shape, dgate_shape):
- validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
- validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
- validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
- validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
- validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
- validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
- input_size = x_shape[1]
- hidden_size = h_shape[1]
- dw_shape = (input_size + hidden_size, 4 * hidden_size)
- db_shape = (4 * hidden_size,)
- return (dw_shape, db_shape)
-
- def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
- validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
- validator.check_subclass("h", h_dtype, mstype.tensor, self.name)
- validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
- validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
- return (x_dtype, x_dtype)
-
-
- class BasicLSTMCellInputGrad(PrimitiveWithInfer):
- """Computes the input gradients of BasicLSTM."""
-
- @prim_attr_register
- def __init__(self, keep_prob):
- self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
- self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
- self.add_prim_attr("io_format", "ND")
-
- def infer_shape(self, dgate_shape, w_shape):
- validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
- validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
- validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
- batch_size = dgate_shape[0]
- hidden_size = dgate_shape[1] // 4
- input_size = w_shape[0] - hidden_size
- dxt_shape = (batch_size, input_size)
- dht_shape = (batch_size, hidden_size)
- return (dxt_shape, dht_shape)
-
- def infer_dtype(self, dgate_dtype, w_dtype):
- validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
- validator.check_subclass("w", w_dtype, mstype.tensor, self.name)
- validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
- validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
- return (dgate_dtype, dgate_dtype)
-
-
- class InvGrad(PrimitiveWithInfer):
- """Computes gradients for inv operation."""
-
- @prim_attr_register
- def __init__(self):
- pass
-
- def infer_shape(self, x, grad):
- validator.check("x_shape", x, "grad_shape", grad, Rel.EQ, self.name)
- return x
-
- def infer_dtype(self, x, grad):
- validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
- validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
- return x
-
-
- class LRNGrad(PrimitiveWithInfer):
- """Computes gradients for LRN operation."""
-
- @prim_attr_register
- def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
- self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
- validator.check_value_type("depth_radius", depth_radius, [int], self.name)
- validator.check_value_type("bias", bias, [float], self.name)
- validator.check_value_type("alpha", alpha, [float], self.name)
- validator.check_value_type("beta", beta, [float], self.name)
-
- def infer_dtype(self, grads, x, y):
- args = {"grads": grads, "x": x, "y": y}
- validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
- return x
-
- def infer_shape(self, grads, x, y):
- return x
|