| @@ -67,7 +67,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl | |||||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder, | GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder, | ||||
| LogSoftmax, | LogSoftmax, | ||||
| MaxPool, DataFormatDimMap, | MaxPool, DataFormatDimMap, | ||||
| AvgPool, Conv2DBackpropInput, ConfusionMulGrad, | |||||
| AvgPool, Conv2DBackpropInput, | |||||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | ||||
| ResizeBilinear, Sigmoid, | ResizeBilinear, Sigmoid, | ||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| @@ -136,7 +136,6 @@ __all__ = [ | |||||
| 'LogSoftmax', | 'LogSoftmax', | ||||
| 'SoftmaxCrossEntropyWithLogits', | 'SoftmaxCrossEntropyWithLogits', | ||||
| 'ROIAlign', | 'ROIAlign', | ||||
| 'ConfusionMulGrad', | |||||
| 'SparseSoftmaxCrossEntropyWithLogits', | 'SparseSoftmaxCrossEntropyWithLogits', | ||||
| 'SGD', | 'SGD', | ||||
| 'ApplyMomentum', | 'ApplyMomentum', | ||||
| @@ -19,6 +19,7 @@ from ..._checkparam import Rel | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | from ..primitive import PrimitiveWithInfer, prim_attr_register | ||||
| from ..operations.math_ops import _infer_shape_reduce | |||||
| class StridedSliceAICPU(PrimitiveWithInfer): | class StridedSliceAICPU(PrimitiveWithInfer): | ||||
| @@ -681,3 +682,63 @@ class DynamicRNN(PrimitiveWithInfer): | |||||
| validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name) | validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name) | ||||
| validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name) | validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name) | ||||
| return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | ||||
| class ConfusionMulGrad(PrimitiveWithInfer): | |||||
| """ | |||||
| `output0` is the dot product result of input0 and input1. | |||||
| `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it. | |||||
| Args: | |||||
| axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | |||||
| Default:(), reduce all dimensions. Only constant value is allowed. | |||||
| keep_dims (bool): | |||||
| - If true, keep these reduced dimensions and the length as 1. | |||||
| - If false, don't keep these dimensions. Default:False. | |||||
| Inputs: | |||||
| - **input_0** (Tensor) - The input Tensor. | |||||
| - **input_1** (Tensor) - The input Tensor. | |||||
| - **input_2** (Tensor) - The input Tensor. | |||||
| Outputs: | |||||
| - **output_0** (Tensor) - The same shape as `input0`. | |||||
| - **output_1** (Tensor) | |||||
| - If axis is (), and keep_dims is false, the output is a 0-D array representing | |||||
| the sum of all elements in the input array. | |||||
| - If axis is int, set as 2, and keep_dims is false, | |||||
| the shape of output is :math:`(x_1,x_3,...,x_R)`. | |||||
| - If axis is tuple(int), set as (2,3), and keep_dims is false, | |||||
| the shape of output is :math:`(x_1,x_4,...x_R)`. | |||||
| Examples: | |||||
| >>> confusion_mul_grad = P.ConfusionMulGrad() | |||||
| >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) | |||||
| >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32) | |||||
| >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32) | |||||
| >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2) | |||||
| output_0: | |||||
| [[ 3. 1. 0.] | |||||
| [-6. 2. -2.]] | |||||
| output_1: | |||||
| -3.0 | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, axis=(), keep_dims=False): | |||||
| self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"]) | |||||
| self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) | |||||
| self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) | |||||
| def infer_shape(self, input0_shape, input1_shape, input2_shape): | |||||
| outshape0 = input0_shape | |||||
| outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) | |||||
| return outshape0, outshape1 | |||||
| def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): | |||||
| validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) | |||||
| validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) | |||||
| validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) | |||||
| return input0_dtype, input1_dtype | |||||
| @@ -27,7 +27,6 @@ from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | ||||
| from ..operations.math_ops import _infer_shape_reduce | |||||
| def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): | def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): | ||||
| @@ -5151,66 +5150,6 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||||
| return var_dtype, accum_dtype, linear_dtype | return var_dtype, accum_dtype, linear_dtype | ||||
| class ConfusionMulGrad(PrimitiveWithInfer): | |||||
| """ | |||||
| `output0` is the dot product result of input0 and input1. | |||||
| `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it. | |||||
| Args: | |||||
| axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | |||||
| Default:(), reduce all dimensions. Only constant value is allowed. | |||||
| keep_dims (bool): | |||||
| - If true, keep these reduced dimensions and the length as 1. | |||||
| - If false, don't keep these dimensions. Default: False. | |||||
| Inputs: | |||||
| - **input_0** (Tensor) - The input Tensor. | |||||
| - **input_1** (Tensor) - The input Tensor. | |||||
| - **input_2** (Tensor) - The input Tensor. | |||||
| Outputs: | |||||
| - **output_0** (Tensor) - The same shape as `input0`. | |||||
| - **output_1** (Tensor) | |||||
| - If axis is (), and keep_dims is False, the output is a 0-D array representing | |||||
| the sum of all elements in the input array. | |||||
| - If axis is int, set as 2, and keep_dims is False, | |||||
| the shape of output is :math:`(x_1,x_3,...,x_R)`. | |||||
| - If axis is tuple(int), set as (2,3), and keep_dims is False, | |||||
| the shape of output is :math:`(x_1,x_4,...x_R)`. | |||||
| Examples: | |||||
| >>> confusion_mul_grad = P.ConfusionMulGrad() | |||||
| >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) | |||||
| >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32) | |||||
| >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32) | |||||
| >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2) | |||||
| output_0: | |||||
| [[ 3. 1. 0.] | |||||
| [-6. 2. -2.]] | |||||
| output_1: | |||||
| -3.0 | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, axis=(), keep_dims=False): | |||||
| self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"]) | |||||
| self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) | |||||
| self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) | |||||
| def infer_shape(self, input0_shape, input1_shape, input2_shape): | |||||
| outshape0 = input0_shape | |||||
| outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) | |||||
| return outshape0, outshape1 | |||||
| def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): | |||||
| validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) | |||||
| validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) | |||||
| validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) | |||||
| return input0_dtype, input1_dtype | |||||
| class Dropout(PrimitiveWithInfer): | class Dropout(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| During training, randomly zeroes some of the elements of the input tensor with probability. | During training, randomly zeroes some of the elements of the input tensor with probability. | ||||
| @@ -2329,21 +2329,6 @@ test_case_other_ops = [ | |||||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | ||||
| Tensor(np.array([1.2]).astype(np.float32))], | Tensor(np.array([1.2]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ConfusionMulGrad_1', { | |||||
| 'block': P.ConfusionMulGrad(axis=[0], keep_dims=False), | |||||
| 'desc_inputs': [[3, 2], [3, 2], [3, 2]], | |||||
| 'desc_bprop': [[3, 2], [2]], | |||||
| 'skip': ['backward']}), | |||||
| ('ConfusionMulGrad_2', { | |||||
| 'block': P.ConfusionMulGrad(axis=[0], keep_dims=True), | |||||
| 'desc_inputs': [[3, 2], [3, 2], [3, 2]], | |||||
| 'desc_bprop': [[3, 2], [1, 2]], | |||||
| 'skip': ['backward']}), | |||||
| ('ConfusionMulGrad_3', { | |||||
| 'block': P.ConfusionMulGrad(axis=(), keep_dims=True), | |||||
| 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], | |||||
| 'desc_bprop': [[2, 3, 4], [1, 1, 1]], | |||||
| 'skip': ['backward']}), | |||||
| ('HistogramSummary', { | ('HistogramSummary', { | ||||
| 'block': HistogramSummaryNet(), | 'block': HistogramSummaryNet(), | ||||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | ||||