diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h index f6c3205df6..52815ab10d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h @@ -97,11 +97,6 @@ class BatchNormGpuKernel : public GpuKernel { InitResource(); is_train_ = GetAttr(kernel_node, "is_training"); - if (is_train_) { - mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - } else { - mode_ = CUDNN_BATCHNORM_SPATIAL; - } epsilon_ = GetAttr(kernel_node, "epsilon"); exp_avg_factor_ = GetAttr(kernel_node, "momentum"); @@ -118,8 +113,8 @@ class BatchNormGpuKernel : public GpuKernel { } auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 4"; + if (shape.size() != 4 && shape.size() != 2) { + MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 2D or 4D"; } is_null_input_ = CHECK_NULL_INPUT(shape); if (is_null_input_) { @@ -127,6 +122,15 @@ class BatchNormGpuKernel : public GpuKernel { InitSizeLists(); return true; } + + if (shape.size() == 2) { + mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; + } else if (is_train_) { + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } else { + mode_ = CUDNN_BATCHNORM_SPATIAL; + } + auto format = AnfAlgo::GetInputFormat(kernel_node, 0); auto format_attr = GetAttr(kernel_node, "format"); if (format_attr == kOpFormat_NHWC) { @@ -242,7 +246,13 @@ class BatchNormGpuKernel : public GpuKernel { void SetTensorDescriptor(const std::string &format, const std::vector &shape) { cudnnTensorFormat_t cudnn_format; int batch, channel, height, width; - if (format == kOpFormat_NHWC) { + if (shape.size() == 2) { + batch = SizeToInt(shape[0]); + channel = SizeToInt(shape[1]); + height = 1; + width = 1; + cudnn_format = CUDNN_TENSOR_NCHW; + } else if (format == kOpFormat_NHWC) { batch = SizeToInt(shape[0]); height = SizeToInt(shape[1]); width = SizeToInt(shape[2]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h index 8c0a4ca7d4..e64856311a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h @@ -124,7 +124,6 @@ class BatchNormGradGpuKernel : public GpuKernel { } InitResource(); - mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; epsilon_ = GetAttr(kernel_node, "epsilon"); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); @@ -140,8 +139,8 @@ class BatchNormGradGpuKernel : public GpuKernel { } auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4"; + if (shape.size() != 4 && shape.size() != 2) { + MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 2D or 4D"; } is_null_input_ = CHECK_NULL_INPUT(shape); if (is_null_input_) { @@ -149,6 +148,12 @@ class BatchNormGradGpuKernel : public GpuKernel { InitSizeLists(); return true; } + if (shape.size() == 2) { + mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; + } else { + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } + std::string format = AnfAlgo::GetInputFormat(kernel_node, 0); auto format_attr = GetAttr(kernel_node, "format"); if (format_attr == kOpFormat_NHWC) { @@ -234,7 +239,13 @@ class BatchNormGradGpuKernel : public GpuKernel { private: void SetTensorDescriptor(const std::string &format, const std::vector &shape) { cudnnTensorFormat_t cudnn_format; - if (format == kOpFormat_NHWC) { + if (shape.size() == 2) { + batch_ = SizeToInt(shape[0]); + channel_ = SizeToInt(shape[1]); + height_ = 1; + width_ = 1; + cudnn_format = CUDNN_TENSOR_NCHW; + } else if (format == kOpFormat_NHWC) { batch_ = SizeToInt(shape[0]); height_ = SizeToInt(shape[1]); width_ = SizeToInt(shape[2]); diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 93142d8f40..da65d2b058 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -288,7 +288,7 @@ class BatchNorm1d(_BatchNorm): Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Raises: TypeError: If `num_features` is not an int. diff --git a/tests/st/ops/gpu/test_batchnorm_op.py b/tests/st/ops/gpu/test_batchnorm_op.py index d8457c01f2..036927be5c 100644 --- a/tests/st/ops/gpu/test_batchnorm_op.py +++ b/tests/st/ops/gpu/test_batchnorm_op.py @@ -18,7 +18,8 @@ import pytest import mindspore.context as context from mindspore.common.tensor import Tensor -from mindspore.nn import BatchNorm2d +from mindspore.common.parameter import ParameterTuple +from mindspore.nn import BatchNorm2d, BatchNorm1d, SGD from mindspore.nn import Cell from mindspore.ops import composite as C @@ -201,3 +202,139 @@ def test_infer_backward(): ms_grad = Grad(ms_net) ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np)) assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output) + + +class BatchNorm1d_Net(Cell): + def __init__(self, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros', + moving_var_init='ones', use_batch_statistics=None): + super(BatchNorm1d_Net, self).__init__() + self.bn1 = BatchNorm1d(2, eps=0.00001, momentum=0.1, affine=affine, gamma_init=gamma_init, beta_init=beta_init, + moving_mean_init=moving_mean_init, moving_var_init=moving_var_init, + use_batch_statistics=use_batch_statistics) + + def construct(self, x): + x = self.bn1(x) + return x + +class GradByListNet(Cell): + def __init__(self, network): + super(GradByListNet, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True) + self.network = network + self.params = ParameterTuple(network.trainable_params()) + + def construct(self, x, dy): + grad_op = self.grad(self.network, self.params) + output = grad_op(x, dy) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_1d_train(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + bn_net = BatchNorm1d_Net(use_batch_statistics=None) + grad_net = GradByListNet(bn_net) + optimizer = SGD(bn_net.trainable_params(), learning_rate=0.01, momentum=0.9) + bn_net.set_train(True) + + x1 = np.array([[1.6243454, -0.6117564], + [-0.5281718, -1.0729686], + [0.86540765, -2.3015387], + [1.7448118, -0.7612069], + [0.3190391, -0.24937038]]).astype(np.float32) + dy1 = np.array([[1.4621079, -2.0601406], + [-0.3224172, -0.38405436], + [1.1337694, -1.0998913], + [-0.1724282, -0.8778584], + [0.04221375, 0.58281523]]).astype(np.float32) + x2 = np.array([[-0.19183555, -0.887629], + [-0.7471583, 1.6924546], + [0.05080776, -0.6369957], + [0.19091548, 2.1002553], + [0.12015896, 0.6172031]]).astype(np.float32) + dy2 = np.array([[0.30017033, -0.35224986], + [-1.1425182, -0.34934273], + [-0.20889424, 0.5866232], + [0.8389834, 0.9311021], + [0.2855873, 0.8851412]]).astype(np.float32) + x_train = [x1, x2] + dy_train = [dy1, dy2] + + dx1 = np.array([[0.8120, -2.0371], + [-0.2202, 0.5837], + [0.8040, 0.1950], + [-1.1823, -0.2786], + [-0.2135, 1.5371]]).astype(np.float32) + gamma1 = np.array([0.9821, 0.9873]).astype(np.float32) + beta1 = np.array([-0.0214, 0.0384]).astype(np.float32) + mean1 = np.array([0.7246, -0.8994]).astype(np.float32) + variance1 = np.array([0.9036, 0.6559]).astype(np.float32) + + dx2 = np.array([[1.1955, -0.4247], + [-0.2425, -0.6789], + [-1.4563, 0.3237], + [0.8752, 0.3351], + [-0.3719, 0.4448]]).astype(np.float32) + gamma2 = np.array([0.9370, 0.9687]).astype(np.float32) + beta2 = np.array([-0.0415, 0.0559]).astype(np.float32) + mean2 = np.array([-0.0314, 0.4294]).astype(np.float32) + variance2 = np.array([0.2213, 1.6822]).astype(np.float32) + + exp_dx = [dx1, dx2] + exp_gamma = [gamma1, gamma2] + exp_beta = [beta1, beta2] + exp_mean = [mean1, mean2] + exp_variance = [variance1, variance2] + + for data in zip(x_train, dy_train, exp_dx, exp_gamma, exp_beta, exp_mean, exp_variance): + output = grad_net(Tensor(data[0]), Tensor(data[1])) + assert np.allclose(output[0][0].asnumpy(), data[2], atol=1.0e-4) + optimizer(output[1]) + assert np.allclose(bn_net.bn1.gamma.asnumpy(), data[3], atol=1.0e-4) + assert np.allclose(bn_net.bn1.beta.asnumpy(), data[4], atol=1.0e-4) + assert np.allclose(bn_net.bn1.moving_mean.asnumpy(), data[5], atol=1.0e-4) + assert np.allclose(bn_net.bn1.moving_variance.asnumpy(), data[6], atol=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_1d_eval(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gamma_init = Tensor(np.array([0.93700373, 0.96870345]).astype(np.float32)) + beta_init = Tensor(np.array([-0.04145495, 0.05593072]).astype(np.float32)) + mean_init = Tensor(np.array([-0.03142229, 0.4294087]).astype(np.float32)) + variance_init = Tensor(np.array([0.2212921, 1.6822311]).astype(np.float32)) + bn_net = BatchNorm1d_Net(affine=False, gamma_init=gamma_init, beta_init=beta_init, moving_mean_init=mean_init, + moving_var_init=variance_init, use_batch_statistics=None) + bn_net.set_train(False) + + x1 = np.array([[-1.1006192, 1.1447237], + [0.9015907, 0.50249434], + [0.90085596, -0.68372786], + [-0.12289023, -0.93576944], + [-0.26788807, 0.53035545]]).astype(np.float32) + x2 = np.array([[-0.7543979, 1.2528682], + [0.5129298, -0.29809284], + [0.48851815, -0.07557172], + [1.1316293, 1.5198169], + [2.1855755, -1.3964963]]).astype(np.float32) + x_test = [x1, x2] + + y1 = np.array([[-2.1711, 0.5902], + [1.8169, 0.1105], + [1.8155, -0.7754], + [-0.2236, -0.9637], + [-0.5125, 0.1313]]).astype(np.float32) + y2 = np.array([[-1.4815, 0.6710], + [1.0428, -0.4874], + [0.9942, -0.3212], + [2.2751, 0.8703], + [4.3744, -1.3078]]).astype(np.float32) + y_test = [y1, y2] + + for x, y in zip(x_test, y_test): + y_pred = bn_net(Tensor(x)) + assert np.allclose(y_pred.asnumpy(), y, atol=1.0e-4)