| @@ -18,10 +18,21 @@ | |||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
| #include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" | #include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" | ||||
| #include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" | |||||
| constexpr int NUM_PER_THREAD_REDUCE = 4; | constexpr int NUM_PER_THREAD_REDUCE = 4; | ||||
| constexpr int WARP_SIZE = 32; | constexpr int WARP_SIZE = 32; | ||||
| template <typename T> | |||||
| inline __device__ T my_pow(T a, double b) { | |||||
| return pow(a, static_cast<float>(b)); | |||||
| } | |||||
| template <> | |||||
| inline __device__ half my_pow(half a, double b) { | |||||
| return __float2half(pow(__half2float(a), static_cast<float>(b))); | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, | inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, | ||||
| const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, | const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, | ||||
| @@ -35,7 +46,7 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d | |||||
| } | } | ||||
| int pos = row * col_dim + col; | int pos = row * col_dim + col; | ||||
| dg[0] += dy[pos] * pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); | |||||
| dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); | |||||
| db[0] += dy[pos]; | db[0] += dy[pos]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -58,26 +69,26 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di | |||||
| // load data to share memory | // load data to share memory | ||||
| // thread(0, 32, 64, 96, ...) keep the data | // thread(0, 32, 64, 96, ...) keep the data | ||||
| extern __shared__ T share_mem[]; | |||||
| DynamicSharedMem<T> share_mem; | |||||
| if (threadIdx.x % WARP_SIZE == 0) { | if (threadIdx.x % WARP_SIZE == 0) { | ||||
| int offset = threadIdx.x / WARP_SIZE * 2; | int offset = threadIdx.x / WARP_SIZE * 2; | ||||
| share_mem[offset] = dg[0]; | |||||
| share_mem[offset + 1] = db[0]; | |||||
| share_mem.addr()[offset] = dg[0]; | |||||
| share_mem.addr()[offset + 1] = db[0]; | |||||
| } | } | ||||
| __syncthreads(); | __syncthreads(); | ||||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | ||||
| if (threadIdx.x < stride) { | if (threadIdx.x < stride) { | ||||
| int offset = (threadIdx.x + stride) * 2; | int offset = (threadIdx.x + stride) * 2; | ||||
| share_mem[threadIdx.x * 2] += share_mem[offset]; | |||||
| share_mem[threadIdx.x * 2 + 1] += share_mem[offset + 1]; | |||||
| share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset]; | |||||
| share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1]; | |||||
| } | } | ||||
| } | } | ||||
| __syncthreads(); | __syncthreads(); | ||||
| if (threadIdx.x == 0) { | if (threadIdx.x == 0) { | ||||
| dg_addr[col] = share_mem[0]; | |||||
| db_addr[col] = share_mem[1]; | |||||
| dg_addr[col] = share_mem.addr()[0]; | |||||
| db_addr[col] = share_mem.addr()[1]; | |||||
| } | } | ||||
| } | } | ||||
| @@ -114,13 +125,37 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con | |||||
| T v1 = dy[pos] * gamma[gamma_offset]; | T v1 = dy[pos] * gamma[gamma_offset]; | ||||
| T v2 = x[pos] - mean[row]; | T v2 = x[pos] - mean[row]; | ||||
| sum1[0] += -0.5 * v1 * v2 * pow(var[row] + epsilon, -1.5); | |||||
| sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5); | |||||
| sum2[0] += v1; | sum2[0] += v1; | ||||
| sum3[0] += -2.0 * v2; | sum3[0] += -2.0 * v2; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <> | |||||
| inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, | |||||
| half* sum1, half* sum2, half* sum3, const half* dy, const half* x, | |||||
| const half* mean, const half* var, const half* gamma) { | |||||
| int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||||
| int col = NUM_PER_THREAD_REDUCE * i + j; | |||||
| if (col >= col_dim) { | |||||
| return; | |||||
| } | |||||
| int pos = row * col_dim + col; | |||||
| int gamma_offset = pos % param_dim; | |||||
| half v1 = dy[pos] * gamma[gamma_offset]; | |||||
| half v2 = x[pos] - mean[row]; | |||||
| sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5); | |||||
| sum2[0] += v1; | |||||
| sum3[0] += __float2half(-2.0) * v2; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { | inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { | ||||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | ||||
| @@ -166,12 +201,28 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& | |||||
| int gamma_offset = pos % param_dim; | int gamma_offset = pos % param_dim; | ||||
| T v1 = dy[pos] * gamma[gamma_offset]; | T v1 = dy[pos] * gamma[gamma_offset]; | ||||
| T v2 = x[pos] - mean[row]; | T v2 = x[pos] - mean[row]; | ||||
| T v3 = pow(var[row] + epsilon, -0.5); | |||||
| T v3 = my_pow(var[row] + epsilon, -0.5); | |||||
| dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + | dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + | ||||
| (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); | (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); | ||||
| } | } | ||||
| } | } | ||||
| template <> | |||||
| inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, | |||||
| const half* dy, const half* x, const half* mean, const half* var, const half* gamma, | |||||
| half* dx, const half* share_mem) { | |||||
| for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { | |||||
| int pos = (row * col_dim + col); | |||||
| int gamma_offset = pos % param_dim; | |||||
| half v1 = dy[pos] * gamma[gamma_offset]; | |||||
| half v2 = x[pos] - mean[row]; | |||||
| half v3 = my_pow(var[row] + epsilon, -0.5); | |||||
| dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + | |||||
| (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ | |||||
| * __float2half(1.0 / col_dim); | |||||
| } | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, | __global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, | ||||
| const T* x, const T* mean, const T* var, const T* gamma, T* dx) { | const T* x, const T* mean, const T* var, const T* gamma, T* dx) { | ||||
| @@ -179,27 +230,30 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int | |||||
| T sum1 = 0; | T sum1 = 0; | ||||
| T sum2 = 0; | T sum2 = 0; | ||||
| T sum3 = 0; | T sum3 = 0; | ||||
| extern __shared__ T share_mem[]; | |||||
| DynamicSharedMem<T> share_mem; | |||||
| InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); | InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); | ||||
| InputWarpReduce(&sum1, &sum2, &sum3); | InputWarpReduce(&sum1, &sum2, &sum3); | ||||
| InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem); | |||||
| InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem); | |||||
| InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr()); | |||||
| InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr()); | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, | void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, | ||||
| const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { | const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { | ||||
| int share_mem = | |||||
| int share_mem_size = | |||||
| ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); | ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); | ||||
| InputPropKernel<<<row_dim, 256, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, | |||||
| dx); | |||||
| InputPropKernel<<<row_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, | |||||
| gamma, dx); | |||||
| share_mem = | |||||
| share_mem_size = | |||||
| ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); | ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); | ||||
| GammaAndBetaPropKernel<<<col_dim, 256, share_mem, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); | |||||
| GammaAndBetaPropKernel<<<col_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); | |||||
| } | } | ||||
| template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, | template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, | ||||
| const float* dy, const float* x, const float* mean, const float* var, const float* gamma, | const float* dy, const float* x, const float* mean, const float* var, const float* gamma, | ||||
| float* dx, float* dg, float* db, cudaStream_t stream); | float* dx, float* dg, float* db, cudaStream_t stream); | ||||
| template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, | |||||
| const half* dy, const half* x, const half* mean, const half* var, const half* gamma, | |||||
| half* dx, half* dg, half* db, cudaStream_t stream); | |||||
| @@ -35,7 +35,8 @@ inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T & | |||||
| template <typename T> | template <typename T> | ||||
| inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { | inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { | ||||
| if (n2 == 0) { | |||||
| T zero = 0; | |||||
| if (n2 == zero) { | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -112,6 +113,17 @@ inline __device__ void LayerNorm(const int &row, const int &col_dim, const int & | |||||
| } | } | ||||
| } | } | ||||
| template <> | |||||
| inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const half *x, | |||||
| const half *share_mem, const half *gamma, const half *beta, const half epsilon, | |||||
| half *y) { | |||||
| for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { | |||||
| int pos = row * col_dim + col; | |||||
| int i = pos % param_dim; | |||||
| y[pos] = (x[pos] - share_mem[0]) / hsqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; | |||||
| } | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, | __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, | ||||
| const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { | const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { | ||||
| @@ -120,14 +132,14 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int | |||||
| T var = 0; | T var = 0; | ||||
| T num = 0; | T num = 0; | ||||
| const T *block_addr = x + row * col_dim; | const T *block_addr = x + row * col_dim; | ||||
| extern __shared__ T share_mem[]; | |||||
| DynamicSharedMem<T> share_mem; | |||||
| ThreadReduce(col_dim, block_addr, &mean, &var, &num); | ThreadReduce(col_dim, block_addr, &mean, &var, &num); | ||||
| WarpReduce(&mean, &var, &num); | WarpReduce(&mean, &var, &num); | ||||
| BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem); | |||||
| BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr()); | |||||
| __syncthreads(); | __syncthreads(); | ||||
| LayerNorm(row, col_dim, param_dim, x, share_mem, gamma, beta, epsilon, y); | |||||
| LayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y); | |||||
| } | } | ||||
| } | } | ||||
| @@ -137,12 +149,15 @@ void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, con | |||||
| const dim3 block(row_dim); | const dim3 block(row_dim); | ||||
| const dim3 thread(256); | const dim3 thread(256); | ||||
| // keep the mean/var/num after warp reduce | // keep the mean/var/num after warp reduce | ||||
| int share_mem = | |||||
| int share_mem_size = | |||||
| ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); | ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); | ||||
| LayerNormKernel<<<block, thread, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, mean, | |||||
| var); | |||||
| LayerNormKernel<<<block, thread, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, | |||||
| mean, var); | |||||
| } | } | ||||
| template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, | template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, | ||||
| const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, | const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, | ||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, | |||||
| const half *x, const half *gamma, const half *beta, half *y, half *mean, half *var, | |||||
| cudaStream_t stream); | |||||
| @@ -19,6 +19,23 @@ | |||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| template <typename T> | |||||
| struct DynamicSharedMem; | |||||
| template<> | |||||
| struct DynamicSharedMem<float> { | |||||
| __device__ float *addr() { | |||||
| extern __shared__ float addr_float[]; | |||||
| return addr_float; | |||||
| } | |||||
| }; | |||||
| template<> | |||||
| struct DynamicSharedMem<half> { | |||||
| __device__ half *addr() { | |||||
| extern __shared__ half addr_half[]; | |||||
| return addr_half; | |||||
| } | |||||
| }; | |||||
| template <typename T> | template <typename T> | ||||
| void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, | void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, | ||||
| const T* beta, T* y, T* mean, T* var, cudaStream_t stream); | const T* beta, T* y, T* mean, T* var, cudaStream_t stream); | ||||
| @@ -15,25 +15,38 @@ | |||||
| */ | */ | ||||
| #include "momentum_impl.cuh" | #include "momentum_impl.cuh" | ||||
| template <typename T> | |||||
| __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const T *learning_rate, | |||||
| const T *gradient, const T *momentum) { | |||||
| template <typename T, typename S> | |||||
| __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, | |||||
| const T *gradient, const S *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | ||||
| accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; | accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; | ||||
| variable[i] -= learning_rate[0] * accumulation[i]; | variable[i] -= learning_rate[0] * accumulation[i]; | ||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const T *learning_rate, const T *gradient, | |||||
| const T *momentum, cudaStream_t cuda_stream) { | |||||
| template <> | |||||
| __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation, | |||||
| const float *learning_rate, const half *gradient, | |||||
| const float *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||||
| accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; | |||||
| variable[i] -= __float2half(learning_rate[0]) * accumulation[i]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, | |||||
| const S *momentum, cudaStream_t cuda_stream) { | |||||
| MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation, | MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation, | ||||
| learning_rate, gradient, momentum); | learning_rate, gradient, momentum); | ||||
| return; | return; | ||||
| } | } | ||||
| template void MomentumUpdateVariable<float>(const size_t size, float *variable, float *accumulation, | |||||
| const float *learning_rate, const float *gradient, const float *momentum, | |||||
| cudaStream_t cuda_stream); | |||||
| template void MomentumUpdateVariable<half>(const size_t size, half *variable, half *accumulation, | |||||
| const half *learning_rate, const half *gradient, const half *momentum, | |||||
| cudaStream_t cuda_stream); | |||||
| template void MomentumUpdateVariable<float, float>(const size_t size, float *variable, float *accumulation, | |||||
| const float *learning_rate, const float *gradient, | |||||
| const float *momentum, cudaStream_t cuda_stream); | |||||
| template void MomentumUpdateVariable<half, half>(const size_t size, half *variable, half *accumulation, | |||||
| const half *learning_rate, const half *gradient, | |||||
| const half *momentum, cudaStream_t cuda_stream); | |||||
| template void MomentumUpdateVariable<half, float>(const size_t size, half *variable, half *accumulation, | |||||
| const float *learning_rate, const half *gradient, | |||||
| const float *momentum, cudaStream_t cuda_stream); | |||||
| @@ -18,8 +18,8 @@ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | ||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| template <typename T> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const T *learning_rate, const T *gradient, | |||||
| const T *momentum, cudaStream_t cuda_stream); | |||||
| template <typename T, typename S> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, | |||||
| const S *momentum, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | ||||
| @@ -27,5 +27,14 @@ MS_REG_GPU_KERNEL_ONE(LayerNorm, | |||||
| .AddOutputAttr(kNumberTypeFloat32) | .AddOutputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| LayerNormGpuKernel, float) | LayerNormGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(LayerNorm, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| LayerNormGpuKernel, half) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,5 +29,16 @@ MS_REG_GPU_KERNEL_ONE(LayerNormGrad, | |||||
| .AddOutputAttr(kNumberTypeFloat32) | .AddOutputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| LayerNormGradGpuKernel, float) | LayerNormGradGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(LayerNormGrad, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| LayerNormGradGpuKernel, half) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,7 +18,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_ONE(ApplyMomentum, | |||||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| @@ -26,8 +26,8 @@ MS_REG_GPU_KERNEL_ONE(ApplyMomentum, | |||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| MomentumGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(ApplyMomentum, | |||||
| MomentumGpuKernel, float, float) | |||||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| @@ -35,6 +35,15 @@ MS_REG_GPU_KERNEL_ONE(ApplyMomentum, | |||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16), | .AddOutputAttr(kNumberTypeFloat16), | ||||
| MomentumGpuKernel, half) | |||||
| MomentumGpuKernel, half, half) | |||||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| MomentumGpuKernel, half, float) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "kernel/gpu/cuda_impl/momentum_impl.cuh" | #include "kernel/gpu/cuda_impl/momentum_impl.cuh" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | |||||
| template <typename T, typename S> | |||||
| class MomentumGpuKernel : public GpuKernel { | class MomentumGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| MomentumGpuKernel() | MomentumGpuKernel() | ||||
| @@ -37,9 +37,9 @@ class MomentumGpuKernel : public GpuKernel { | |||||
| void *stream_ptr) override { | void *stream_ptr) override { | ||||
| T *variable = GetDeviceAddress<T>(inputs, 0); | T *variable = GetDeviceAddress<T>(inputs, 0); | ||||
| T *accumulation = GetDeviceAddress<T>(inputs, 1); | T *accumulation = GetDeviceAddress<T>(inputs, 1); | ||||
| T *learning_rate = GetDeviceAddress<T>(inputs, 2); | |||||
| S *learning_rate = GetDeviceAddress<S>(inputs, 2); | |||||
| T *gradient = GetDeviceAddress<T>(inputs, 3); | T *gradient = GetDeviceAddress<T>(inputs, 3); | ||||
| T *momentum = GetDeviceAddress<T>(inputs, 4); | |||||
| S *momentum = GetDeviceAddress<S>(inputs, 4); | |||||
| MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, | MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| @@ -53,9 +53,9 @@ class MomentumGpuKernel : public GpuKernel { | |||||
| variable_size_ = sizeof(T); | variable_size_ = sizeof(T); | ||||
| accumulation_size_ = sizeof(T); | accumulation_size_ = sizeof(T); | ||||
| learning_rate_size_ = sizeof(T); | |||||
| learning_rate_size_ = sizeof(S); | |||||
| gradient_size_ = sizeof(T); | gradient_size_ = sizeof(T); | ||||
| momentum_size_ = sizeof(T); | |||||
| momentum_size_ = sizeof(S); | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| for (size_t i = 0; i < variable_shape.size(); i++) { | for (size_t i = 0; i < variable_shape.size(); i++) { | ||||