From: @tom__chen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -29,25 +29,7 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class FusedBatchNormExGpuKernel : public GpuKernel { | class FusedBatchNormExGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| FusedBatchNormExGpuKernel() | |||||
| : input_x_size_(0), | |||||
| input_z_size_(0), | |||||
| para_size_(0), | |||||
| output_size_(0), | |||||
| workspace_size_(0), | |||||
| reserve_size_(0), | |||||
| mode_(CUDNN_BATCHNORM_SPATIAL), | |||||
| bn_ops_(CUDNN_BATCHNORM_OPS_BN), | |||||
| epsilon_(10e-5), | |||||
| exp_avg_factor_(0.1), | |||||
| is_null_input_(false), | |||||
| x_desc_(nullptr), | |||||
| y_desc_(nullptr), | |||||
| z_desc_(nullptr), | |||||
| scale_bias_mean_var_desc_(nullptr), | |||||
| activation_desc_(nullptr), | |||||
| handle_(nullptr), | |||||
| cudnn_data_type_(CUDNN_DATA_FLOAT) {} | |||||
| FusedBatchNormExGpuKernel() { ResetResource(); } | |||||
| ~FusedBatchNormExGpuKernel() override { DestroyResource(); } | ~FusedBatchNormExGpuKernel() override { DestroyResource(); } | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -142,6 +124,30 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void ResetResource() noexcept override { | |||||
| input_x_size_ = 0; | |||||
| input_z_size_ = 0; | |||||
| para_size_ = 0; | |||||
| output_size_ = 0; | |||||
| workspace_size_ = 0; | |||||
| reserve_size_ = 0; | |||||
| mode_ = CUDNN_BATCHNORM_SPATIAL; | |||||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN; | |||||
| epsilon_ = 10e-5; | |||||
| exp_avg_factor_ = 0.1; | |||||
| is_null_input_ = false; | |||||
| x_desc_ = nullptr; | |||||
| y_desc_ = nullptr; | |||||
| z_desc_ = nullptr; | |||||
| scale_bias_mean_var_desc_ = nullptr; | |||||
| activation_desc_ = nullptr; | |||||
| handle_ = nullptr; | |||||
| cudnn_data_type_ = CUDNN_DATA_FLOAT; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| void DestroyResource() noexcept override { | void DestroyResource() noexcept override { | ||||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); | CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); | ||||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); | CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); | ||||
| @@ -51,6 +51,8 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -23,6 +23,15 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| int64_t GetAndCheckFormat(const ValuePtr &value) { | |||||
| int64_t data_format; | |||||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | |||||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { | |||||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; | |||||
| } | |||||
| return data_format; | |||||
| } | |||||
| AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tensor. | // Inputs: a tensor. | ||||
| @@ -235,6 +244,54 @@ AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const Pri | |||||
| return std::make_shared<AbstractTuple>(rets); | return std::make_shared<AbstractTuple>(rets); | ||||
| } | } | ||||
| AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: five tensors(x, gamma, beta, mean, variance). | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 5); | |||||
| AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(input_x); | |||||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||||
| ShapeVector x_shape = input_x->shape()->shape(); | |||||
| ShapeVector x_min_shape = input_x->shape()->min_shape(); | |||||
| ShapeVector x_max_shape = input_x->shape()->max_shape(); | |||||
| CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||||
| if (x_shape.size() != 4) { | |||||
| MS_LOG(EXCEPTION) << "Input rank should 4."; | |||||
| } | |||||
| auto data_format_ptr = primitive->GetAttr("format"); | |||||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||||
| int64_t data_format = GetAndCheckFormat(data_format_ptr); | |||||
| int64_t c_axis = 1; | |||||
| if (data_format == Format::NHWC) { | |||||
| c_axis = 3; | |||||
| } | |||||
| for (size_t i = 1; i < args_spec_list.size(); ++i) { | |||||
| AbstractTensorPtr arg_spec = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | |||||
| MS_EXCEPTION_IF_NULL(arg_spec); | |||||
| MS_EXCEPTION_IF_NULL(arg_spec->shape()); | |||||
| ShapeVector arg_shape = arg_spec->shape()->shape(); | |||||
| if (arg_shape.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size(); | |||||
| } | |||||
| if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) { | |||||
| MS_LOG(EXCEPTION) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis << "]=" << x_shape[c_axis] | |||||
| << ", but got " << arg_shape[0]; | |||||
| } | |||||
| } | |||||
| AbstractTensorPtr input_gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| ShapeVector gamma_shape = input_gamma->shape()->shape(); | |||||
| ShapeVector gamma_min_shape = input_gamma->shape()->min_shape(); | |||||
| ShapeVector gamma_max_shape = input_gamma->shape()->max_shape(); | |||||
| CheckMinMaxShape(gamma_shape, &gamma_min_shape, &gamma_max_shape); | |||||
| ShapePtr output_shape_ptr = std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape); | |||||
| AbstractTensorPtr output = std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr); | |||||
| ShapePtr gamma_shape_ptr = std::make_shared<Shape>(gamma_shape, gamma_min_shape, gamma_max_shape); | |||||
| AbstractTensorPtr output_gamma = std::make_shared<AbstractTensor>(input_gamma->element(), gamma_shape_ptr); | |||||
| AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma, output_gamma}; | |||||
| return std::make_shared<AbstractTuple>(rets); | |||||
| } | |||||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | ||||
| @@ -311,15 +368,6 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa | |||||
| } | } | ||||
| } | } | ||||
| int64_t GetAndCheckFormat(const ValuePtr &value) { | |||||
| int64_t data_format; | |||||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | |||||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { | |||||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; | |||||
| } | |||||
| return data_format; | |||||
| } | |||||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| @@ -141,6 +141,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | ||||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | ||||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | ||||
| {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, | |||||
| {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | ||||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | ||||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | {prim::kPrimConv2D, {InferImplConv2D, true}}, | ||||
| @@ -871,11 +871,12 @@ class FusedBatchNorm(Primitive): | |||||
| self.target = context.get_context("device_target") | self.target = context.get_context("device_target") | ||||
| class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| class FusedBatchNormEx(PrimitiveWithCheck): | |||||
| r""" | r""" | ||||
| FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve) | FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve) | ||||
| than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that | than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that | ||||
| moving mean and moving variance will be computed instead of being loaded. | |||||
| moving mean and moving variance will be computed instead of being loaded. FusedBatchNormEx currently only | |||||
| supports 4D inputs. | |||||
| Batch Normalization is widely used in convolutional networks. This operation applies | Batch Normalization is widely used in convolutional networks. This operation applies | ||||
| Batch Normalization over input to avoid internal covariate shift as described in the | Batch Normalization over input to avoid internal covariate shift as described in the | ||||
| @@ -899,7 +900,7 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| Default: "NCHW". | Default: "NCHW". | ||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, | |||||
| - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`, | |||||
| data type: float16 or float32. | data type: float16 or float32. | ||||
| - **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, | - **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, | ||||
| data type: float32. | data type: float32. | ||||
| @@ -970,25 +971,22 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| def infer_shape(self, input_x, scale, bias, mean, variance): | |||||
| def check_shape(self, input_x, scale, bias, mean, variance): | |||||
| input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) | input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) | ||||
| validator.check_equal_int(len(input_shape_norm), 4, "x rank", self.name) | |||||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | validator.check_equal_int(len(scale), 1, "scale rank", self.name) | ||||
| validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | ||||
| validator.check("scale shape[0]", scale[0], "input channel", input_shape_norm[1], Rel.EQ, self.name) | |||||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | validator.check_equal_int(len(mean), 1, "mean rank", self.name) | ||||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | ||||
| validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | ||||
| return (input_x, scale, scale, scale, scale, scale) | |||||
| def infer_dtype(self, input_x, scale, bias, mean, variance): | |||||
| def check_dtype(self, input_x, scale, bias, mean, variance): | |||||
| validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) | validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) | ||||
| args = {"scale": scale, "bias": bias} | args = {"scale": scale, "bias": bias} | ||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) | validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) | ||||
| args_moving = {"mean": mean, "variance": variance} | args_moving = {"mean": mean, "variance": variance} | ||||
| valid_dtypes = [mstype.tensor_type(mstype.float32)] | valid_dtypes = [mstype.tensor_type(mstype.float32)] | ||||
| validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) | validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) | ||||
| return (input_x, scale, scale, scale, scale, scale) | |||||
| class InstanceNorm(PrimitiveWithInfer): | class InstanceNorm(PrimitiveWithInfer): | ||||
| @@ -0,0 +1,128 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.nn import Cell | |||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from mindspore.ops import operations as P | |||||
| class NetFusedBatchNormEx(Cell): | |||||
| def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): | |||||
| super(NetFusedBatchNormEx, self).__init__() | |||||
| self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) | |||||
| self.moving_mean = Parameter(initializer( | |||||
| mean_init, num_features), name="mean", requires_grad=False) | |||||
| self.moving_variance = Parameter(initializer( | |||||
| var_init, num_features), name="variance", requires_grad=False) | |||||
| self.gamma = Parameter(initializer( | |||||
| gamma_init, num_features), name="gamma", requires_grad=True) | |||||
| self.beta = Parameter(initializer( | |||||
| beta_init, num_features), name="beta", requires_grad=True) | |||||
| self.dynshape = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x): | |||||
| x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance) | |||||
| return x | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_fused_bn_ex(): | |||||
| x = np.array([[ | |||||
| [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], | |||||
| [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) | |||||
| expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], | |||||
| [-0.1471, 0.7706, 1.6882, 2.6059], | |||||
| [0.3118, 1.6882, 2.1471, 2.1471], | |||||
| [0.7706, 0.3118, 2.6059, -0.1471]], | |||||
| [[0.9119, 1.8518, 1.3819, -0.0281], | |||||
| [-0.0281, 0.9119, 1.3819, 1.8518], | |||||
| [2.7918, 0.4419, -0.4981, 0.9119], | |||||
| [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) | |||||
| weight = np.ones(2).astype(np.float32) | |||||
| bias = np.ones(2).astype(np.float32) | |||||
| moving_mean = np.ones(2).astype(np.float32) | |||||
| moving_var = np.ones(2).astype(np.float32) | |||||
| error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| bn_net = NetFusedBatchNormEx(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var)) | |||||
| output_list = bn_net(Tensor(x)) | |||||
| output = output_list[0] | |||||
| diff = output.asnumpy() - expect_output | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| class NetFusedBatchNormExDynamic(Cell): | |||||
| def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): | |||||
| super(NetFusedBatchNormExDynamic, self).__init__() | |||||
| self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) | |||||
| self.moving_mean = Parameter(initializer( | |||||
| mean_init, num_features), name="mean", requires_grad=False) | |||||
| self.moving_variance = Parameter(initializer( | |||||
| var_init, num_features), name="variance", requires_grad=False) | |||||
| self.gamma = Parameter(initializer( | |||||
| gamma_init, num_features), name="gamma", requires_grad=True) | |||||
| self.beta = Parameter(initializer( | |||||
| beta_init, num_features), name="beta", requires_grad=True) | |||||
| self.dynshape = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x): | |||||
| x = self.dynshape(x) | |||||
| x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance) | |||||
| return x | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_fused_bn_ex_dynamic(): | |||||
| x = np.array([[ | |||||
| [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], | |||||
| [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) | |||||
| expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], | |||||
| [-0.1471, 0.7706, 1.6882, 2.6059], | |||||
| [0.3118, 1.6882, 2.1471, 2.1471], | |||||
| [0.7706, 0.3118, 2.6059, -0.1471]], | |||||
| [[0.9119, 1.8518, 1.3819, -0.0281], | |||||
| [-0.0281, 0.9119, 1.3819, 1.8518], | |||||
| [2.7918, 0.4419, -0.4981, 0.9119], | |||||
| [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) | |||||
| weight = np.ones(2).astype(np.float32) | |||||
| bias = np.ones(2).astype(np.float32) | |||||
| moving_mean = np.ones(2).astype(np.float32) | |||||
| moving_var = np.ones(2).astype(np.float32) | |||||
| error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| bn_net = NetFusedBatchNormExDynamic(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var)) | |||||
| output_list = bn_net(Tensor(x)) | |||||
| output = output_list[0] | |||||
| diff = output.asnumpy() - expect_output | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||