|
|
|
@@ -34,9 +34,9 @@ inline __device__ half my_pow(half a, double b) { |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, |
|
|
|
const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, |
|
|
|
T* dg, T* db) { |
|
|
|
inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, |
|
|
|
const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, |
|
|
|
T *dg, T *db) { |
|
|
|
int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; |
|
|
|
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { |
|
|
|
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { |
|
|
|
@@ -53,7 +53,7 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { |
|
|
|
inline __device__ void GammaAndBetaWarpReduce(T *dg, T *db) { |
|
|
|
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { |
|
|
|
dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); |
|
|
|
db[0] += __shfl_down_sync(0xffffffff, db[0], delta); |
|
|
|
@@ -61,12 +61,8 @@ inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, |
|
|
|
T* db_addr) { |
|
|
|
if (threadIdx.x >= row_dim) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr, |
|
|
|
T *db_addr) { |
|
|
|
// load data to share memory |
|
|
|
// thread(0, 32, 64, 96, ...) keep the data |
|
|
|
DynamicSharedMem<T> share_mem; |
|
|
|
@@ -93,8 +89,8 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, |
|
|
|
const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { |
|
|
|
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T *dy, const T *x, |
|
|
|
const T *mean_addr, const T *var_addr, T *dg_addr, T *db_addr) { |
|
|
|
// row: [0:param_axis] |
|
|
|
// col: [param_axis:] |
|
|
|
// dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) |
|
|
|
@@ -109,9 +105,9 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, |
|
|
|
T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, |
|
|
|
const T* var, const T* gamma) { |
|
|
|
inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, |
|
|
|
T *sum1, T *sum2, T *sum3, const T *dy, const T *x, const T *mean, |
|
|
|
const T *var, const T *gamma) { |
|
|
|
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; |
|
|
|
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { |
|
|
|
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { |
|
|
|
@@ -133,9 +129,9 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, |
|
|
|
half* sum1, half* sum2, half* sum3, const half* dy, const half* x, |
|
|
|
const half* mean, const half* var, const half* gamma) { |
|
|
|
inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, |
|
|
|
half *sum1, half *sum2, half *sum3, const half *dy, const half *x, |
|
|
|
const half *mean, const half *var, const half *gamma) { |
|
|
|
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; |
|
|
|
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { |
|
|
|
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { |
|
|
|
@@ -157,7 +153,7 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { |
|
|
|
inline __device__ void InputWarpReduce(T *sum1, T *sum2, T *sum3) { |
|
|
|
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { |
|
|
|
sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); |
|
|
|
sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); |
|
|
|
@@ -166,11 +162,7 @@ inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { |
|
|
|
if (threadIdx.x >= col_dim) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
inline __device__ void InputBlockReduce(const int &col_dim, T *sum1, T *sum2, T *sum3, T *share_mem) { |
|
|
|
// load data to share memory |
|
|
|
// thread(0, 32, 64, 96, ...) keep the data |
|
|
|
if (threadIdx.x % WARP_SIZE == 0) { |
|
|
|
@@ -193,9 +185,9 @@ inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, |
|
|
|
const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, |
|
|
|
const T* share_mem) { |
|
|
|
inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, |
|
|
|
const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx, |
|
|
|
const T *share_mem) { |
|
|
|
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { |
|
|
|
int pos = (row * col_dim + col); |
|
|
|
int gamma_offset = pos % param_dim; |
|
|
|
@@ -208,9 +200,9 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, |
|
|
|
const half* dy, const half* x, const half* mean, const half* var, const half* gamma, |
|
|
|
half* dx, const half* share_mem) { |
|
|
|
inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, |
|
|
|
const half *dy, const half *x, const half *mean, const half *var, const half *gamma, |
|
|
|
half *dx, const half *share_mem) { |
|
|
|
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { |
|
|
|
int pos = (row * col_dim + col); |
|
|
|
int gamma_offset = pos % param_dim; |
|
|
|
@@ -218,14 +210,14 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& |
|
|
|
half v2 = x[pos] - mean[row]; |
|
|
|
half v3 = my_pow(var[row] + epsilon, -0.5); |
|
|
|
dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + |
|
|
|
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ |
|
|
|
* __float2half(1.0 / col_dim); |
|
|
|
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2]) * |
|
|
|
__float2half(1.0 / col_dim); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, |
|
|
|
const T* x, const T* mean, const T* var, const T* gamma, T* dx) { |
|
|
|
__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *dy, |
|
|
|
const T *x, const T *mean, const T *var, const T *gamma, T *dx) { |
|
|
|
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { |
|
|
|
T sum1 = 0; |
|
|
|
T sum2 = 0; |
|
|
|
@@ -239,21 +231,21 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, |
|
|
|
const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { |
|
|
|
int share_mem_size = |
|
|
|
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); |
|
|
|
InputPropKernel<<<row_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, |
|
|
|
gamma, dx); |
|
|
|
|
|
|
|
share_mem_size = |
|
|
|
((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); |
|
|
|
GammaAndBetaPropKernel<<<col_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); |
|
|
|
void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *dy, |
|
|
|
const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) { |
|
|
|
const int thread_per_block = 256; |
|
|
|
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T); |
|
|
|
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, |
|
|
|
mean, var, gamma, dx); |
|
|
|
|
|
|
|
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T); |
|
|
|
GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, |
|
|
|
var, dg, db); |
|
|
|
} |
|
|
|
|
|
|
|
template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, |
|
|
|
const float* dy, const float* x, const float* mean, const float* var, const float* gamma, |
|
|
|
float* dx, float* dg, float* db, cudaStream_t stream); |
|
|
|
template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, |
|
|
|
const half* dy, const half* x, const half* mean, const half* var, const half* gamma, |
|
|
|
half* dx, half* dg, half* db, cudaStream_t stream); |
|
|
|
template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, |
|
|
|
const float *dy, const float *x, const float *mean, const float *var, const float *gamma, |
|
|
|
float *dx, float *dg, float *db, cudaStream_t stream); |
|
|
|
template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, |
|
|
|
const half *dy, const half *x, const half *mean, const half *var, const half *gamma, |
|
|
|
half *dx, half *dg, half *db, cudaStream_t stream); |