|
|
|
@@ -35,8 +35,8 @@ inline __device__ half my_pow(half a, double b) { |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, |
|
|
|
const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, |
|
|
|
T *dg, T *db) { |
|
|
|
const int &mean_dim, const T &epsilon, const T *dy, const T *x, |
|
|
|
const T *mean, const T *var, T *dg, T *db) { |
|
|
|
int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; |
|
|
|
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { |
|
|
|
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { |
|
|
|
@@ -46,7 +46,8 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d |
|
|
|
} |
|
|
|
|
|
|
|
int pos = row * col_dim + col; |
|
|
|
dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); |
|
|
|
int mean_offset = pos / mean_dim; |
|
|
|
dg[0] += dy[pos] * my_pow(var[mean_offset] + epsilon, -0.5) * (x[pos] - mean[mean_offset]); |
|
|
|
db[0] += dy[pos]; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -89,8 +90,9 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T *dy, const T *x, |
|
|
|
const T *mean_addr, const T *var_addr, T *dg_addr, T *db_addr) { |
|
|
|
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon, |
|
|
|
const T *dy, const T *x, const T *mean_addr, const T *var_addr, T *dg_addr, |
|
|
|
T *db_addr) { |
|
|
|
// row: [0:param_axis] |
|
|
|
// col: [param_axis:] |
|
|
|
// dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) |
|
|
|
@@ -98,7 +100,7 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con |
|
|
|
for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { |
|
|
|
T dg = 0; |
|
|
|
T db = 0; |
|
|
|
GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); |
|
|
|
GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); |
|
|
|
GammaAndBetaWarpReduce(&dg, &db); |
|
|
|
GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); |
|
|
|
} |
|
|
|
@@ -239,8 +241,12 @@ void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, |
|
|
|
mean, var, gamma, dx); |
|
|
|
|
|
|
|
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T); |
|
|
|
GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, |
|
|
|
var, dg, db); |
|
|
|
// GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, |
|
|
|
// mean, |
|
|
|
// var, dg, db); |
|
|
|
int param_reduce_dim = row_dim * col_dim / param_dim; |
|
|
|
GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim, col_dim, |
|
|
|
epsilon, dy, x, mean, var, dg, db); |
|
|
|
} |
|
|
|
|
|
|
|
template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, |
|
|
|
|