Merge pull request !588 from liuxiao/add-reluv2tags/v0.2.0-alpha
| @@ -33,6 +33,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"re_lu6", "relu6"}, | {"re_lu6", "relu6"}, | ||||
| {"re_lu6_grad", "relu6_grad"}, | {"re_lu6_grad", "relu6_grad"}, | ||||
| {"re_lu", "relu"}, | {"re_lu", "relu"}, | ||||
| {"re_luv2", "relu_v2"}, | |||||
| {"tensor_add", "add"}, | {"tensor_add", "add"}, | ||||
| {"reduce_mean", "reduce_mean_d"}, | {"reduce_mean", "reduce_mean_d"}, | ||||
| {"reduce_max", "reduce_max_d"}, | {"reduce_max", "reduce_max_d"}, | ||||
| @@ -227,6 +227,18 @@ def get_bprop_relu6(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.ReLUV2) | |||||
| def get_bprop_relu_v2(self): | |||||
| """Grad definition for `ReLUV2` operation.""" | |||||
| input_grad = G.ReluGradV2() | |||||
| def bprop(x, out, dout): | |||||
| mask = out[1] | |||||
| dx = input_grad(dout[0], mask) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(P.HSwish) | @bprop_getters.register(P.HSwish) | ||||
| def get_bprop_hswish(self): | def get_bprop_hswish(self): | ||||
| """Grad definition for `HSwish` operation.""" | """Grad definition for `HSwish` operation.""" | ||||
| @@ -33,6 +33,7 @@ from .cast import _cast_tbe | |||||
| from .conv2d import _conv2d_tbe | from .conv2d import _conv2d_tbe | ||||
| from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe | from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe | ||||
| from .conv2d_backprop_input import _conv2d_backprop_input_tbe | from .conv2d_backprop_input import _conv2d_backprop_input_tbe | ||||
| from .confusion_mul_grad import _confusion_mul_grad_tbe | |||||
| from .dropout_do_mask import _dropout_do_mask_tbe | from .dropout_do_mask import _dropout_do_mask_tbe | ||||
| from .gelu import _gelu_tbe | from .gelu import _gelu_tbe | ||||
| from .gelu_grad import _gelu_grad_tbe | from .gelu_grad import _gelu_grad_tbe | ||||
| @@ -46,6 +47,8 @@ from .relu import _relu_tbe | |||||
| from .relu_grad import _relu_grad_tbe | from .relu_grad import _relu_grad_tbe | ||||
| from .relu6 import _relu6_tbe | from .relu6 import _relu6_tbe | ||||
| from .relu6_grad import _relu6_grad_tbe | from .relu6_grad import _relu6_grad_tbe | ||||
| from .relu_v2 import _relu_v2_tbe | |||||
| from .relu_grad_v2 import _relu_grad_v2_tbe | |||||
| from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logits_tbe | from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logits_tbe | ||||
| from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe | from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe | ||||
| from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe | from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ConfusionMulGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| confusion_mul_grad_op_info = TBERegOp("ConfusionMulGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .attr("axis", "required", "listInt", "all") \ | |||||
| .attr("keep_dims", "required", "bool", "all") \ | |||||
| .input(0, "input0", False, "required", "all") \ | |||||
| .input(1, "input1", False, "required", "all") \ | |||||
| .input(2, "input2", False, "required", "all") \ | |||||
| .output(0, "output0", False, "required", "all") \ | |||||
| .output(1, "output1", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(confusion_mul_grad_op_info) | |||||
| def _confusion_mul_grad_tbe(): | |||||
| """ConfusionMulGrad TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ReluGradV2 op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("relu_grad_v2.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("relu_grad_v2") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "gradients", False, "required", "all") \ | |||||
| .input(1, "mask", False, "rerequired", "all") \ | |||||
| .output(0, "backprops", True, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.U8_Default, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.U8_Default, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_Default, DataType.U8_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(relu_grad_v2_op_info) | |||||
| def _relu_grad_v2_tbe(): | |||||
| """ReluGradV2 TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ReluV2 op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| relu_v2_op_info = TBERegOp("ReLUV2") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("relu_v2.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("relu_v2") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .output(1, "mask", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(relu_v2_op_info) | |||||
| def _relu_v2_tbe(): | |||||
| """ReluV2 TBE register""" | |||||
| return | |||||
| @@ -58,8 +58,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| GetNext, L2Normalize, LayerNorm, L2Loss, | GetNext, L2Normalize, LayerNorm, L2Loss, | ||||
| LogSoftmax, | LogSoftmax, | ||||
| MaxPool, ExtractImagePatches, | MaxPool, ExtractImagePatches, | ||||
| AvgPool, Conv2DBackpropInput, | |||||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, | |||||
| AvgPool, Conv2DBackpropInput, ConfusionMulGrad, | |||||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||||
| ResizeBilinear, Sigmoid, | ResizeBilinear, Sigmoid, | ||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| SmoothL1Loss, Softmax, | SmoothL1Loss, Softmax, | ||||
| @@ -101,6 +101,7 @@ __all__ = [ | |||||
| 'LogSoftmax', | 'LogSoftmax', | ||||
| 'SoftmaxCrossEntropyWithLogits', | 'SoftmaxCrossEntropyWithLogits', | ||||
| 'ROIAlign', | 'ROIAlign', | ||||
| 'ConfusionMulGrad', | |||||
| 'SparseSoftmaxCrossEntropyWithLogits', | 'SparseSoftmaxCrossEntropyWithLogits', | ||||
| 'SGD', | 'SGD', | ||||
| 'ApplyMomentum', | 'ApplyMomentum', | ||||
| @@ -138,6 +139,7 @@ __all__ = [ | |||||
| 'Split', | 'Split', | ||||
| 'ReLU', | 'ReLU', | ||||
| 'ReLU6', | 'ReLU6', | ||||
| 'ReLUV2', | |||||
| 'Elu', | 'Elu', | ||||
| 'Erf', | 'Erf', | ||||
| 'Sigmoid', | 'Sigmoid', | ||||
| @@ -730,6 +730,27 @@ class ReLU6Grad(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class ReluGradV2(PrimitiveWithInfer): | |||||
| """Performs grad of ReLUV2 operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output']) | |||||
| def __call__(self, gradients, mask): | |||||
| raise NotImplementedError | |||||
| def infer_shape(self, gradients_shape, mask_shape): | |||||
| return gradients_shape | |||||
| def infer_dtype(self, gradients_dtype, mask_dtype): | |||||
| args_type = {'gradients': gradients_dtype, 'mask': mask_dtype} | |||||
| validator.check_args_tensor(args_type) | |||||
| validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type) | |||||
| validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,)) | |||||
| return gradients_dtype | |||||
| class EluGrad(PrimitiveWithInfer): | class EluGrad(PrimitiveWithInfer): | ||||
| """Performs grad of Elu operation.""" | """Performs grad of Elu operation.""" | ||||
| @@ -1329,7 +1329,7 @@ class Concat(PrimitiveWithInfer): | |||||
| def _get_pack_shape(x_shape, x_type, axis): | def _get_pack_shape(x_shape, x_type, axis): | ||||
| """for pack output shape""" | """for pack output shape""" | ||||
| validator.check_type("shape", x_shape, [tuple]) | |||||
| validator.check_type("shape", x_shape, [tuple, list]) | |||||
| validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) | validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) | ||||
| validator.check_subclass("shape0", x_type[0], mstype.tensor) | validator.check_subclass("shape0", x_type[0], mstype.tensor) | ||||
| validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) | validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) | ||||
| @@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from ..operations.math_ops import _infer_shape_reduce | |||||
| def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): | def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): | ||||
| @@ -233,6 +234,62 @@ class ReLU6(PrimitiveWithInfer): | |||||
| return input_x | return input_x | ||||
| class ReLUV2(PrimitiveWithInfer): | |||||
| r""" | |||||
| Computes ReLU(Rectified Linear Unit) of input tensor element-wise. | |||||
| It returns :math:`\max(x,\ 0)` element-wise. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor should be a 4-D tensor. | |||||
| Outputs: | |||||
| - **output** (Tensor) - Has the same type and shape as the `input_x`. | |||||
| - **mask** (Tensor) - A tensor whose data type must be uint8. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32) | |||||
| >>> relu_v2 = P.ReLUV2() | |||||
| >>> output = relu_v2(input_x) | |||||
| ([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]], | |||||
| [[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]]) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init ReLUV2""" | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask']) | |||||
| def __infer__(self, input_x): | |||||
| input_shape = list(input_x['shape']) | |||||
| input_dtype = input_x['dtype'] | |||||
| mask_shape = [] | |||||
| if len(input_shape) != 4: | |||||
| raise ValueError("The `input_x` should be a 4-D tensor, " | |||||
| f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}") | |||||
| for i in enumerate(input_shape): | |||||
| if i[0] == 1: | |||||
| if input_dtype == mstype.uint8 and input_dtype == mstype.int8: | |||||
| mask_shape.append((input_shape[1] + 31) // 32) | |||||
| else: | |||||
| mask_shape.append((input_shape[1] + 15) // 16) | |||||
| else: | |||||
| mask_shape.append(i[1]) | |||||
| if input_dtype == mstype.uint8 and input_dtype == mstype.int8: | |||||
| mask_shape.append(4) | |||||
| else: | |||||
| mask_shape.append(2) | |||||
| output_shape = (input_x['shape'], mask_shape) | |||||
| validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name) | |||||
| mask_dtype = mstype.uint8 | |||||
| output_dtype = (input_dtype, mask_dtype) | |||||
| return {'shape': output_shape, | |||||
| 'dtype': output_dtype, | |||||
| 'value': None} | |||||
| class Elu(PrimitiveWithInfer): | class Elu(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. | Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. | ||||
| @@ -2580,3 +2637,51 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| def infer_dtype(self, input_x): | def infer_dtype(self, input_x): | ||||
| validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) | validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) | ||||
| return input_x | return input_x | ||||
| class ConfusionMulGrad(PrimitiveWithInfer): | |||||
| """ | |||||
| `output0` is the result of which input0 dot multily input1. | |||||
| `output1` is the result of which input0 dot multily input1, then reducesum it. | |||||
| Args: | |||||
| axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | |||||
| Default:(), reduce all dimensions. Only constant value is allowed. | |||||
| keep_dims (bool): | |||||
| - If true, keep these reduced dimensions and the length is 1. | |||||
| - If false, don't keep these dimensions. Default:False. | |||||
| Inputs: | |||||
| - **input_0** (Tensor) - The input Tensor. | |||||
| - **input_1** (Tensor) - The input Tensor. | |||||
| - **input_2** (Tensor) - The input Tensor. | |||||
| outputs: | |||||
| - **output_0** (Tensor) - The same shape with `input0`. | |||||
| - **output_1** (Tensor) | |||||
| - If axis is (), and keep_dims is false, the output is a 0-D array representing | |||||
| the sum of all elements in the input array. | |||||
| - If axis is int, set as 2, and keep_dims is false, | |||||
| the shape of output is :math:`(x_1,x_3,...,x_R)`. | |||||
| - If axis is tuple(int), set as (2,3), and keep_dims is false, | |||||
| the shape of output is :math:`(x_1,x_4,...x_R)`. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, axis = (), keep_dims = False): | |||||
| self.init_prim_io_names(inputs = ["input0", "input1", "input2"], outputs = ["output0", "output1"]) | |||||
| self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) | |||||
| self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) | |||||
| def infer_shape(self, input0_shape, input1_shape, input2_shape): | |||||
| outshape0 = input0_shape | |||||
| outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) | |||||
| return outshape0, outshape1 | |||||
| def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): | |||||
| validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) | |||||
| validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) | |||||
| validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) | |||||
| return input0_dtype, input1_dtype | |||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import ms_function | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.ops.composite import GradOperation | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Grad(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(Grad, self).__init__() | |||||
| self.grad = GradOperation(name="get_all", get_all=True) | |||||
| self.network = network | |||||
| @ms_function | |||||
| def construct(self, input): | |||||
| return self.grad(self.network)(input) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.relu_v2 = P.ReLUV2() | |||||
| def construct(self, x): | |||||
| return self.relu_v2(x) | |||||
| def test_net(): | |||||
| x = Tensor(np.ones((2,3,3,4)).astype(np.float32)) | |||||
| relu_net = Net() | |||||
| relu_output = relu_net(x) | |||||
| net = Grad(Net()) | |||||
| output_grad = net(x) | |||||
| print(relu_output[0].asnumpy()) | |||||
| print(relu_output[1].asnumpy()) | |||||
| print(len(output_grad)) | |||||
| print(output_grad[0].asnumpy()) | |||||
| @@ -582,6 +582,10 @@ test_case_nn_ops = [ | |||||
| 'block': P.ReLU6(), | 'block': P.ReLU6(), | ||||
| 'desc_inputs': [[1, 3, 4, 4]], | 'desc_inputs': [[1, 3, 4, 4]], | ||||
| 'desc_bprop': [[1, 3, 4, 4]]}), | 'desc_bprop': [[1, 3, 4, 4]]}), | ||||
| ('ReLUV2', { | |||||
| 'block': P.ReLUV2(), | |||||
| 'desc_inputs': [[1, 3, 4, 4]], | |||||
| 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), | |||||
| ('ReLUGrad', { | ('ReLUGrad', { | ||||
| 'block': G.ReluGrad(), | 'block': G.ReluGrad(), | ||||
| 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | ||||
| @@ -1134,6 +1138,21 @@ test_case_other_ops = [ | |||||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | ||||
| Tensor(np.array([1.2]).astype(np.float32))], | Tensor(np.array([1.2]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ConfusionMulGrad_1', { | |||||
| 'block': P.ConfusionMulGrad(axis = [0], keep_dims = False), | |||||
| 'desc_inputs': [[3, 2], [3, 2], [3, 2]], | |||||
| 'desc_bprop': [[3, 2], [2]], | |||||
| 'skip': ['backward']}), | |||||
| ('ConfusionMulGrad_2', { | |||||
| 'block': P.ConfusionMulGrad(axis = [0], keep_dims = True), | |||||
| 'desc_inputs': [[3, 2], [3, 2], [3, 2]], | |||||
| 'desc_bprop': [[3, 2], [1, 2]], | |||||
| 'skip': ['backward']}), | |||||
| ('ConfusionMulGrad_3', { | |||||
| 'block': P.ConfusionMulGrad(axis = (), keep_dims = True), | |||||
| 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], | |||||
| 'desc_bprop': [[2, 3, 4], [1, 1, 1]], | |||||
| 'skip': ['backward']}), | |||||
| ('HistogramSummary', { | ('HistogramSummary', { | ||||
| 'block': HistogramSummaryNet(), | 'block': HistogramSummaryNet(), | ||||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | ||||