| @@ -0,0 +1,395 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdio.h> | |||
| #include <stdint.h> | |||
| #include <cuda_runtime.h> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" | |||
| constexpr int NUM_PER_THREAD_REDUCE = 4; | |||
| constexpr int WARP_SIZE = 32; | |||
| constexpr int NUM_SHARED_SUM_INPUT = 6; | |||
| constexpr int NUM_SHARED_SUM_GAMMA = 3; | |||
| template <typename T> | |||
| inline __device__ T my_pow(T a, double b) { | |||
| return pow(a, static_cast<float>(b)); | |||
| } | |||
| template <> | |||
| inline __device__ half my_pow(half a, double b) { | |||
| return __float2half(pow(__half2float(a), static_cast<float>(b))); | |||
| } | |||
| template <typename T> | |||
| inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, | |||
| const int &mean_dim, const T &epsilon, const T *dy, const T *x, | |||
| const T *mean, const T *var, const T *grad_dx, T *part1, T *part2, | |||
| T *part3, const T *global_sum1, const T *global_sum2) { | |||
| int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||
| int row = NUM_PER_THREAD_REDUCE * i + j; | |||
| if (row >= row_dim) { | |||
| return; | |||
| } | |||
| int pos = row * col_dim + col; | |||
| int mean_offset = pos / mean_dim; | |||
| T v1 = my_pow(var[mean_offset] + epsilon, -0.5); | |||
| part1[0] += dy[pos] * v1 * (x[pos] - mean[mean_offset]) * global_sum2[pos]; | |||
| part2[0] += dy[pos] * global_sum1[pos]; | |||
| part3[0] += dy[pos] * grad_dx[pos] * v1; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void GammaAndBetaWarpReduce(T *part1, T *part2, T *part3) { | |||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | |||
| part1[0] += __shfl_down_sync(0xffffffff, part1[0], delta); | |||
| part2[0] += __shfl_down_sync(0xffffffff, part2[0], delta); | |||
| part3[0] += __shfl_down_sync(0xffffffff, part3[0], delta); | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *part1, T *part2, T *part3, | |||
| T *d_gamma) { | |||
| // load data to share memory | |||
| // thread(0, 32, 64, 96, ...) keep the data | |||
| DynamicSharedMem<T> share_mem; | |||
| if (threadIdx.x % WARP_SIZE == 0) { | |||
| int offset = threadIdx.x / WARP_SIZE * 3; | |||
| share_mem.addr()[offset] = part1[0]; | |||
| share_mem.addr()[offset + 1] = part2[0]; | |||
| share_mem.addr()[offset + 2] = part3[0]; | |||
| } | |||
| __syncthreads(); | |||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||
| if (threadIdx.x < stride) { | |||
| int offset = (threadIdx.x + stride) * 3; | |||
| share_mem.addr()[threadIdx.x * 3] += share_mem.addr()[offset]; | |||
| share_mem.addr()[threadIdx.x * 3 + 1] += share_mem.addr()[offset + 1]; | |||
| share_mem.addr()[threadIdx.x * 3 + 2] += share_mem.addr()[offset + 2]; | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| if (threadIdx.x == 0) { | |||
| d_gamma[col] = share_mem.addr()[0] + share_mem.addr()[1] + share_mem.addr()[2]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon, | |||
| const T *dy, const T *x, const T *mean, const T *var, const T *grad_dx, | |||
| T *d_gamma, T *global_sum1, T *global_sum2) { | |||
| for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { | |||
| T part1 = 0; | |||
| T part2 = 0; | |||
| T part3 = 0; | |||
| GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean, var, grad_dx, &part1, &part2, | |||
| &part3, global_sum1, global_sum2); | |||
| GammaAndBetaWarpReduce(&part1, &part2, &part3); | |||
| GammaAndBetaBlockReduce(col, row_dim, &part1, &part2, &part3, d_gamma); | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col_dim, const int ¶m_dim, | |||
| const T &epsilon, T *sum1, T *sum2, T *sum3, T *sum4, | |||
| const T *dy, const T *x, const T *mean, const T *var, | |||
| const T *gamma, const T *grad_dx) { | |||
| int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||
| int col = NUM_PER_THREAD_REDUCE * i + j; | |||
| if (col >= col_dim) { | |||
| return; | |||
| } | |||
| int pos = row * col_dim + col; | |||
| int gamma_offset = pos % param_dim; | |||
| T v1 = x[pos] - mean[row]; | |||
| T v2 = my_pow(var[row] + epsilon, -0.5); | |||
| T v3 = v1 * v2; | |||
| T v4 = dy[pos] * gamma[gamma_offset]; | |||
| sum1[0] -= v2 * grad_dx[pos]; | |||
| sum2[0] -= v3 * v2 * grad_dx[pos]; | |||
| sum3[0] += v4; | |||
| sum4[0] += v4 * v3; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputWarpReduceInnerMean(T *sum1, T *sum2, T *sum3, T *sum4) { | |||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | |||
| sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); | |||
| sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); | |||
| sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); | |||
| sum4[0] += __shfl_down_sync(0xffffffff, sum4[0], delta); | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T *sum2, T *sum3, T *sum4, T *share_mem) { | |||
| // load data to share memory | |||
| // thread(0, 32, 64, 96, ...) keep the data | |||
| if (threadIdx.x % WARP_SIZE == 0) { | |||
| int offset = threadIdx.x / WARP_SIZE * 6; | |||
| share_mem[offset] = sum1[0]; | |||
| share_mem[offset + 1] = sum2[0]; | |||
| share_mem[offset + 2] = sum3[0]; | |||
| share_mem[offset + 3] = sum4[0]; | |||
| } | |||
| __syncthreads(); | |||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||
| if (threadIdx.x < stride) { | |||
| int offset = (threadIdx.x + stride) * 6; | |||
| share_mem[threadIdx.x * 3] += share_mem[offset]; | |||
| share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; | |||
| share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; | |||
| share_mem[threadIdx.x * 3 + 3] += share_mem[offset + 3]; | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim, | |||
| const T &epsilon, T *sum5, T *sum6, T *share_mem, const T *dy, | |||
| const T *x, const T *mean, const T *var, const T *gamma, | |||
| const T *grad_dx, const T *grad_dg, T *d_x) { | |||
| int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||
| int col = NUM_PER_THREAD_REDUCE * i + j; | |||
| if (col >= col_dim) { | |||
| return; | |||
| } | |||
| int pos = row * col_dim + col; | |||
| int gamma_offset = pos % param_dim; | |||
| T v1 = x[pos] - mean[row]; | |||
| T v2 = my_pow(var[row] + epsilon, -0.5); | |||
| T v3 = dy[pos] * gamma[gamma_offset]; | |||
| T v4 = v3 * share_mem[1] * (1.0 / col_dim); | |||
| T v5 = grad_dx[pos] * v2 * share_mem[3] * (-1.0 / col_dim); | |||
| T v6 = dy[pos] * grad_dg[gamma_offset]; | |||
| T v7 = v4 + v5 + v6; | |||
| T part1 = v1 * v7; | |||
| T part2 = v2 * v7; | |||
| d_x[pos] = part2; | |||
| sum5[0] += part1; | |||
| sum6[0] -= part2; | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim, | |||
| const half &epsilon, half *sum5, half *sum6, half *share_mem, | |||
| const half *dy, const half *x, const half *mean, const half *var, | |||
| const half *gamma, const half *grad_dx, const half *grad_dg, | |||
| half *d_x) { | |||
| int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||
| int col = NUM_PER_THREAD_REDUCE * i + j; | |||
| if (col >= col_dim) { | |||
| return; | |||
| } | |||
| int pos = row * col_dim + col; | |||
| int gamma_offset = pos % param_dim; | |||
| half v1 = x[pos] - mean[row]; | |||
| half v2 = my_pow(var[row] + epsilon, -0.5); | |||
| half v3 = dy[pos] * gamma[gamma_offset]; | |||
| half v4 = v3 * share_mem[1] * __float2half(1.0 / col_dim); | |||
| half v5 = grad_dx[pos] * v2 * share_mem[3] * __float2half(-1.0 / col_dim); | |||
| half v6 = dy[pos] * grad_dg[gamma_offset]; | |||
| half v7 = v4 + v5 + v6; | |||
| half part1 = v1 * v7; | |||
| half part2 = v2 * v7; | |||
| d_x[pos] = part2; | |||
| sum5[0] += part1; | |||
| sum6[0] -= part2; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputWarpReduceOuterMean(T *sum5, T *sum6) { | |||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | |||
| sum5[0] += __shfl_down_sync(0xffffffff, sum5[0], delta); | |||
| sum6[0] += __shfl_down_sync(0xffffffff, sum6[0], delta); | |||
| } | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T *sum6, T *share_mem) { | |||
| // load data to share memory | |||
| // thread(0, 32, 64, 96, ...) keep the data | |||
| if (threadIdx.x % WARP_SIZE == 0) { | |||
| int offset = threadIdx.x / WARP_SIZE * 6; | |||
| share_mem[offset + 4] = sum5[0]; | |||
| share_mem[offset + 5] = sum6[0]; | |||
| } | |||
| __syncthreads(); | |||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||
| if (threadIdx.x < stride) { | |||
| int offset = (threadIdx.x + stride) * 6; | |||
| share_mem[threadIdx.x * 6 + 4] += share_mem[offset + 4]; | |||
| share_mem[threadIdx.x * 6 + 5] += share_mem[offset + 5]; | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| template <typename T> | |||
| inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, | |||
| const T *dy, const T *x, const T *mean, const T *var, const T *gamma, | |||
| const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, | |||
| const T *share_mem, T *global_sum1, T *global_sum2) { | |||
| for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { | |||
| int pos = (row * col_dim + col); | |||
| int gamma_offset = pos % param_dim; | |||
| T v1 = x[pos] - mean[row]; | |||
| T v2 = my_pow(var[row] + epsilon, -0.5); | |||
| T v3 = v1 * v2; | |||
| T part1 = gamma[gamma_offset] * grad_dx[pos] * v2; | |||
| T part2 = gamma[gamma_offset] * share_mem[0] * (1.0 / col_dim); | |||
| T part3 = gamma[gamma_offset] * v3 * share_mem[1] * (1.0 / col_dim); | |||
| T part4 = v3 * grad_dg[gamma_offset]; | |||
| d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset]; | |||
| T part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * (-1.0 / col_dim))); | |||
| d_x[pos] += part5 + share_mem[5] * (1.0 / col_dim); | |||
| global_sum1[pos] = share_mem[0] * (1.0 / col_dim); | |||
| global_sum2[pos] = share_mem[1] * (1.0 / col_dim); | |||
| } | |||
| } | |||
| template <> | |||
| inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, | |||
| const half *dy, const half *x, const half *mean, const half *var, const half *gamma, | |||
| const half *grad_dx, const half *grad_dg, const half *grad_db, half *d_dy, half *d_x, | |||
| const half *share_mem, half *global_sum1, half *global_sum2) { | |||
| for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { | |||
| int pos = (row * col_dim + col); | |||
| int gamma_offset = pos % param_dim; | |||
| half v1 = x[pos] - mean[row]; | |||
| half v2 = my_pow(var[row] + epsilon, -0.5); | |||
| half v3 = v1 * v2; | |||
| half part1 = gamma[gamma_offset] * grad_dx[pos] * v2; | |||
| half part2 = gamma[gamma_offset] * share_mem[0] * __float2half(1.0 / col_dim); | |||
| half part3 = gamma[gamma_offset] * v3 * share_mem[1] * __float2half(1.0 / col_dim); | |||
| half part4 = v3 * grad_dg[gamma_offset]; | |||
| d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset]; | |||
| half part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * __float2half(-1.0 / col_dim))); | |||
| d_x[pos] += part5 + share_mem[5] * __float2half(1.0 / col_dim); | |||
| global_sum1[pos] = share_mem[0] * __float2half(1.0 / col_dim); | |||
| global_sum2[pos] = share_mem[1] * __float2half(1.0 / col_dim); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, | |||
| const T *dy, const T *x, const T *mean, const T *var, const T *gamma, | |||
| const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *global_sum1, | |||
| T *global_sum2) { | |||
| for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { | |||
| T sum1 = 0; | |||
| T sum2 = 0; | |||
| T sum3 = 0; | |||
| T sum4 = 0; | |||
| T sum5 = 0; | |||
| T sum6 = 0; | |||
| DynamicSharedMem<T> share_mem; | |||
| InputThreadReduceInnerMean(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, &sum4, dy, x, mean, var, gamma, | |||
| grad_dx); | |||
| InputWarpReduceInnerMean(&sum1, &sum2, &sum3, &sum4); | |||
| InputBlockReduceInnerMean(col_dim, &sum1, &sum2, &sum3, &sum4, share_mem.addr()); | |||
| InputThreadReduceOuterMean(row, col_dim, param_dim, epsilon, &sum5, &sum6, share_mem.addr(), dy, x, mean, | |||
| var, gamma, grad_dx, grad_dg, d_x); | |||
| InputWarpReduceOuterMean(&sum5, &sum6); | |||
| InputBlockReduceOuterMean(col_dim, &sum5, &sum6, share_mem.addr()); | |||
| InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, grad_dx, grad_dg, grad_db, d_dy, d_x, | |||
| share_mem.addr(), global_sum1, global_sum2); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1, T *global_sum2, | |||
| const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, const T *gamma, | |||
| const T* grad_dx, const T* grad_dg, const T* grad_db, T *d_dy, T *d_x, T *d_gamma, | |||
| cudaStream_t stream) { | |||
| const int thread_per_block = 256; | |||
| int share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T); | |||
| InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, | |||
| mean, var, gamma, grad_dx, grad_dg, grad_db, | |||
| d_dy, d_x, global_sum1, global_sum2); | |||
| share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_GAMMA * sizeof(T); | |||
| int param_reduce_dim = row_dim * col_dim / param_dim; | |||
| GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim, | |||
| col_dim, epsilon, dy, x, mean, var, | |||
| grad_dx, d_gamma, global_sum1, | |||
| global_sum2); | |||
| } | |||
| template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, float *global_sum1, | |||
| float *global_sum2, const float &epsilon, const float *dy, const float *x, | |||
| const float *mean, const float *var, const float *gamma, const float *grad_dx, | |||
| const float *grad_dg, const float *grad_db, float *d_dy, float *d_x, float *d_gamma, | |||
| cudaStream_t stream); | |||
| template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, half *global_sum1, | |||
| half *global_sum2, const half &epsilon, const half *dy, const half *x, const half *mean, | |||
| const half *var, const half *gamma, const half *grad_dx, const half *grad_dg, | |||
| const half *grad_db, half *d_dy, half *d_x, half *d_gamma, cudaStream_t stream); | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void LayerNormGradGrad(const int& row_dim, const int& col_dim, const int& param_dim, T* global_sum1, T* global_sum2, | |||
| const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, const T* gamma, | |||
| const T* grad_dx, const T* grad_dg, const T* grad_db, T* d_dy, T* d_x, T* d_gamma, | |||
| cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(LayerNormGradGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| LayerNormGradGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(LayerNormGradGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| LayerNormGradGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,128 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class LayerNormGradGradGpuKernel : public GpuKernel { | |||
| public: | |||
| LayerNormGradGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1), input_size_(1) {} | |||
| ~LayerNormGradGradGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||
| auto dy = GetDeviceAddress<T>(inputs, 1); | |||
| auto var = GetDeviceAddress<T>(inputs, 2); | |||
| auto mean = GetDeviceAddress<T>(inputs, 3); | |||
| auto gamma = GetDeviceAddress<T>(inputs, 4); | |||
| auto grad_dx = GetDeviceAddress<T>(inputs, 5); | |||
| auto grad_dg = GetDeviceAddress<T>(inputs, 6); | |||
| auto grad_db = GetDeviceAddress<T>(inputs, 7); | |||
| auto d_x = GetDeviceAddress<T>(outputs, 0); | |||
| auto d_dy = GetDeviceAddress<T>(outputs, 1); | |||
| auto d_gamma = GetDeviceAddress<T>(outputs, 2); | |||
| auto global_sum1 = GetDeviceAddress<T>(workspace, 0); | |||
| auto global_sum2 = GetDeviceAddress<T>(workspace, 1); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemsetAsync(global_sum1, 0, input_size_, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemsetAsync global_sum1 failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemsetAsync(global_sum2, 0, input_size_, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemsetAsync global_sum2 failed"); | |||
| const T epsilon = 10e-12; | |||
| LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon, dy, x, mean, var, gamma, | |||
| grad_dx, grad_dg, grad_db, d_dy, d_x, d_gamma, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| int begin_norm_axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "begin_norm_axis")); | |||
| int begin_params_axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "begin_params_axis")); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (begin_norm_axis < 0) { | |||
| begin_norm_axis += input_shape.size(); | |||
| } | |||
| if (begin_params_axis < 0) { | |||
| begin_params_axis += input_shape.size(); | |||
| } | |||
| for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { | |||
| input_row_ *= input_shape[i]; | |||
| } | |||
| for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { | |||
| input_col_ *= input_shape[i]; | |||
| } | |||
| for (size_t i = begin_params_axis; i < input_shape.size(); i++) { | |||
| param_dim_ *= input_shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_ = input_row_ * input_col_ * sizeof(T); | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(input_row_ * sizeof(T)); | |||
| input_size_list_.push_back(input_row_ * sizeof(T)); | |||
| input_size_list_.push_back(param_dim_ * sizeof(T)); | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(param_dim_ * sizeof(T)); | |||
| input_size_list_.push_back(param_dim_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(param_dim_ * sizeof(T)); | |||
| workspace_size_list_.push_back(input_size_); | |||
| workspace_size_list_.push_back(input_size_); | |||
| return; | |||
| } | |||
| private: | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int input_row_; | |||
| int input_col_; | |||
| int param_dim_; | |||
| int input_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_ | |||
| @@ -673,6 +673,19 @@ def get_bprop_layer_norm(self): | |||
| return bprop | |||
| @bprop_getters.register(G.LayerNormGrad) | |||
| def get_bprop_layer_norm_grad(self): | |||
| """Grad definition for `LayerNorm` operation.""" | |||
| layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis) | |||
| def bprop(x, dy, variance, mean, gamma, out, dout): | |||
| d_x, d_dy, d_gamma = layer_norm_grad_grad( | |||
| x, dy, variance, mean, gamma, dout[0], dout[1], dout[2]) | |||
| return d_x, d_dy, d_gamma, zeros_like(variance), zeros_like(mean) | |||
| return bprop | |||
| @bprop_getters.register(P.L2Normalize) | |||
| def get_bprop_l2normalize(self): | |||
| """Grad definition for `L2Normalize` operation.""" | |||
| @@ -1087,6 +1087,33 @@ class LayerNormGrad(Primitive): | |||
| raise NotImplementedError | |||
| class LayerNormGradGrad(PrimitiveWithInfer): | |||
| """ | |||
| Gets the gradient of LayerNormGrad operation. | |||
| Args: | |||
| begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1. | |||
| begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1. | |||
| Returns: | |||
| tuple[int], tuple of 3 values (the gradients of layernormgrad input, dy, gamma). | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, begin_norm_axis=1, begin_params_axis=1): | |||
| """init""" | |||
| self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) | |||
| self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) | |||
| def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): | |||
| raise NotImplementedError | |||
| def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): | |||
| return x, dy, gamma | |||
| def infer_dtype(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): | |||
| return x, dy, gamma | |||
| class LogSoftmaxGrad(PrimitiveWithInfer): | |||
| """Computes gradient for the Log Softmax activation.""" | |||
| @@ -0,0 +1,142 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class LayerNormGradGradNet(nn.Cell): | |||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||
| super(LayerNormGradGradNet, self).__init__() | |||
| self.norm = G.LayerNormGradGrad(begin_norm_axis, begin_params_axis) | |||
| def construct(self, x, dy, var, mean, gamma, grad_dx, grad_dg, grad_db): | |||
| return self.norm(x, dy, var, mean, gamma, grad_dx, grad_dg, grad_db) | |||
| def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) | |||
| norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] | |||
| param_axis = [i for i in range(0, begin_params_axis)] | |||
| num = 1 | |||
| for i in range(begin_norm_axis, len(x.shape)): | |||
| num *= x.shape[i] | |||
| mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) | |||
| var = np.var(x, axis=tuple(norm_axis), keepdims=True) | |||
| gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) | |||
| dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) | |||
| db = np.sum(dy, axis=tuple(param_axis), keepdims=True) | |||
| sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), | |||
| keepdims=True) | |||
| sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) | |||
| sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) | |||
| dx1 = dy * gamma * np.power(var + epsilon, -0.5) | |||
| dx2 = sum1 * 2.0 / num * (x - mean) | |||
| dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) | |||
| dx = dx1 + dx2 + dx3 | |||
| return dx, dg, db, mean, var | |||
| def LayerNormGradGradReference(x, dy, gamma, epsilon, grad_dx_np, grad_dg_np, grad_db_np, begin_norm_axis, | |||
| begin_params_axis): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) | |||
| norm_axis = tuple([i for i in range(begin_norm_axis, len(x.shape))]) | |||
| param_axis = [i for i in range(0, begin_params_axis)] | |||
| num = 1 | |||
| for i in range(begin_norm_axis, len(x.shape)): | |||
| num *= x.shape[i] | |||
| mean = np.mean(x, tuple(norm_axis), keepdims=True) | |||
| var = np.mean(np.power((x - mean), 2), tuple(norm_axis), keepdims=True) | |||
| inv_std = np.power(var + epsilon, -0.5) | |||
| x_hat = (x - mean) * inv_std | |||
| sum1 = np.mean((-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True) | |||
| sum2 = np.mean(x_hat * (-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True) | |||
| sum3 = np.mean(dy * gamma * x_hat, tuple(norm_axis), keepdims=True) | |||
| part = dy * gamma * sum2 + sum3 * (-1.0) * grad_dx_np * inv_std + dy * grad_dg_np | |||
| sum4 = np.mean((x - mean) * part, tuple(norm_axis), keepdims=True) | |||
| sum5 = np.mean(-inv_std * part, tuple(norm_axis), keepdims=True) | |||
| part1 = inv_std * part | |||
| part2 = (x - mean) * (-1.0) * np.power(var + epsilon, -1.5) * sum4 | |||
| d_x = part1 + part2 + sum5 | |||
| part3 = gamma * grad_dx_np * inv_std | |||
| part4 = gamma * sum1 | |||
| part5 = gamma * x_hat * sum2 | |||
| part6 = x_hat * grad_dg_np | |||
| d_dy = part3 + part4 + part5 + part6 + grad_db_np | |||
| part7 = np.sum(dy * x_hat * sum2, tuple(param_axis), keepdims=True) | |||
| part8 = np.sum(dy * sum1, tuple(param_axis), keepdims=True) | |||
| part9 = np.sum(dy * grad_dx_np * inv_std, tuple(param_axis), keepdims=True) | |||
| d_gamma = part7 + part8 + part9 | |||
| return d_x, d_dy, d_gamma | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgradgrad0(): | |||
| np.random.seed(1) | |||
| begin_norm_axis = 1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.rand(2, 2, 4).astype(np.float32) | |||
| dy_np = np.random.rand(2, 2, 4).astype(np.float32) | |||
| gamma_np = np.random.rand(2, 4).astype(np.float32) | |||
| grad_dx_np = np.random.rand(2, 2, 4).astype(np.float32) | |||
| grad_dg_np = np.random.rand(2, 4).astype(np.float32) | |||
| grad_db_np = np.random.rand(2, 4).astype(np.float32) | |||
| epsilon = 1e-12 | |||
| _, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np, | |||
| grad_db_np, begin_norm_axis, begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| grad_dx_ms = Tensor(grad_dx_np) | |||
| grad_dg_ms = Tensor(grad_dg_np) | |||
| grad_db_ms = Tensor(grad_db_np) | |||
| net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis) | |||
| d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms) | |||
| assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=1e-6, atol=1e-3) | |||
| assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=1e-6, atol=1e-6) | |||
| assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=1e-6, atol=1e-3) | |||