diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 6510ef79ea..0a6fb0b3f6 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -174,6 +174,8 @@ const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad") const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); +const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); +const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index b37d068d94..8c63660c3e 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -175,6 +175,8 @@ extern const PrimitivePtr kPrimTanhGrad; extern const PrimitivePtr kPrimPooling; extern const PrimitivePtr kPrimPoolingGrad; extern const PrimitivePtr kPrimFusedBatchNorm; +extern const PrimitivePtr kPrimBatchNorm; +extern const PrimitivePtr kPrimBatchNormGrad; extern const PrimitivePtr kPrimConv2D; extern const PrimitivePtr kPrimMaxPool; extern const PrimitivePtr kPrimMaxPoolGrad; diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index e7ea44b555..fb98d16c26 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -221,7 +221,6 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, - {prim::kPrimFusedBatchNormGrad->name(), ADPT_DESC(FusedBatchNormGrad)}, {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, @@ -229,7 +228,6 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, - {prim::kPrimFusedBatchNorm->name(), ADPT_DESC(FusedBatchNorm, BatchNorm)}, {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, {string(kNameReshape), ADPT_DESC(Reshape)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index b1195cfb1c..8159204155 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -703,28 +703,6 @@ INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP; OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}}; -// FusedBatchNorm -INPUT_MAP(FusedBatchNorm) = { - {1, INPUT_DESC(x)}, {2, INPUT_DESC(scale)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(variance)}}; -ATTR_MAP(FusedBatchNorm) = {{"mode", ATTR_DESC(mode, AnyTraits())}, - {"momentum", ATTR_DESC(moving_average_fraction, AnyTraits())}, - {"epsilon", ATTR_DESC(epsilon, AnyTraits())}}; -OUTPUT_MAP(FusedBatchNorm) = {{0, OUTPUT_DESC(y)}, - {1, OUTPUT_DESC(running_mean)}, - {2, OUTPUT_DESC(running_variance)}, - {3, OUTPUT_DESC(save_mean)}, - {4, OUTPUT_DESC(save_inv_variance)}}; - -// FusedBatchNromGrad -INPUT_MAP(FusedBatchNormGrad) = {{1, INPUT_DESC(dy)}, - {2, INPUT_DESC(x)}, - {3, INPUT_DESC(scale)}, - {4, INPUT_DESC(save_mean)}, - {5, INPUT_DESC(save_inv_variance)}}; -ATTR_MAP(FusedBatchNormGrad) = {{"momentum", ATTR_DESC(momentum, AnyTraits())}, - {"epsilon", ATTR_DESC(epsilon, AnyTraits())}}; -OUTPUT_MAP(FusedBatchNormGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(bn_scale)}, {2, OUTPUT_DESC(bn_bias)}}; - // BiasAddGrad INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index a2dc16c285..21cac35121 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -82,10 +82,6 @@ DECLARE_OP_USE_OUTPUT(HcomAllGather) DECLARE_OP_ADAPTER(Variable) DECLARE_OP_ADAPTER(ReluGrad) DECLARE_OP_USE_OUTPUT(ReluGrad) -DECLARE_OP_ADAPTER(FusedBatchNorm) -DECLARE_OP_USE_OUTPUT(FusedBatchNorm) -DECLARE_OP_ADAPTER(FusedBatchNormGrad) -DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad) DECLARE_OP_ADAPTER(BiasAddGrad) DECLARE_OP_USE_OUTPUT(BiasAddGrad) DECLARE_OP_ADAPTER(MaxPoolWithArgmax) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 3e139a2db5..f06c5fd30a 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -58,6 +58,7 @@ class ImageGradients(Cell): super(ImageGradients, self).__init__() def construct(self, images): + _check_input_4d(F.shape(images), "images", self.cls_name) batch_size, depth, height, width = P.Shape()(images) dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) @@ -151,8 +152,8 @@ class SSIM(Cell): self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) def construct(self, img1, img2): - _check_input_4d(F.shape(img1), "img1", "SSIM") - _check_input_4d(F.shape(img2), "img2", "SSIM") + _check_input_4d(F.shape(img1), "img1", self.cls_name) + _check_input_4d(F.shape(img2), "img2", self.cls_name) P.SameTypeShape()(img1, img2) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val) @@ -244,8 +245,8 @@ class PSNR(Cell): self.max_val = max_val def construct(self, img1, img2): - _check_input_4d(F.shape(img1), "img1", "PSNR") - _check_input_4d(F.shape(img2), "img2", "PSNR") + _check_input_4d(F.shape(img1), "img1", self.cls_name) + _check_input_4d(F.shape(img2), "img2", self.cls_name) P.SameTypeShape()(img1, img2) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index abffde1865..b69b083e03 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1016,6 +1016,7 @@ class Argmin(PrimitiveWithInfer): """init Argmin""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name) + validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name) self.axis = axis self.add_prim_attr('output_type', output_type) @@ -1726,7 +1727,9 @@ class Diag(PrimitiveWithInfer): def infer_value(self, x): if x is None: return None - validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name) + # do constant-folding only when x rank is 1 + if len(x.shape()) != 1: + return None ret = np.diag(x.asnumpy()) return Tensor(ret) @@ -1752,7 +1755,7 @@ class DiagPart(PrimitiveWithInfer): >>> [0, 0, 3, 0], >>> [0, 0, 0, 4]]) >>> diag_part = P.DiagPart() - >>> diag_part(x) + >>> diag_part(input_x) [1, 2, 3, 4] """ @@ -1776,7 +1779,9 @@ class DiagPart(PrimitiveWithInfer): def infer_value(self, x): if x is None: return None - validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name) + # do constant-folding only when x rank is 2 + if len(x.shape()) != 2: + return None ret = np.diag(x.asnumpy()) return Tensor(ret) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 1dfe93136b..a634ebbb71 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2037,7 +2037,7 @@ class Atan2(_MathBinaryOp): r""" Returns arctangent of input_x/input_y element-wise. - It returns :math:`\theta\ \in\ (-\frac{\pi}{2}, \frac{\pi}{2})` + It returns :math:`\theta\ \in\ [-\pi, \pi]` such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`. Inputs: diff --git a/tests/ut/cpp/transform/convert_test.cc b/tests/ut/cpp/transform/convert_test.cc index 4388312592..0f47499665 100644 --- a/tests/ut/cpp/transform/convert_test.cc +++ b/tests/ut/cpp/transform/convert_test.cc @@ -147,13 +147,13 @@ TEST_F(TestConvert, TestReluOps) { } TEST_F(TestConvert, TestConvertBatchNorm) { - PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm; - fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f)); - fused_batch_norm->AddAttr("momentum", MakeValue(0.1f)); + PrimitivePtr batch_norm = prim::kPrimBatchNorm; + batch_norm->AddAttr("epsilon", MakeValue(0.001f)); + batch_norm->AddAttr("momentum", MakeValue(0.1f)); FuncGraphPtr anf_graph = std::make_shared(); std::vector inputs; - inputs.push_back(NewValueNode(fused_batch_norm)); + inputs.push_back(NewValueNode(batch_norm)); for (unsigned int i = 0; i < 5; i++) { inputs.push_back(anf_graph->add_parameter()); } diff --git a/tests/ut/python/nn/test_image_gradients.py b/tests/ut/python/nn/test_image_gradients.py index a2b9495443..e268ceb9d9 100644 --- a/tests/ut/python/nn/test_image_gradients.py +++ b/tests/ut/python/nn/test_image_gradients.py @@ -14,6 +14,7 @@ # ============================================================================ """ test image gradients """ import numpy as np +import pytest import mindspore.nn as nn import mindspore.context as context import mindspore.common.dtype as mstype @@ -47,3 +48,10 @@ def test_compile_multi_channel(): [[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype) net = Net() _executor.compile(net, image) + +def test_invalid_5d_input(): + dtype = mstype.float32 + image = Tensor(np.random.random([4, 1, 16, 16, 1]), dtype=dtype) + net = Net() + with pytest.raises(ValueError): + _executor.compile(net, image) \ No newline at end of file diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 01e7e32d50..61b8d48fea 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -14,16 +14,15 @@ # ============================================================================ """ test array ops """ import functools +import pytest import numpy as np import mindspore as ms from mindspore import Tensor from mindspore.nn import Cell from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C from mindspore.ops import prim_attr_register +from mindspore.common import dtype as mstype from mindspore.ops.primitive import Primitive, PrimitiveWithInfer -from mindspore.common.dtype import get_py_obj_dtype from mindspore._c_expression import signature_dtype as sig_dtype from mindspore._c_expression import signature_rw as sig_rw from mindspore._c_expression import signature_kind as sig_kind @@ -96,6 +95,17 @@ def test_select(): expect = np.array([[1, 8, 9], [10, 5, 6]]) assert np.all(output.asnumpy() == expect) +def test_argmin_invalid_output_type(): + P.Argmin(-1, mstype.int64) + P.Argmin(-1, mstype.int32) + with pytest.raises(TypeError): + P.Argmin(-1, mstype.float32) + with pytest.raises(TypeError): + P.Argmin(-1, mstype.float64) + with pytest.raises(TypeError): + P.Argmin(-1, mstype.uint8) + with pytest.raises(TypeError): + P.Argmin(-1, mstype.bool_) class CustomOP(PrimitiveWithInfer): __mindspore_signature__ = (sig_dtype.T, sig_dtype.T, sig_dtype.T1, diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index b866c7c556..a4a645a7ef 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -17,6 +17,7 @@ import functools import numpy as np import mindspore as ms import mindspore.nn as nn +from mindspore.common.api import _executor from mindspore.common import dtype as mstype from mindspore.ops import prim_attr_register, PrimitiveWithInfer from mindspore import Tensor