From: @jonwe Reviewed-by: @liangchenghui Signed-off-by:tags/v1.1.0
| @@ -0,0 +1,120 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdint.h> | |||
| #include <thrust/device_ptr.h> | |||
| #include <thrust/fill.h> | |||
| #include <thrust/reduce.h> | |||
| #include <thrust/system/cuda/execution_policy.h> | |||
| #include "batchnorm_grad_impl.cuh" | |||
| #include "include/cuda_runtime.h" | |||
| const int kWarpSize = 32; | |||
| const int kBlockSize = 1024; | |||
| const int kNumWarps = 32; | |||
| template <typename T> | |||
| __global__ void BatchNormGradKernel(T *x_input, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, | |||
| float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W) { | |||
| __shared__ T shared_dy[kNumWarps]; | |||
| __shared__ T shared_p[kNumWarps]; | |||
| int warpId = threadIdx.x / kWarpSize; | |||
| int laneId = threadIdx.x % kWarpSize; | |||
| int plane = blockIdx.x; | |||
| int plane_size = N * H * W; | |||
| T invstd = static_cast<T>(1) / static_cast<T>(sqrt(save_variance[plane] + epsilon)); | |||
| T scale_val = scale != nullptr ? static_cast<T>(scale[plane]) : static_cast<T>(1); | |||
| T grad_scale = invstd * scale_val; | |||
| T mean = static_cast<T>(save_mean[plane]); | |||
| T dy_sum = static_cast<T>(0); | |||
| T dot_p = static_cast<T>(0); | |||
| if (threadIdx.x < kNumWarps) { | |||
| shared_dy[threadIdx.x] = static_cast<T>(0); | |||
| shared_p[threadIdx.x] = static_cast<T>(0); | |||
| } | |||
| __syncthreads(); | |||
| // Compute three values across (Batch, Height, Width) in one pass: | |||
| // 1. dx | |||
| // 2. Sum(dy) | |||
| // 3. DotProduct(x - mean, dy) | |||
| for (int x = threadIdx.x; x < plane_size; x += blockDim.x) { | |||
| int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W)); | |||
| dx[index] = static_cast<T>(dy[index] * grad_scale); | |||
| dy_sum += dy[index]; | |||
| dot_p += (x_input[index] - mean) * dy[index]; | |||
| } | |||
| __syncthreads(); | |||
| // Warp reduction | |||
| for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { | |||
| T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset); | |||
| T other_p = __shfl_down_sync(0xffffffff, dot_p, offset); | |||
| dy_sum += other_dy; | |||
| dot_p += other_p; | |||
| } | |||
| __syncwarp(); | |||
| // Move warp-reduction result to shared memory | |||
| if (laneId == 0) { | |||
| shared_dy[warpId] = dy_sum; | |||
| shared_p[warpId] = dot_p; | |||
| } | |||
| __syncthreads(); | |||
| // Shared memory reduction | |||
| // There are exactly 32 items in shared memory, can be reduced within one warp. | |||
| if (warpId == 0) { | |||
| dy_sum = shared_dy[laneId]; | |||
| dot_p = shared_p[laneId]; | |||
| __syncwarp(); | |||
| for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { | |||
| T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset); | |||
| T other_p = __shfl_down_sync(0xffffffff, dot_p, offset); | |||
| dy_sum += other_dy; | |||
| dot_p += other_p; | |||
| } | |||
| __syncwarp(); | |||
| } | |||
| // Compute bn_scale & bn_bias | |||
| if (threadIdx.x == 0) { | |||
| bn_scale[plane] = static_cast<T>(dot_p * invstd); | |||
| } | |||
| if (threadIdx.x == 0) { | |||
| bn_bias[plane] = static_cast<T>(dy_sum); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale, | |||
| float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream) { | |||
| BatchNormGradKernel<<<C, kBlockSize, 0, cuda_stream>>>(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, | |||
| epsilon, N, C, H, W); | |||
| } | |||
| template void CalBatchNormGrad<float>(float *x, float *dy, float *scale, float *save_mean, float *save_variance, | |||
| float *dx, float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, | |||
| int W, cudaStream_t cuda_stream); | |||
| template void CalBatchNormGrad<half>(half *x, half *dy, float *scale, float *save_mean, float *save_variance, half *dx, | |||
| float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale, | |||
| float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_ | |||
| @@ -21,6 +21,7 @@ | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -66,16 +67,21 @@ class BatchNormGradGpuKernel : public GpuKernel { | |||
| // For CI only, reserved vars can not be unused. | |||
| MS_LOG(DEBUG) << reinterpret_cast<size_t>(reserve_1) << reinterpret_cast<size_t>(reserve_2); // NOLINT | |||
| const float alpha_data_diff = 1; | |||
| const float beta_data_diff = 0; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, | |||
| &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, | |||
| bn_scale, bn_bias, epsilon_, save_mean, save_variance), | |||
| "Kernel Launch Failed."); | |||
| if (is_training_) { | |||
| const float alpha_data_diff = 1; | |||
| const float beta_data_diff = 0; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, | |||
| &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, | |||
| scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance), | |||
| "Kernel Launch Failed."); | |||
| } else { | |||
| CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_, | |||
| height_, width_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| @@ -104,6 +110,7 @@ class BatchNormGradGpuKernel : public GpuKernel { | |||
| width_ = SizeToInt(shape[3]); | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL; | |||
| is_training_ = GetAttr<bool>(kernel_node, "is_training"); | |||
| epsilon_ = GetAttr<float>(kernel_node, "epsilon"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| @@ -175,6 +182,7 @@ class BatchNormGradGpuKernel : public GpuKernel { | |||
| int width_; | |||
| cudnnBatchNormMode_t mode_; | |||
| bool is_training_; | |||
| double epsilon_; | |||
| bool is_null_input_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| @@ -178,3 +178,26 @@ def test_train_stats_false_forward(): | |||
| diff = output.asnumpy() - expect_output | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_infer_backward(): | |||
| expect_output = np.array([[[[-0.3224156, -0.3840524], [1.1337637, -1.0998858]], | |||
| [[-0.1724273, -0.877854], [0.0422135, 0.5828123]], | |||
| [[-1.1006137, 1.1447179], [0.9015862, 0.5024918]]]]).astype(np.float32) | |||
| np.random.seed(1) | |||
| x_np = np.random.randn(1, 3, 2, 2).astype(np.float32) | |||
| input_grad_np = np.random.randn(1, 3, 2, 2).astype(np.float32) | |||
| ms_input = Tensor(x_np) | |||
| weight = Tensor(np.ones(3).astype(np.float32)) | |||
| bias = Tensor(np.zeros(3).astype(np.float32)) | |||
| moving_mean = Tensor(np.zeros(3).astype(np.float32)) | |||
| moving_var_init = Tensor(np.ones(3).astype(np.float32)) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| ms_net = Batchnorm_Net(3, weight, bias, moving_mean, moving_var_init) | |||
| ms_net.set_train(False) | |||
| ms_grad = Grad(ms_net) | |||
| ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np)) | |||
| assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output) | |||