From 8bfeadf26eb10b05df4decff0444f8af2a9443a4 Mon Sep 17 00:00:00 2001 From: hedongdong Date: Thu, 10 Dec 2020 14:10:24 +0800 Subject: [PATCH] Add new operator layer_norm_grad_grad --- .../cuda_impl/layer_norm_grad_grad_impl.cu | 395 ++++++++++++++++++ .../cuda_impl/layer_norm_grad_grad_impl.cuh | 28 ++ .../gpu/nn/layer_norm_grad_grad_gpu_kernel.cc | 50 +++ .../gpu/nn/layer_norm_grad_grad_gpu_kernel.h | 128 ++++++ mindspore/ops/_grad/grad_nn_ops.py | 13 + mindspore/ops/operations/_grad_ops.py | 27 ++ .../ops/gpu/test_layer_norm_grad_grad_op.py | 142 +++++++ 7 files changed, 783 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_layer_norm_grad_grad_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cu new file mode 100644 index 0000000000..acd1786fce --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cu @@ -0,0 +1,395 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; +constexpr int NUM_SHARED_SUM_INPUT = 6; +constexpr int NUM_SHARED_SUM_GAMMA = 3; + +template +inline __device__ T my_pow(T a, double b) { + return pow(a, static_cast(b)); +} + +template <> +inline __device__ half my_pow(half a, double b) { + return __float2half(pow(__half2float(a), static_cast(b))); +} + + +template +inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, + const int &mean_dim, const T &epsilon, const T *dy, const T *x, + const T *mean, const T *var, const T *grad_dx, T *part1, T *part2, + T *part3, const T *global_sum1, const T *global_sum2) { + int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int row = NUM_PER_THREAD_REDUCE * i + j; + if (row >= row_dim) { + return; + } + + int pos = row * col_dim + col; + int mean_offset = pos / mean_dim; + + T v1 = my_pow(var[mean_offset] + epsilon, -0.5); + + part1[0] += dy[pos] * v1 * (x[pos] - mean[mean_offset]) * global_sum2[pos]; + part2[0] += dy[pos] * global_sum1[pos]; + part3[0] += dy[pos] * grad_dx[pos] * v1; + } + } +} + +template +inline __device__ void GammaAndBetaWarpReduce(T *part1, T *part2, T *part3) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + part1[0] += __shfl_down_sync(0xffffffff, part1[0], delta); + part2[0] += __shfl_down_sync(0xffffffff, part2[0], delta); + part3[0] += __shfl_down_sync(0xffffffff, part3[0], delta); + } +} + +template +inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *part1, T *part2, T *part3, + T *d_gamma) { + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + DynamicSharedMem share_mem; + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem.addr()[offset] = part1[0]; + share_mem.addr()[offset + 1] = part2[0]; + share_mem.addr()[offset + 2] = part3[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + share_mem.addr()[threadIdx.x * 3] += share_mem.addr()[offset]; + share_mem.addr()[threadIdx.x * 3 + 1] += share_mem.addr()[offset + 1]; + share_mem.addr()[threadIdx.x * 3 + 2] += share_mem.addr()[offset + 2]; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + d_gamma[col] = share_mem.addr()[0] + share_mem.addr()[1] + share_mem.addr()[2]; + } +} + +template +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon, + const T *dy, const T *x, const T *mean, const T *var, const T *grad_dx, + T *d_gamma, T *global_sum1, T *global_sum2) { + for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { + T part1 = 0; + T part2 = 0; + T part3 = 0; + GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean, var, grad_dx, &part1, &part2, + &part3, global_sum1, global_sum2); + GammaAndBetaWarpReduce(&part1, &part2, &part3); + GammaAndBetaBlockReduce(col, row_dim, &part1, &part2, &part3, d_gamma); + } +} + + +template +inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col_dim, const int ¶m_dim, + const T &epsilon, T *sum1, T *sum2, T *sum3, T *sum4, + const T *dy, const T *x, const T *mean, const T *var, + const T *gamma, const T *grad_dx) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + + T v1 = x[pos] - mean[row]; + T v2 = my_pow(var[row] + epsilon, -0.5); + T v3 = v1 * v2; + T v4 = dy[pos] * gamma[gamma_offset]; + + sum1[0] -= v2 * grad_dx[pos]; + sum2[0] -= v3 * v2 * grad_dx[pos]; + sum3[0] += v4; + sum4[0] += v4 * v3; + } + } +} + +template +inline __device__ void InputWarpReduceInnerMean(T *sum1, T *sum2, T *sum3, T *sum4) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); + sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); + sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); + sum4[0] += __shfl_down_sync(0xffffffff, sum4[0], delta); + } +} + +template +inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T *sum2, T *sum3, T *sum4, T *share_mem) { + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 6; + share_mem[offset] = sum1[0]; + share_mem[offset + 1] = sum2[0]; + share_mem[offset + 2] = sum3[0]; + share_mem[offset + 3] = sum4[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 6; + + share_mem[threadIdx.x * 3] += share_mem[offset]; + share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; + share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; + share_mem[threadIdx.x * 3 + 3] += share_mem[offset + 3]; + } + } + __syncthreads(); +} + + +template +inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim, + const T &epsilon, T *sum5, T *sum6, T *share_mem, const T *dy, + const T *x, const T *mean, const T *var, const T *gamma, + const T *grad_dx, const T *grad_dg, T *d_x) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + + T v1 = x[pos] - mean[row]; + T v2 = my_pow(var[row] + epsilon, -0.5); + T v3 = dy[pos] * gamma[gamma_offset]; + T v4 = v3 * share_mem[1] * (1.0 / col_dim); + T v5 = grad_dx[pos] * v2 * share_mem[3] * (-1.0 / col_dim); + T v6 = dy[pos] * grad_dg[gamma_offset]; + T v7 = v4 + v5 + v6; + + T part1 = v1 * v7; + T part2 = v2 * v7; + d_x[pos] = part2; + + sum5[0] += part1; + sum6[0] -= part2; + } + } +} + + +template <> +inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim, + const half &epsilon, half *sum5, half *sum6, half *share_mem, + const half *dy, const half *x, const half *mean, const half *var, + const half *gamma, const half *grad_dx, const half *grad_dg, + half *d_x) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + + half v1 = x[pos] - mean[row]; + half v2 = my_pow(var[row] + epsilon, -0.5); + half v3 = dy[pos] * gamma[gamma_offset]; + half v4 = v3 * share_mem[1] * __float2half(1.0 / col_dim); + half v5 = grad_dx[pos] * v2 * share_mem[3] * __float2half(-1.0 / col_dim); + half v6 = dy[pos] * grad_dg[gamma_offset]; + half v7 = v4 + v5 + v6; + + half part1 = v1 * v7; + half part2 = v2 * v7; + d_x[pos] = part2; + + sum5[0] += part1; + sum6[0] -= part2; + } + } +} + + +template +inline __device__ void InputWarpReduceOuterMean(T *sum5, T *sum6) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + sum5[0] += __shfl_down_sync(0xffffffff, sum5[0], delta); + sum6[0] += __shfl_down_sync(0xffffffff, sum6[0], delta); + } +} + +template +inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T *sum6, T *share_mem) { + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 6; + + share_mem[offset + 4] = sum5[0]; + share_mem[offset + 5] = sum6[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 6; + + share_mem[threadIdx.x * 6 + 4] += share_mem[offset + 4]; + share_mem[threadIdx.x * 6 + 5] += share_mem[offset + 5]; + } + } + __syncthreads(); +} + + +template +inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, + const T *dy, const T *x, const T *mean, const T *var, const T *gamma, + const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, + const T *share_mem, T *global_sum1, T *global_sum2) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + + T v1 = x[pos] - mean[row]; + T v2 = my_pow(var[row] + epsilon, -0.5); + T v3 = v1 * v2; + + T part1 = gamma[gamma_offset] * grad_dx[pos] * v2; + T part2 = gamma[gamma_offset] * share_mem[0] * (1.0 / col_dim); + T part3 = gamma[gamma_offset] * v3 * share_mem[1] * (1.0 / col_dim); + T part4 = v3 * grad_dg[gamma_offset]; + d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset]; + + T part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * (-1.0 / col_dim))); + d_x[pos] += part5 + share_mem[5] * (1.0 / col_dim); + + global_sum1[pos] = share_mem[0] * (1.0 / col_dim); + global_sum2[pos] = share_mem[1] * (1.0 / col_dim); + } +} + + +template <> +inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *dy, const half *x, const half *mean, const half *var, const half *gamma, + const half *grad_dx, const half *grad_dg, const half *grad_db, half *d_dy, half *d_x, + const half *share_mem, half *global_sum1, half *global_sum2) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + + half v1 = x[pos] - mean[row]; + half v2 = my_pow(var[row] + epsilon, -0.5); + half v3 = v1 * v2; + + half part1 = gamma[gamma_offset] * grad_dx[pos] * v2; + half part2 = gamma[gamma_offset] * share_mem[0] * __float2half(1.0 / col_dim); + half part3 = gamma[gamma_offset] * v3 * share_mem[1] * __float2half(1.0 / col_dim); + half part4 = v3 * grad_dg[gamma_offset]; + d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset]; + + half part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * __float2half(-1.0 / col_dim))); + d_x[pos] += part5 + share_mem[5] * __float2half(1.0 / col_dim); + + global_sum1[pos] = share_mem[0] * __float2half(1.0 / col_dim); + global_sum2[pos] = share_mem[1] * __float2half(1.0 / col_dim); + } +} + + +template +__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, + const T *dy, const T *x, const T *mean, const T *var, const T *gamma, + const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *global_sum1, + T *global_sum2) { + for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { + T sum1 = 0; + T sum2 = 0; + T sum3 = 0; + T sum4 = 0; + T sum5 = 0; + T sum6 = 0; + DynamicSharedMem share_mem; + + InputThreadReduceInnerMean(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, &sum4, dy, x, mean, var, gamma, + grad_dx); + InputWarpReduceInnerMean(&sum1, &sum2, &sum3, &sum4); + InputBlockReduceInnerMean(col_dim, &sum1, &sum2, &sum3, &sum4, share_mem.addr()); + + InputThreadReduceOuterMean(row, col_dim, param_dim, epsilon, &sum5, &sum6, share_mem.addr(), dy, x, mean, + var, gamma, grad_dx, grad_dg, d_x); + InputWarpReduceOuterMean(&sum5, &sum6); + InputBlockReduceOuterMean(col_dim, &sum5, &sum6, share_mem.addr()); + InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, grad_dx, grad_dg, grad_db, d_dy, d_x, + share_mem.addr(), global_sum1, global_sum2); + } +} + +template +void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1, T *global_sum2, + const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, const T *gamma, + const T* grad_dx, const T* grad_dg, const T* grad_db, T *d_dy, T *d_x, T *d_gamma, + cudaStream_t stream) { + const int thread_per_block = 256; + + int share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, + mean, var, gamma, grad_dx, grad_dg, grad_db, + d_dy, d_x, global_sum1, global_sum2); + share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_GAMMA * sizeof(T); + int param_reduce_dim = row_dim * col_dim / param_dim; + GammaAndBetaPropKernel<<>>(param_reduce_dim, param_dim, + col_dim, epsilon, dy, x, mean, var, + grad_dx, d_gamma, global_sum1, + global_sum2); +} + +template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, float *global_sum1, + float *global_sum2, const float &epsilon, const float *dy, const float *x, + const float *mean, const float *var, const float *gamma, const float *grad_dx, + const float *grad_dg, const float *grad_db, float *d_dy, float *d_x, float *d_gamma, + cudaStream_t stream); +template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, half *global_sum1, + half *global_sum2, const half &epsilon, const half *dy, const half *x, const half *mean, + const half *var, const half *gamma, const half *grad_dx, const half *grad_dg, + const half *grad_db, half *d_dy, half *d_x, half *d_gamma, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh new file mode 100644 index 0000000000..7ffc93ca36 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void LayerNormGradGrad(const int& row_dim, const int& col_dim, const int& param_dim, T* global_sum1, T* global_sum2, + const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, const T* gamma, + const T* grad_dx, const T* grad_dg, const T* grad_db, T* d_dy, T* d_x, T* d_gamma, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.cc new file mode 100644 index 0000000000..96192dbd48 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNormGradGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGradGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNormGradGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGradGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h new file mode 100644 index 0000000000..c6752df9cf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGradGradGpuKernel : public GpuKernel { + public: + LayerNormGradGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1), input_size_(1) {} + ~LayerNormGradGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto dy = GetDeviceAddress(inputs, 1); + auto var = GetDeviceAddress(inputs, 2); + auto mean = GetDeviceAddress(inputs, 3); + auto gamma = GetDeviceAddress(inputs, 4); + auto grad_dx = GetDeviceAddress(inputs, 5); + auto grad_dg = GetDeviceAddress(inputs, 6); + auto grad_db = GetDeviceAddress(inputs, 7); + auto d_x = GetDeviceAddress(outputs, 0); + auto d_dy = GetDeviceAddress(outputs, 1); + auto d_gamma = GetDeviceAddress(outputs, 2); + + auto global_sum1 = GetDeviceAddress(workspace, 0); + auto global_sum2 = GetDeviceAddress(workspace, 1); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemsetAsync(global_sum1, 0, input_size_, reinterpret_cast(stream_ptr)), + "cudaMemsetAsync global_sum1 failed"); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemsetAsync(global_sum2, 0, input_size_, reinterpret_cast(stream_ptr)), + "cudaMemsetAsync global_sum2 failed"); + + const T epsilon = 10e-12; + LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon, dy, x, mean, var, gamma, + grad_dx, grad_dg, grad_db, d_dy, d_x, d_gamma, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = static_cast(GetAttr(kernel_node, "begin_norm_axis")); + int begin_params_axis = static_cast(GetAttr(kernel_node, "begin_params_axis")); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_ = input_row_ * input_col_ * sizeof(T); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(param_dim_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + output_size_list_.push_back(param_dim_ * sizeof(T)); + + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(input_size_); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; + int input_size_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index f42065d813..5f7d03ced0 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -673,6 +673,19 @@ def get_bprop_layer_norm(self): return bprop +@bprop_getters.register(G.LayerNormGrad) +def get_bprop_layer_norm_grad(self): + """Grad definition for `LayerNorm` operation.""" + layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis) + + def bprop(x, dy, variance, mean, gamma, out, dout): + d_x, d_dy, d_gamma = layer_norm_grad_grad( + x, dy, variance, mean, gamma, dout[0], dout[1], dout[2]) + return d_x, d_dy, d_gamma, zeros_like(variance), zeros_like(mean) + + return bprop + + @bprop_getters.register(P.L2Normalize) def get_bprop_l2normalize(self): """Grad definition for `L2Normalize` operation.""" diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 423d3f12d3..aa961eeb66 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1087,6 +1087,33 @@ class LayerNormGrad(Primitive): raise NotImplementedError +class LayerNormGradGrad(PrimitiveWithInfer): + """ + Gets the gradient of LayerNormGrad operation. + + Args: + begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1. + begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1. + + Returns: + tuple[int], tuple of 3 values (the gradients of layernormgrad input, dy, gamma). + """ + + @prim_attr_register + def __init__(self, begin_norm_axis=1, begin_params_axis=1): + """init""" + self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) + + def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): + raise NotImplementedError + def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): + return x, dy, gamma + + def infer_dtype(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): + return x, dy, gamma + + class LogSoftmaxGrad(PrimitiveWithInfer): """Computes gradient for the Log Softmax activation.""" diff --git a/tests/st/ops/gpu/test_layer_norm_grad_grad_op.py b/tests/st/ops/gpu/test_layer_norm_grad_grad_op.py new file mode 100644 index 0000000000..6330c060e1 --- /dev/null +++ b/tests/st/ops/gpu/test_layer_norm_grad_grad_op.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class LayerNormGradGradNet(nn.Cell): + def __init__(self, begin_norm_axis, begin_params_axis): + super(LayerNormGradGradNet, self).__init__() + self.norm = G.LayerNormGradGrad(begin_norm_axis, begin_params_axis) + + def construct(self, x, dy, var, mean, gamma, grad_dx, grad_dg, grad_db): + return self.norm(x, dy, var, mean, gamma, grad_dx, grad_dg, grad_db) + + +def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): + begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) + begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) + + norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] + param_axis = [i for i in range(0, begin_params_axis)] + num = 1 + for i in range(begin_norm_axis, len(x.shape)): + num *= x.shape[i] + + mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) + var = np.var(x, axis=tuple(norm_axis), keepdims=True) + + gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) + dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) + db = np.sum(dy, axis=tuple(param_axis), keepdims=True) + + sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), + keepdims=True) + sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) + sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) + + dx1 = dy * gamma * np.power(var + epsilon, -0.5) + dx2 = sum1 * 2.0 / num * (x - mean) + dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) + dx = dx1 + dx2 + dx3 + return dx, dg, db, mean, var + +def LayerNormGradGradReference(x, dy, gamma, epsilon, grad_dx_np, grad_dg_np, grad_db_np, begin_norm_axis, + begin_params_axis): + begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) + begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) + + norm_axis = tuple([i for i in range(begin_norm_axis, len(x.shape))]) + param_axis = [i for i in range(0, begin_params_axis)] + num = 1 + for i in range(begin_norm_axis, len(x.shape)): + num *= x.shape[i] + + mean = np.mean(x, tuple(norm_axis), keepdims=True) + var = np.mean(np.power((x - mean), 2), tuple(norm_axis), keepdims=True) + inv_std = np.power(var + epsilon, -0.5) + x_hat = (x - mean) * inv_std + + sum1 = np.mean((-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True) + sum2 = np.mean(x_hat * (-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True) + sum3 = np.mean(dy * gamma * x_hat, tuple(norm_axis), keepdims=True) + part = dy * gamma * sum2 + sum3 * (-1.0) * grad_dx_np * inv_std + dy * grad_dg_np + sum4 = np.mean((x - mean) * part, tuple(norm_axis), keepdims=True) + sum5 = np.mean(-inv_std * part, tuple(norm_axis), keepdims=True) + + part1 = inv_std * part + part2 = (x - mean) * (-1.0) * np.power(var + epsilon, -1.5) * sum4 + d_x = part1 + part2 + sum5 + + part3 = gamma * grad_dx_np * inv_std + part4 = gamma * sum1 + part5 = gamma * x_hat * sum2 + part6 = x_hat * grad_dg_np + d_dy = part3 + part4 + part5 + part6 + grad_db_np + + part7 = np.sum(dy * x_hat * sum2, tuple(param_axis), keepdims=True) + part8 = np.sum(dy * sum1, tuple(param_axis), keepdims=True) + part9 = np.sum(dy * grad_dx_np * inv_std, tuple(param_axis), keepdims=True) + d_gamma = part7 + part8 + part9 + + return d_x, d_dy, d_gamma + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgradgrad0(): + np.random.seed(1) + begin_norm_axis = 1 + begin_params_axis = 1 + + x_np = np.random.rand(2, 2, 4).astype(np.float32) + dy_np = np.random.rand(2, 2, 4).astype(np.float32) + gamma_np = np.random.rand(2, 4).astype(np.float32) + + grad_dx_np = np.random.rand(2, 2, 4).astype(np.float32) + grad_dg_np = np.random.rand(2, 4).astype(np.float32) + grad_db_np = np.random.rand(2, 4).astype(np.float32) + + epsilon = 1e-12 + _, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np, + grad_db_np, begin_norm_axis, begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + grad_dx_ms = Tensor(grad_dx_np) + grad_dg_ms = Tensor(grad_dg_np) + grad_db_ms = Tensor(grad_db_np) + + net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis) + d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms) + + assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=1e-6, atol=1e-3) + assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=1e-6, atol=1e-6) + assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=1e-6, atol=1e-3)