Merge pull request !794 from zhaozhenlong/fix-issues-image-gradients-complement-checktags/v0.3.0-alpha
| @@ -174,6 +174,8 @@ const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad") | |||||
| const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | ||||
| const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | ||||
| const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | ||||
| const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||||
| const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||||
| const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | ||||
| const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | ||||
| const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | ||||
| @@ -175,6 +175,8 @@ extern const PrimitivePtr kPrimTanhGrad; | |||||
| extern const PrimitivePtr kPrimPooling; | extern const PrimitivePtr kPrimPooling; | ||||
| extern const PrimitivePtr kPrimPoolingGrad; | extern const PrimitivePtr kPrimPoolingGrad; | ||||
| extern const PrimitivePtr kPrimFusedBatchNorm; | extern const PrimitivePtr kPrimFusedBatchNorm; | ||||
| extern const PrimitivePtr kPrimBatchNorm; | |||||
| extern const PrimitivePtr kPrimBatchNormGrad; | |||||
| extern const PrimitivePtr kPrimConv2D; | extern const PrimitivePtr kPrimConv2D; | ||||
| extern const PrimitivePtr kPrimMaxPool; | extern const PrimitivePtr kPrimMaxPool; | ||||
| extern const PrimitivePtr kPrimMaxPoolGrad; | extern const PrimitivePtr kPrimMaxPoolGrad; | ||||
| @@ -221,7 +221,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, | {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, | ||||
| {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, | {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, | ||||
| {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, | {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, | ||||
| {prim::kPrimFusedBatchNormGrad->name(), ADPT_DESC(FusedBatchNormGrad)}, | |||||
| {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, | {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, | ||||
| {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, | {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, | ||||
| {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, | {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, | ||||
| @@ -229,7 +228,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, | {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, | ||||
| {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, | {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, | ||||
| {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, | {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, | ||||
| {prim::kPrimFusedBatchNorm->name(), ADPT_DESC(FusedBatchNorm, BatchNorm)}, | |||||
| {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, | {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, | ||||
| {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, | {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, | ||||
| {string(kNameReshape), ADPT_DESC(Reshape)}, | {string(kNameReshape), ADPT_DESC(Reshape)}, | ||||
| @@ -703,28 +703,6 @@ INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; | |||||
| ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP; | ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}}; | OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}}; | ||||
| // FusedBatchNorm | |||||
| INPUT_MAP(FusedBatchNorm) = { | |||||
| {1, INPUT_DESC(x)}, {2, INPUT_DESC(scale)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(variance)}}; | |||||
| ATTR_MAP(FusedBatchNorm) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())}, | |||||
| {"momentum", ATTR_DESC(moving_average_fraction, AnyTraits<float>())}, | |||||
| {"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}}; | |||||
| OUTPUT_MAP(FusedBatchNorm) = {{0, OUTPUT_DESC(y)}, | |||||
| {1, OUTPUT_DESC(running_mean)}, | |||||
| {2, OUTPUT_DESC(running_variance)}, | |||||
| {3, OUTPUT_DESC(save_mean)}, | |||||
| {4, OUTPUT_DESC(save_inv_variance)}}; | |||||
| // FusedBatchNromGrad | |||||
| INPUT_MAP(FusedBatchNormGrad) = {{1, INPUT_DESC(dy)}, | |||||
| {2, INPUT_DESC(x)}, | |||||
| {3, INPUT_DESC(scale)}, | |||||
| {4, INPUT_DESC(save_mean)}, | |||||
| {5, INPUT_DESC(save_inv_variance)}}; | |||||
| ATTR_MAP(FusedBatchNormGrad) = {{"momentum", ATTR_DESC(momentum, AnyTraits<float>())}, | |||||
| {"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}}; | |||||
| OUTPUT_MAP(FusedBatchNormGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(bn_scale)}, {2, OUTPUT_DESC(bn_bias)}}; | |||||
| // BiasAddGrad | // BiasAddGrad | ||||
| INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | ||||
| @@ -82,10 +82,6 @@ DECLARE_OP_USE_OUTPUT(HcomAllGather) | |||||
| DECLARE_OP_ADAPTER(Variable) | DECLARE_OP_ADAPTER(Variable) | ||||
| DECLARE_OP_ADAPTER(ReluGrad) | DECLARE_OP_ADAPTER(ReluGrad) | ||||
| DECLARE_OP_USE_OUTPUT(ReluGrad) | DECLARE_OP_USE_OUTPUT(ReluGrad) | ||||
| DECLARE_OP_ADAPTER(FusedBatchNorm) | |||||
| DECLARE_OP_USE_OUTPUT(FusedBatchNorm) | |||||
| DECLARE_OP_ADAPTER(FusedBatchNormGrad) | |||||
| DECLARE_OP_USE_OUTPUT(FusedBatchNormGrad) | |||||
| DECLARE_OP_ADAPTER(BiasAddGrad) | DECLARE_OP_ADAPTER(BiasAddGrad) | ||||
| DECLARE_OP_USE_OUTPUT(BiasAddGrad) | DECLARE_OP_USE_OUTPUT(BiasAddGrad) | ||||
| DECLARE_OP_ADAPTER(MaxPoolWithArgmax) | DECLARE_OP_ADAPTER(MaxPoolWithArgmax) | ||||
| @@ -58,6 +58,7 @@ class ImageGradients(Cell): | |||||
| super(ImageGradients, self).__init__() | super(ImageGradients, self).__init__() | ||||
| def construct(self, images): | def construct(self, images): | ||||
| _check_input_4d(F.shape(images), "images", self.cls_name) | |||||
| batch_size, depth, height, width = P.Shape()(images) | batch_size, depth, height, width = P.Shape()(images) | ||||
| dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] | dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] | ||||
| dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) | dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) | ||||
| @@ -151,8 +152,8 @@ class SSIM(Cell): | |||||
| self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) | self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) | ||||
| def construct(self, img1, img2): | def construct(self, img1, img2): | ||||
| _check_input_4d(F.shape(img1), "img1", "SSIM") | |||||
| _check_input_4d(F.shape(img2), "img2", "SSIM") | |||||
| _check_input_4d(F.shape(img1), "img1", self.cls_name) | |||||
| _check_input_4d(F.shape(img2), "img2", self.cls_name) | |||||
| P.SameTypeShape()(img1, img2) | P.SameTypeShape()(img1, img2) | ||||
| max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) | max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) | ||||
| img1 = _convert_img_dtype_to_float32(img1, self.max_val) | img1 = _convert_img_dtype_to_float32(img1, self.max_val) | ||||
| @@ -244,8 +245,8 @@ class PSNR(Cell): | |||||
| self.max_val = max_val | self.max_val = max_val | ||||
| def construct(self, img1, img2): | def construct(self, img1, img2): | ||||
| _check_input_4d(F.shape(img1), "img1", "PSNR") | |||||
| _check_input_4d(F.shape(img2), "img2", "PSNR") | |||||
| _check_input_4d(F.shape(img1), "img1", self.cls_name) | |||||
| _check_input_4d(F.shape(img2), "img2", self.cls_name) | |||||
| P.SameTypeShape()(img1, img2) | P.SameTypeShape()(img1, img2) | ||||
| max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) | max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) | ||||
| img1 = _convert_img_dtype_to_float32(img1, self.max_val) | img1 = _convert_img_dtype_to_float32(img1, self.max_val) | ||||
| @@ -1016,6 +1016,7 @@ class Argmin(PrimitiveWithInfer): | |||||
| """init Argmin""" | """init Argmin""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name) | |||||
| self.axis = axis | self.axis = axis | ||||
| self.add_prim_attr('output_type', output_type) | self.add_prim_attr('output_type', output_type) | ||||
| @@ -1726,7 +1727,9 @@ class Diag(PrimitiveWithInfer): | |||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| if x is None: | if x is None: | ||||
| return None | return None | ||||
| validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name) | |||||
| # do constant-folding only when x rank is 1 | |||||
| if len(x.shape()) != 1: | |||||
| return None | |||||
| ret = np.diag(x.asnumpy()) | ret = np.diag(x.asnumpy()) | ||||
| return Tensor(ret) | return Tensor(ret) | ||||
| @@ -1752,7 +1755,7 @@ class DiagPart(PrimitiveWithInfer): | |||||
| >>> [0, 0, 3, 0], | >>> [0, 0, 3, 0], | ||||
| >>> [0, 0, 0, 4]]) | >>> [0, 0, 0, 4]]) | ||||
| >>> diag_part = P.DiagPart() | >>> diag_part = P.DiagPart() | ||||
| >>> diag_part(x) | |||||
| >>> diag_part(input_x) | |||||
| [1, 2, 3, 4] | [1, 2, 3, 4] | ||||
| """ | """ | ||||
| @@ -1776,7 +1779,9 @@ class DiagPart(PrimitiveWithInfer): | |||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| if x is None: | if x is None: | ||||
| return None | return None | ||||
| validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name) | |||||
| # do constant-folding only when x rank is 2 | |||||
| if len(x.shape()) != 2: | |||||
| return None | |||||
| ret = np.diag(x.asnumpy()) | ret = np.diag(x.asnumpy()) | ||||
| return Tensor(ret) | return Tensor(ret) | ||||
| @@ -2037,7 +2037,7 @@ class Atan2(_MathBinaryOp): | |||||
| r""" | r""" | ||||
| Returns arctangent of input_x/input_y element-wise. | Returns arctangent of input_x/input_y element-wise. | ||||
| It returns :math:`\theta\ \in\ (-\frac{\pi}{2}, \frac{\pi}{2})` | |||||
| It returns :math:`\theta\ \in\ [-\pi, \pi]` | |||||
| such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`. | such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`. | ||||
| Inputs: | Inputs: | ||||
| @@ -147,13 +147,13 @@ TEST_F(TestConvert, TestReluOps) { | |||||
| } | } | ||||
| TEST_F(TestConvert, TestConvertBatchNorm) { | TEST_F(TestConvert, TestConvertBatchNorm) { | ||||
| PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm; | |||||
| fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f)); | |||||
| fused_batch_norm->AddAttr("momentum", MakeValue(0.1f)); | |||||
| PrimitivePtr batch_norm = prim::kPrimBatchNorm; | |||||
| batch_norm->AddAttr("epsilon", MakeValue(0.001f)); | |||||
| batch_norm->AddAttr("momentum", MakeValue(0.1f)); | |||||
| FuncGraphPtr anf_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr anf_graph = std::make_shared<FuncGraph>(); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(fused_batch_norm)); | |||||
| inputs.push_back(NewValueNode(batch_norm)); | |||||
| for (unsigned int i = 0; i < 5; i++) { | for (unsigned int i = 0; i < 5; i++) { | ||||
| inputs.push_back(anf_graph->add_parameter()); | inputs.push_back(anf_graph->add_parameter()); | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test image gradients """ | """ test image gradients """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| @@ -47,3 +48,10 @@ def test_compile_multi_channel(): | |||||
| [[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype) | [[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype) | ||||
| net = Net() | net = Net() | ||||
| _executor.compile(net, image) | _executor.compile(net, image) | ||||
| def test_invalid_5d_input(): | |||||
| dtype = mstype.float32 | |||||
| image = Tensor(np.random.random([4, 1, 16, 16, 1]), dtype=dtype) | |||||
| net = Net() | |||||
| with pytest.raises(ValueError): | |||||
| _executor.compile(net, image) | |||||
| @@ -14,16 +14,15 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test array ops """ | """ test array ops """ | ||||
| import functools | import functools | ||||
| import pytest | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import prim_attr_register | from mindspore.ops import prim_attr_register | ||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops.primitive import Primitive, PrimitiveWithInfer | from mindspore.ops.primitive import Primitive, PrimitiveWithInfer | ||||
| from mindspore.common.dtype import get_py_obj_dtype | |||||
| from mindspore._c_expression import signature_dtype as sig_dtype | from mindspore._c_expression import signature_dtype as sig_dtype | ||||
| from mindspore._c_expression import signature_rw as sig_rw | from mindspore._c_expression import signature_rw as sig_rw | ||||
| from mindspore._c_expression import signature_kind as sig_kind | from mindspore._c_expression import signature_kind as sig_kind | ||||
| @@ -96,6 +95,17 @@ def test_select(): | |||||
| expect = np.array([[1, 8, 9], [10, 5, 6]]) | expect = np.array([[1, 8, 9], [10, 5, 6]]) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| def test_argmin_invalid_output_type(): | |||||
| P.Argmin(-1, mstype.int64) | |||||
| P.Argmin(-1, mstype.int32) | |||||
| with pytest.raises(TypeError): | |||||
| P.Argmin(-1, mstype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| P.Argmin(-1, mstype.float64) | |||||
| with pytest.raises(TypeError): | |||||
| P.Argmin(-1, mstype.uint8) | |||||
| with pytest.raises(TypeError): | |||||
| P.Argmin(-1, mstype.bool_) | |||||
| class CustomOP(PrimitiveWithInfer): | class CustomOP(PrimitiveWithInfer): | ||||
| __mindspore_signature__ = (sig_dtype.T, sig_dtype.T, sig_dtype.T1, | __mindspore_signature__ = (sig_dtype.T, sig_dtype.T, sig_dtype.T1, | ||||
| @@ -17,6 +17,7 @@ import functools | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.api import _executor | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | from mindspore.ops import prim_attr_register, PrimitiveWithInfer | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||