| @@ -38,14 +38,14 @@ class BatchNormFold2GpuKernel : public GpuKernel { | |||
| ~BatchNormFold2GpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| @@ -66,7 +66,7 @@ class BatchNormFold2GpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| @@ -98,9 +98,9 @@ class BatchNormFold2GpuKernel : public GpuKernel { | |||
| } | |||
| protected: | |||
| void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitSizeLists() { | |||
| void InitSizeLists() override { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = channel_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); | |||
| @@ -38,14 +38,14 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { | |||
| ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| @@ -88,7 +88,7 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| @@ -120,9 +120,9 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { | |||
| } | |||
| protected: | |||
| void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitSizeLists() { | |||
| void InitSizeLists() override { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = channel_ * sizeof(T); | |||
| size_t workspace_size = batch_size_ * channel_ * sizeof(T); | |||
| @@ -46,14 +46,14 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||
| ~BatchNormFoldGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| (void)workspace; | |||
| auto x = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto mean = reinterpret_cast<T *>(inputs[1]->addr); | |||
| @@ -104,7 +104,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 4) { | |||
| @@ -152,7 +152,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| void InitSizeLists() override { | |||
| // x, mean, variance, current_step | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(output_size_); | |||
| @@ -169,7 +169,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||
| workspace_size_list_.push_back(input_size_); | |||
| } | |||
| void InitResource() { | |||
| void InitResource() override { | |||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); | |||
| @@ -42,11 +42,12 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { | |||
| width_(0) {} | |||
| ~BatchNormFoldGradGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| (void)workspace; | |||
| // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' | |||
| T *d_batch_mean = GetDeviceAddress<T>(inputs, 0); | |||
| @@ -92,7 +93,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 6) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; | |||
| @@ -128,7 +130,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| void InitSizeLists() override { | |||
| // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' | |||
| input_size_list_.push_back(channel_size_); | |||
| input_size_list_.push_back(channel_size_); | |||
| @@ -30,11 +30,11 @@ class CorrectionMulGpuKernel : public GpuKernel { | |||
| CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} | |||
| ~CorrectionMulGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| auto *weight = GetDeviceAddress<T>(inputs, 0); | |||
| auto *gamma = GetDeviceAddress<T>(inputs, 1); | |||
| auto *running_std = GetDeviceAddress<T>(inputs, 2); | |||
| @@ -44,7 +44,7 @@ class CorrectionMulGpuKernel : public GpuKernel { | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| @@ -69,7 +69,7 @@ class CorrectionMulGpuKernel : public GpuKernel { | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| void InitSizeLists() override { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = batch_size_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); // weight | |||
| @@ -79,7 +79,7 @@ class CorrectionMulGpuKernel : public GpuKernel { | |||
| output_size_list_.push_back(input_size); | |||
| workspace_size_list_.push_back(workspace_size); | |||
| } | |||
| void InitResource() {} | |||
| void InitResource() override {} | |||
| private: | |||
| void DestroyResource() noexcept {} | |||
| @@ -30,11 +30,12 @@ class CorrectionMulGradGpuKernel : public GpuKernel { | |||
| CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} | |||
| ~CorrectionMulGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| auto *d_out = GetDeviceAddress<T>(inputs, 0); | |||
| auto *weight = GetDeviceAddress<T>(inputs, 1); | |||
| auto *gamma = GetDeviceAddress<T>(inputs, 2); | |||
| @@ -49,7 +50,8 @@ class CorrectionMulGradGpuKernel : public GpuKernel { | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| @@ -74,7 +76,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel { | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| void InitSizeLists() override { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = batch_size_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); // d_out | |||
| @@ -85,7 +87,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel { | |||
| output_size_list_.push_back(weight_size); // d_gamma | |||
| workspace_size_list_.push_back(input_size); // tmp d_out * weight | |||
| } | |||
| void InitResource() {} | |||
| void InitResource() override {} | |||
| private: | |||
| void DestroyResource() noexcept {} | |||
| @@ -369,7 +369,7 @@ class HSigmoid(Cell): | |||
| Hard sigmoid is defined as: | |||
| .. math:: | |||
| \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), | |||
| \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})), | |||
| where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. | |||
| @@ -319,7 +319,7 @@ class HSigmoid(PrimitiveWithInfer): | |||
| Hard sigmoid is defined as: | |||
| .. math:: | |||
| \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), | |||
| \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})), | |||
| where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import ms_function | |||
| import mindspore.context as context | |||
| context.set_context(device_target='GPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.op = P.BatchNormFold2(100000) | |||
| @ms_function | |||
| def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step): | |||
| return self.op(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step) | |||
| class Net_gnd(nn.Cell): | |||
| def __init__(self): | |||
| super(Net_gnd, self).__init__() | |||
| self.conv_mul = P.ConvMul(freeze_bn=100000) | |||
| self.correct_add = P.CorrectionAdd(freeze_bn=100000) | |||
| self.add_fold = P.AddFold() | |||
| @ms_function | |||
| def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step): | |||
| out = self.conv_mul(x, batch_std, running_std, current_step) | |||
| out = self.correct_add(out, gamma, batch_std, batch_mean, | |||
| running_std, running_mean, current_step) | |||
| out = self.add_fold(out, beta, gamma, batch_std, batch_mean) | |||
| return out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnrom_fold2(): | |||
| net = Net() | |||
| c = 64 | |||
| freeze_bn = 100000 | |||
| x = np.random.uniform(-1, 1, size=[3, c, 32, 32]).astype('float32') | |||
| beta = np.random.uniform(1, 2, size=[c]).astype('float32') | |||
| gamma = np.random.uniform(1, 2, size=[c]).astype('float32') | |||
| batch_std = np.random.uniform(1, 2, size=[c]).astype('float32') | |||
| batch_mean = np.random.uniform(1, 2, size=[c]).astype('float32') | |||
| running_std = np.random.uniform(1, 2, size=[c]).astype('float32') | |||
| running_mean = np.random.uniform(1, 2, size=[c]).astype('float32') | |||
| current_step = np.array([0]).astype('int32') | |||
| output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean), | |||
| Tensor(running_std), Tensor(running_mean), Tensor(current_step)) | |||
| expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1, | |||
| 1) if current_step >= freeze_bn else | |||
| x * (running_std / batch_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1, | |||
| 1)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-6 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(diff > error * -1) | |||
| current_step = np.array([100000]).astype('int32') | |||
| output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean), Tensor(running_std), | |||
| Tensor(running_mean), Tensor(current_step)) | |||
| expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1, | |||
| 1) if current_step >= freeze_bn else | |||
| x * (batch_std / running_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1, | |||
| 1)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-6 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(diff > error * -1) | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import ms_function | |||
| import mindspore.context as context | |||
| context.set_context(device_target='GPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.op = P.BatchNormFoldGrad(freeze_bn=10) | |||
| @ms_function | |||
| def construct(self, d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step): | |||
| dx = self.op(d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step) | |||
| return dx | |||
| def np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std): | |||
| n = x.shape[0] * x.shape[2] * x.shape[3] | |||
| dx = d_batch_mean.reshape(1, -1, 1, 1) / n + d_batch_std.reshape(1, -1, 1, 1) * ( | |||
| x - batch_mean.reshape(1, -1, 1, 1)) / batch_std.reshape(1, -1, 1, 1) / n | |||
| return dx | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnorm_fold_grad1(): | |||
| net = Net() | |||
| c = 64 | |||
| x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') | |||
| d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| current_step = np.array([0]).astype('int32') | |||
| dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std), | |||
| Tensor(current_step)) | |||
| expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std) | |||
| assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnorm_fold_grad2(): | |||
| net = Net() | |||
| c = 64 | |||
| x = np.random.uniform(1, 10, size=[1, c, 256, 256]).astype('float32') | |||
| d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| current_step = np.array([0]).astype('int32') | |||
| dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std), | |||
| Tensor(current_step)) | |||
| expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std) | |||
| assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnorm_fold_grad_freeze(): | |||
| net = Net() | |||
| c = 64 | |||
| x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') | |||
| d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| current_step = np.array([10]).astype('int32') | |||
| dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std), | |||
| Tensor(current_step)) | |||
| expect = np.zeros_like(x) | |||
| assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7) | |||
| @@ -0,0 +1,116 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import ms_function | |||
| import mindspore.context as context | |||
| context.set_context(device_target='GPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.op = P.BatchNormFold(freeze_bn=10) | |||
| @ms_function | |||
| def construct(self, x, mean, variance, current_step): | |||
| a, b, c, d = self.op(x, mean, variance, current_step) | |||
| return a, b, c, d | |||
| def np_result(x, mean, var, momentum, epsilon): | |||
| np_mean = x.mean(axis=(0, 2, 3)) | |||
| np_var = x.var(axis=(0, 2, 3)) | |||
| n = x.shape[0] * x.shape[2] * x.shape[3] | |||
| mean_update = momentum * np_mean + (1 - momentum) * mean | |||
| var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var | |||
| np_var = np.sqrt(np_var + epsilon) | |||
| delay_mean = mean.copy() | |||
| delay_std = np.sqrt(var + epsilon) | |||
| return np_mean, np_var, mean_update, var_update, delay_mean, delay_std | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnorm_fold(): | |||
| net = Net() | |||
| c = 64 | |||
| x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') | |||
| mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| variance = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| current_step = np.array([0]).astype('int32') | |||
| ms_mean = Tensor(mean) | |||
| ms_var = Tensor(variance) | |||
| batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var, | |||
| Tensor(current_step)) | |||
| expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12) | |||
| assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(ms_var.asnumpy(), expect4, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnorm_fold2(): | |||
| net = Net() | |||
| c = 64 | |||
| x = np.random.uniform(1, 10, size=[3, c, 512, 512]).astype('float32') | |||
| mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| variance = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| current_step = np.array([0]).astype('int32') | |||
| ms_mean = Tensor(mean) | |||
| ms_var = Tensor(variance) | |||
| batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var, | |||
| Tensor(current_step)) | |||
| expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12) | |||
| assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_batchnorm_fold_freeze(): | |||
| net = Net() | |||
| c = 64 | |||
| x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') | |||
| mean = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| variance = np.random.uniform(1, 10, size=[c]).astype('float32') | |||
| current_step = np.array([10]).astype('int32') | |||
| ms_mean = Tensor(mean) | |||
| ms_var = Tensor(variance) | |||
| batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var, | |||
| Tensor(current_step)) | |||
| expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12) | |||
| assert np.allclose(batch_mean.asnumpy(), np.zeros_like(mean), rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(batch_var.asnumpy(), np.ones_like(mean), rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(ms_mean.asnumpy(), mean, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(ms_var.asnumpy(), variance, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5) | |||
| assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5) | |||
| @@ -14,10 +14,10 @@ | |||
| # ============================================================================ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import os | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import ms_function | |||
| import mindspore.context as context | |||
| context.set_context(device_target='GPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.op_w = P.CorrectionMulGrad() | |||
| @ms_function | |||
| def construct(self, dy, x, batch_std, running_std): | |||
| dx, d_batch_std = self.op_w(dy, x, batch_std, running_std) | |||
| return dx, d_batch_std | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_correction_mul_grad(): | |||
| net = Net() | |||
| co, ci, h, w = 64, 1, 32, 32 | |||
| dout = np.random.uniform(-0.1, 0.1, size=[co, ci, h, w]).astype('float32') | |||
| x = np.random.uniform(1, 1, size=[co, ci, h, w]).astype('float32') | |||
| batch_std = np.random.uniform(1, 10, size=[co]).astype('float32') | |||
| running_std = np.random.uniform(1, 10, size=[co]).astype('float32') | |||
| output = net(Tensor(dout), Tensor(x), Tensor(batch_std), Tensor(running_std)) | |||
| expect = [0, 0] | |||
| expect[0] = (dout * np.reshape(batch_std / running_std, (co, 1, 1, 1))) | |||
| expect[1] = (np.sum(dout * x, (1, 2, 3)) / running_std) | |||
| for i, v in enumerate(output): | |||
| assert (np.allclose(output[i].asnumpy(), expect[i], rtol=1.e-5, atol=1.e-5)) | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import ms_function | |||
| import mindspore.context as context | |||
| context.set_context(device_target='GPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.op = P.CorrectionMul() | |||
| @ms_function | |||
| def construct(self, x, batch_var, moving_var): | |||
| return self.op(x, batch_var, moving_var) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_correction_mul(): | |||
| net = Net() | |||
| co = 64 | |||
| x = np.random.uniform(-1, 1, size=[co, 64, 32, 32]).astype('float32') | |||
| bv = np.random.uniform(1, 2, size=[co]).astype('float32') | |||
| mv = np.random.uniform(1, 2, size=[co]).astype('float32') | |||
| output = net(Tensor(x), Tensor(bv), Tensor(mv)) | |||
| expect = x * np.reshape(bv, (co, 1, 1, 1)) / np.reshape(mv, (co, 1, 1, 1)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert np.all(diff > error * -1) | |||
| assert (output.shape() == expect.shape) | |||