Browse Source

Add new operator layer_norm_grad_grad

tags/v1.1.0
hedongdong 5 years ago
parent
commit
8bfeadf26e
7 changed files with 783 additions and 0 deletions
  1. +395
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cu
  2. +28
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh
  3. +50
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.cc
  4. +128
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h
  5. +13
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  6. +27
    -0
      mindspore/ops/operations/_grad_ops.py
  7. +142
    -0
      tests/st/ops/gpu/test_layer_norm_grad_grad_op.py

+ 395
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cu View File

@@ -0,0 +1,395 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <stdio.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh"

constexpr int NUM_PER_THREAD_REDUCE = 4;
constexpr int WARP_SIZE = 32;
constexpr int NUM_SHARED_SUM_INPUT = 6;
constexpr int NUM_SHARED_SUM_GAMMA = 3;

template <typename T>
inline __device__ T my_pow(T a, double b) {
return pow(a, static_cast<float>(b));
}

template <>
inline __device__ half my_pow(half a, double b) {
return __float2half(pow(__half2float(a), static_cast<float>(b)));
}


template <typename T>
inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim,
const int &mean_dim, const T &epsilon, const T *dy, const T *x,
const T *mean, const T *var, const T *grad_dx, T *part1, T *part2,
T *part3, const T *global_sum1, const T *global_sum2) {
int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
int row = NUM_PER_THREAD_REDUCE * i + j;
if (row >= row_dim) {
return;
}

int pos = row * col_dim + col;
int mean_offset = pos / mean_dim;

T v1 = my_pow(var[mean_offset] + epsilon, -0.5);

part1[0] += dy[pos] * v1 * (x[pos] - mean[mean_offset]) * global_sum2[pos];
part2[0] += dy[pos] * global_sum1[pos];
part3[0] += dy[pos] * grad_dx[pos] * v1;
}
}
}

template <typename T>
inline __device__ void GammaAndBetaWarpReduce(T *part1, T *part2, T *part3) {
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
part1[0] += __shfl_down_sync(0xffffffff, part1[0], delta);
part2[0] += __shfl_down_sync(0xffffffff, part2[0], delta);
part3[0] += __shfl_down_sync(0xffffffff, part3[0], delta);
}
}

template <typename T>
inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *part1, T *part2, T *part3,
T *d_gamma) {
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
DynamicSharedMem<T> share_mem;
if (threadIdx.x % WARP_SIZE == 0) {
int offset = threadIdx.x / WARP_SIZE * 3;
share_mem.addr()[offset] = part1[0];
share_mem.addr()[offset + 1] = part2[0];
share_mem.addr()[offset + 2] = part3[0];
}
__syncthreads();

for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
int offset = (threadIdx.x + stride) * 3;
share_mem.addr()[threadIdx.x * 3] += share_mem.addr()[offset];
share_mem.addr()[threadIdx.x * 3 + 1] += share_mem.addr()[offset + 1];
share_mem.addr()[threadIdx.x * 3 + 2] += share_mem.addr()[offset + 2];
}
}
__syncthreads();

if (threadIdx.x == 0) {
d_gamma[col] = share_mem.addr()[0] + share_mem.addr()[1] + share_mem.addr()[2];
}
}

template <typename T>
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon,
const T *dy, const T *x, const T *mean, const T *var, const T *grad_dx,
T *d_gamma, T *global_sum1, T *global_sum2) {
for (int col = blockIdx.x; col < col_dim; col += gridDim.x) {
T part1 = 0;
T part2 = 0;
T part3 = 0;
GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean, var, grad_dx, &part1, &part2,
&part3, global_sum1, global_sum2);
GammaAndBetaWarpReduce(&part1, &part2, &part3);
GammaAndBetaBlockReduce(col, row_dim, &part1, &part2, &part3, d_gamma);
}
}


template <typename T>
inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col_dim, const int &param_dim,
const T &epsilon, T *sum1, T *sum2, T *sum3, T *sum4,
const T *dy, const T *x, const T *mean, const T *var,
const T *gamma, const T *grad_dx) {
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
int col = NUM_PER_THREAD_REDUCE * i + j;
if (col >= col_dim) {
return;
}
int pos = row * col_dim + col;
int gamma_offset = pos % param_dim;

T v1 = x[pos] - mean[row];
T v2 = my_pow(var[row] + epsilon, -0.5);
T v3 = v1 * v2;
T v4 = dy[pos] * gamma[gamma_offset];

sum1[0] -= v2 * grad_dx[pos];
sum2[0] -= v3 * v2 * grad_dx[pos];
sum3[0] += v4;
sum4[0] += v4 * v3;
}
}
}

template <typename T>
inline __device__ void InputWarpReduceInnerMean(T *sum1, T *sum2, T *sum3, T *sum4) {
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta);
sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta);
sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta);
sum4[0] += __shfl_down_sync(0xffffffff, sum4[0], delta);
}
}

template <typename T>
inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T *sum2, T *sum3, T *sum4, T *share_mem) {
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
if (threadIdx.x % WARP_SIZE == 0) {
int offset = threadIdx.x / WARP_SIZE * 6;
share_mem[offset] = sum1[0];
share_mem[offset + 1] = sum2[0];
share_mem[offset + 2] = sum3[0];
share_mem[offset + 3] = sum4[0];
}
__syncthreads();

for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
int offset = (threadIdx.x + stride) * 6;

share_mem[threadIdx.x * 3] += share_mem[offset];
share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1];
share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2];
share_mem[threadIdx.x * 3 + 3] += share_mem[offset + 3];
}
}
__syncthreads();
}


template <typename T>
inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int &param_dim,
const T &epsilon, T *sum5, T *sum6, T *share_mem, const T *dy,
const T *x, const T *mean, const T *var, const T *gamma,
const T *grad_dx, const T *grad_dg, T *d_x) {
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
int col = NUM_PER_THREAD_REDUCE * i + j;
if (col >= col_dim) {
return;
}
int pos = row * col_dim + col;
int gamma_offset = pos % param_dim;

T v1 = x[pos] - mean[row];
T v2 = my_pow(var[row] + epsilon, -0.5);
T v3 = dy[pos] * gamma[gamma_offset];
T v4 = v3 * share_mem[1] * (1.0 / col_dim);
T v5 = grad_dx[pos] * v2 * share_mem[3] * (-1.0 / col_dim);
T v6 = dy[pos] * grad_dg[gamma_offset];
T v7 = v4 + v5 + v6;

T part1 = v1 * v7;
T part2 = v2 * v7;
d_x[pos] = part2;

sum5[0] += part1;
sum6[0] -= part2;
}
}
}


template <>
inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int &param_dim,
const half &epsilon, half *sum5, half *sum6, half *share_mem,
const half *dy, const half *x, const half *mean, const half *var,
const half *gamma, const half *grad_dx, const half *grad_dg,
half *d_x) {
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
int col = NUM_PER_THREAD_REDUCE * i + j;
if (col >= col_dim) {
return;
}
int pos = row * col_dim + col;
int gamma_offset = pos % param_dim;

half v1 = x[pos] - mean[row];
half v2 = my_pow(var[row] + epsilon, -0.5);
half v3 = dy[pos] * gamma[gamma_offset];
half v4 = v3 * share_mem[1] * __float2half(1.0 / col_dim);
half v5 = grad_dx[pos] * v2 * share_mem[3] * __float2half(-1.0 / col_dim);
half v6 = dy[pos] * grad_dg[gamma_offset];
half v7 = v4 + v5 + v6;

half part1 = v1 * v7;
half part2 = v2 * v7;
d_x[pos] = part2;

sum5[0] += part1;
sum6[0] -= part2;
}
}
}


template <typename T>
inline __device__ void InputWarpReduceOuterMean(T *sum5, T *sum6) {
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
sum5[0] += __shfl_down_sync(0xffffffff, sum5[0], delta);
sum6[0] += __shfl_down_sync(0xffffffff, sum6[0], delta);
}
}

template <typename T>
inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T *sum6, T *share_mem) {
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
if (threadIdx.x % WARP_SIZE == 0) {
int offset = threadIdx.x / WARP_SIZE * 6;

share_mem[offset + 4] = sum5[0];
share_mem[offset + 5] = sum6[0];
}
__syncthreads();

for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
int offset = (threadIdx.x + stride) * 6;

share_mem[threadIdx.x * 6 + 4] += share_mem[offset + 4];
share_mem[threadIdx.x * 6 + 5] += share_mem[offset + 5];
}
}
__syncthreads();
}


template <typename T>
inline __device__ void InputProp(const int &row, const int &col_dim, const int &param_dim, const T &epsilon,
const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x,
const T *share_mem, T *global_sum1, T *global_sum2) {
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
int pos = (row * col_dim + col);
int gamma_offset = pos % param_dim;

T v1 = x[pos] - mean[row];
T v2 = my_pow(var[row] + epsilon, -0.5);
T v3 = v1 * v2;

T part1 = gamma[gamma_offset] * grad_dx[pos] * v2;
T part2 = gamma[gamma_offset] * share_mem[0] * (1.0 / col_dim);
T part3 = gamma[gamma_offset] * v3 * share_mem[1] * (1.0 / col_dim);
T part4 = v3 * grad_dg[gamma_offset];
d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset];

T part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * (-1.0 / col_dim)));
d_x[pos] += part5 + share_mem[5] * (1.0 / col_dim);

global_sum1[pos] = share_mem[0] * (1.0 / col_dim);
global_sum2[pos] = share_mem[1] * (1.0 / col_dim);
}
}


template <>
inline __device__ void InputProp(const int &row, const int &col_dim, const int &param_dim, const half &epsilon,
const half *dy, const half *x, const half *mean, const half *var, const half *gamma,
const half *grad_dx, const half *grad_dg, const half *grad_db, half *d_dy, half *d_x,
const half *share_mem, half *global_sum1, half *global_sum2) {
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
int pos = (row * col_dim + col);
int gamma_offset = pos % param_dim;

half v1 = x[pos] - mean[row];
half v2 = my_pow(var[row] + epsilon, -0.5);
half v3 = v1 * v2;

half part1 = gamma[gamma_offset] * grad_dx[pos] * v2;
half part2 = gamma[gamma_offset] * share_mem[0] * __float2half(1.0 / col_dim);
half part3 = gamma[gamma_offset] * v3 * share_mem[1] * __float2half(1.0 / col_dim);
half part4 = v3 * grad_dg[gamma_offset];
d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset];

half part5 = v1 * (my_pow(var[row] + epsilon, -1.5) * (share_mem[4] * __float2half(-1.0 / col_dim)));
d_x[pos] += part5 + share_mem[5] * __float2half(1.0 / col_dim);

global_sum1[pos] = share_mem[0] * __float2half(1.0 / col_dim);
global_sum2[pos] = share_mem[1] * __float2half(1.0 / col_dim);
}
}


template <typename T>
__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon,
const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *global_sum1,
T *global_sum2) {
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
T sum1 = 0;
T sum2 = 0;
T sum3 = 0;
T sum4 = 0;
T sum5 = 0;
T sum6 = 0;
DynamicSharedMem<T> share_mem;

InputThreadReduceInnerMean(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, &sum4, dy, x, mean, var, gamma,
grad_dx);
InputWarpReduceInnerMean(&sum1, &sum2, &sum3, &sum4);
InputBlockReduceInnerMean(col_dim, &sum1, &sum2, &sum3, &sum4, share_mem.addr());

InputThreadReduceOuterMean(row, col_dim, param_dim, epsilon, &sum5, &sum6, share_mem.addr(), dy, x, mean,
var, gamma, grad_dx, grad_dg, d_x);
InputWarpReduceOuterMean(&sum5, &sum6);
InputBlockReduceOuterMean(col_dim, &sum5, &sum6, share_mem.addr());
InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, grad_dx, grad_dg, grad_db, d_dy, d_x,
share_mem.addr(), global_sum1, global_sum2);
}
}

template <typename T>
void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int &param_dim, T *global_sum1, T *global_sum2,
const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
const T* grad_dx, const T* grad_dg, const T* grad_db, T *d_dy, T *d_x, T *d_gamma,
cudaStream_t stream) {
const int thread_per_block = 256;

int share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T);
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
mean, var, gamma, grad_dx, grad_dg, grad_db,
d_dy, d_x, global_sum1, global_sum2);
share_mem_size = thread_per_block / WARP_SIZE * NUM_SHARED_SUM_GAMMA * sizeof(T);
int param_reduce_dim = row_dim * col_dim / param_dim;
GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim,
col_dim, epsilon, dy, x, mean, var,
grad_dx, d_gamma, global_sum1,
global_sum2);
}

template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int &param_dim, float *global_sum1,
float *global_sum2, const float &epsilon, const float *dy, const float *x,
const float *mean, const float *var, const float *gamma, const float *grad_dx,
const float *grad_dg, const float *grad_db, float *d_dy, float *d_x, float *d_gamma,
cudaStream_t stream);
template void LayerNormGradGrad(const int &row_dim, const int &col_dim, const int &param_dim, half *global_sum1,
half *global_sum2, const half &epsilon, const half *dy, const half *x, const half *mean,
const half *var, const half *gamma, const half *grad_dx, const half *grad_dg,
const half *grad_db, half *d_dy, half *d_x, half *d_gamma, cudaStream_t stream);

+ 28
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh View File

@@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_

#include "runtime/device/gpu/cuda_common.h"

template <typename T>
void LayerNormGradGrad(const int& row_dim, const int& col_dim, const int& param_dim, T* global_sum1, T* global_sum2,
const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, const T* gamma,
const T* grad_dx, const T* grad_dg, const T* grad_db, T* d_dy, T* d_x, T* d_gamma,
cudaStream_t stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_GRAD_H_

+ 50
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.cc View File

@@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(LayerNormGradGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
LayerNormGradGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(LayerNormGradGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
LayerNormGradGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

+ 128
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_grad_gpu_kernel.h View File

@@ -0,0 +1,128 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_

#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_grad_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T>
class LayerNormGradGradGpuKernel : public GpuKernel {
public:
LayerNormGradGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1), input_size_(1) {}
~LayerNormGradGradGpuKernel() override = default;

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto x = GetDeviceAddress<T>(inputs, 0);
auto dy = GetDeviceAddress<T>(inputs, 1);
auto var = GetDeviceAddress<T>(inputs, 2);
auto mean = GetDeviceAddress<T>(inputs, 3);
auto gamma = GetDeviceAddress<T>(inputs, 4);
auto grad_dx = GetDeviceAddress<T>(inputs, 5);
auto grad_dg = GetDeviceAddress<T>(inputs, 6);
auto grad_db = GetDeviceAddress<T>(inputs, 7);
auto d_x = GetDeviceAddress<T>(outputs, 0);
auto d_dy = GetDeviceAddress<T>(outputs, 1);
auto d_gamma = GetDeviceAddress<T>(outputs, 2);

auto global_sum1 = GetDeviceAddress<T>(workspace, 0);
auto global_sum2 = GetDeviceAddress<T>(workspace, 1);

CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemsetAsync(global_sum1, 0, input_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync global_sum1 failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemsetAsync(global_sum2, 0, input_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync global_sum2 failed");

const T epsilon = 10e-12;
LayerNormGradGrad(input_row_, input_col_, param_dim_, global_sum1, global_sum2, epsilon, dy, x, mean, var, gamma,
grad_dx, grad_dg, grad_db, d_dy, d_x, d_gamma, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
int begin_norm_axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "begin_norm_axis"));
int begin_params_axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "begin_params_axis"));

auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (begin_norm_axis < 0) {
begin_norm_axis += input_shape.size();
}

if (begin_params_axis < 0) {
begin_params_axis += input_shape.size();
}

for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) {
input_row_ *= input_shape[i];
}

for (size_t i = begin_norm_axis; i < input_shape.size(); i++) {
input_col_ *= input_shape[i];
}

for (size_t i = begin_params_axis; i < input_shape.size(); i++) {
param_dim_ *= input_shape[i];
}

InitSizeLists();
return true;
}

protected:
void InitSizeLists() override {
input_size_ = input_row_ * input_col_ * sizeof(T);
input_size_list_.push_back(input_size_);
input_size_list_.push_back(input_size_);
input_size_list_.push_back(input_row_ * sizeof(T));
input_size_list_.push_back(input_row_ * sizeof(T));
input_size_list_.push_back(param_dim_ * sizeof(T));
input_size_list_.push_back(input_size_);
input_size_list_.push_back(param_dim_ * sizeof(T));
input_size_list_.push_back(param_dim_ * sizeof(T));

output_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
output_size_list_.push_back(param_dim_ * sizeof(T));

workspace_size_list_.push_back(input_size_);
workspace_size_list_.push_back(input_size_);
return;
}

private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

int input_row_;
int input_col_;
int param_dim_;
int input_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAYER_NORM_GRAD_GRAD_GPU_KERNEL_H_

+ 13
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -673,6 +673,19 @@ def get_bprop_layer_norm(self):
return bprop


@bprop_getters.register(G.LayerNormGrad)
def get_bprop_layer_norm_grad(self):
"""Grad definition for `LayerNorm` operation."""
layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)

def bprop(x, dy, variance, mean, gamma, out, dout):
d_x, d_dy, d_gamma = layer_norm_grad_grad(
x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
return d_x, d_dy, d_gamma, zeros_like(variance), zeros_like(mean)

return bprop


@bprop_getters.register(P.L2Normalize)
def get_bprop_l2normalize(self):
"""Grad definition for `L2Normalize` operation."""


+ 27
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -1087,6 +1087,33 @@ class LayerNormGrad(Primitive):
raise NotImplementedError


class LayerNormGradGrad(PrimitiveWithInfer):
"""
Gets the gradient of LayerNormGrad operation.

Args:
begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.

Returns:
tuple[int], tuple of 3 values (the gradients of layernormgrad input, dy, gamma).
"""

@prim_attr_register
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
"""init"""
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)

def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
raise NotImplementedError
def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
return x, dy, gamma

def infer_dtype(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
return x, dy, gamma


class LogSoftmaxGrad(PrimitiveWithInfer):
"""Computes gradient for the Log Softmax activation."""



+ 142
- 0
tests/st/ops/gpu/test_layer_norm_grad_grad_op.py View File

@@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


class LayerNormGradGradNet(nn.Cell):
def __init__(self, begin_norm_axis, begin_params_axis):
super(LayerNormGradGradNet, self).__init__()
self.norm = G.LayerNormGradGrad(begin_norm_axis, begin_params_axis)

def construct(self, x, dy, var, mean, gamma, grad_dx, grad_dg, grad_db):
return self.norm(x, dy, var, mean, gamma, grad_dx, grad_dg, grad_db)


def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)

norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]

mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
var = np.var(x, axis=tuple(norm_axis), keepdims=True)

gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
db = np.sum(dy, axis=tuple(param_axis), keepdims=True)

sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
keepdims=True)
sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)

dx1 = dy * gamma * np.power(var + epsilon, -0.5)
dx2 = sum1 * 2.0 / num * (x - mean)
dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
dx = dx1 + dx2 + dx3
return dx, dg, db, mean, var

def LayerNormGradGradReference(x, dy, gamma, epsilon, grad_dx_np, grad_dg_np, grad_db_np, begin_norm_axis,
begin_params_axis):
begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)

norm_axis = tuple([i for i in range(begin_norm_axis, len(x.shape))])
param_axis = [i for i in range(0, begin_params_axis)]
num = 1
for i in range(begin_norm_axis, len(x.shape)):
num *= x.shape[i]

mean = np.mean(x, tuple(norm_axis), keepdims=True)
var = np.mean(np.power((x - mean), 2), tuple(norm_axis), keepdims=True)
inv_std = np.power(var + epsilon, -0.5)
x_hat = (x - mean) * inv_std

sum1 = np.mean((-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True)
sum2 = np.mean(x_hat * (-1.0) * inv_std * grad_dx_np, tuple(norm_axis), keepdims=True)
sum3 = np.mean(dy * gamma * x_hat, tuple(norm_axis), keepdims=True)
part = dy * gamma * sum2 + sum3 * (-1.0) * grad_dx_np * inv_std + dy * grad_dg_np
sum4 = np.mean((x - mean) * part, tuple(norm_axis), keepdims=True)
sum5 = np.mean(-inv_std * part, tuple(norm_axis), keepdims=True)

part1 = inv_std * part
part2 = (x - mean) * (-1.0) * np.power(var + epsilon, -1.5) * sum4
d_x = part1 + part2 + sum5

part3 = gamma * grad_dx_np * inv_std
part4 = gamma * sum1
part5 = gamma * x_hat * sum2
part6 = x_hat * grad_dg_np
d_dy = part3 + part4 + part5 + part6 + grad_db_np

part7 = np.sum(dy * x_hat * sum2, tuple(param_axis), keepdims=True)
part8 = np.sum(dy * sum1, tuple(param_axis), keepdims=True)
part9 = np.sum(dy * grad_dx_np * inv_std, tuple(param_axis), keepdims=True)
d_gamma = part7 + part8 + part9

return d_x, d_dy, d_gamma


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgradgrad0():
np.random.seed(1)
begin_norm_axis = 1
begin_params_axis = 1

x_np = np.random.rand(2, 2, 4).astype(np.float32)
dy_np = np.random.rand(2, 2, 4).astype(np.float32)
gamma_np = np.random.rand(2, 4).astype(np.float32)

grad_dx_np = np.random.rand(2, 2, 4).astype(np.float32)
grad_dg_np = np.random.rand(2, 4).astype(np.float32)
grad_db_np = np.random.rand(2, 4).astype(np.float32)

epsilon = 1e-12
_, _, _, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)

d_x_np, d_dy_np, d_gamma_np = LayerNormGradGradReference(x_np, dy_np, gamma_np, epsilon, grad_dx_np, grad_dg_np,
grad_db_np, begin_norm_axis, begin_params_axis)

dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
grad_dx_ms = Tensor(grad_dx_np)
grad_dg_ms = Tensor(grad_dg_np)
grad_db_ms = Tensor(grad_db_np)

net = LayerNormGradGradNet(begin_norm_axis, begin_params_axis)
d_x_ms, d_dy_ms, d_gamma_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms, grad_dx_ms, grad_dg_ms, grad_db_ms)

assert np.allclose(d_x_ms.asnumpy(), d_x_np, rtol=1e-6, atol=1e-3)
assert np.allclose(d_dy_ms.asnumpy(), d_dy_np, rtol=1e-6, atol=1e-6)
assert np.allclose(d_gamma_ms.asnumpy(), d_gamma_np, rtol=1e-6, atol=1e-3)

Loading…
Cancel
Save