|
|
@@ -17,36 +17,60 @@ |
|
|
#include "momentum_impl.cuh" |
|
|
#include "momentum_impl.cuh" |
|
|
template <typename T, typename S, typename G> |
|
|
template <typename T, typename S, typename G> |
|
|
__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, |
|
|
__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, |
|
|
const G *gradient, const S *momentum) { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; |
|
|
|
|
|
variable[i] -= learning_rate[0] * accumulation[i]; |
|
|
|
|
|
|
|
|
const G *gradient, const S *momentum, bool use_nesterov) { |
|
|
|
|
|
if (use_nesterov) { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; |
|
|
|
|
|
variable[i] -= gradient[i] * learning_rate[0] + accumulation[i] * momentum[0] * learning_rate[0]; |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; |
|
|
|
|
|
variable[i] -= learning_rate[0] * accumulation[i]; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
template <> |
|
|
template <> |
|
|
__global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation, |
|
|
__global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation, |
|
|
const float *learning_rate, const half *gradient, const float *momentum) { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; |
|
|
|
|
|
variable[i] -= __float2half(learning_rate[0]) * accumulation[i]; |
|
|
|
|
|
|
|
|
const float *learning_rate, const half *gradient, const float *momentum, |
|
|
|
|
|
bool use_nesterov) { |
|
|
|
|
|
if (use_nesterov) { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; |
|
|
|
|
|
variable[i] -= gradient[i] * __float2half(learning_rate[0]) + |
|
|
|
|
|
accumulation[i] * __float2half(momentum[0]) * __float2half(learning_rate[0]); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; |
|
|
|
|
|
variable[i] -= __float2half(learning_rate[0]) * accumulation[i]; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
template <> |
|
|
template <> |
|
|
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, |
|
|
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, |
|
|
const float *learning_rate, const half *gradient, const float *momentum) { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); |
|
|
|
|
|
variable[i] -= learning_rate[0] * accumulation[i]; |
|
|
|
|
|
|
|
|
const float *learning_rate, const half *gradient, const float *momentum, |
|
|
|
|
|
bool use_nesterov) { |
|
|
|
|
|
if (use_nesterov) { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); |
|
|
|
|
|
variable[i] -= __half2float(gradient[i]) * learning_rate[0] + accumulation[i] * momentum[0] * learning_rate[0]; |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { |
|
|
|
|
|
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); |
|
|
|
|
|
variable[i] -= learning_rate[0] * accumulation[i]; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
template <typename T, typename S, typename G> |
|
|
template <typename T, typename S, typename G> |
|
|
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, |
|
|
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, |
|
|
const S *momentum, cudaStream_t cuda_stream) { |
|
|
|
|
|
MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation, |
|
|
|
|
|
learning_rate, gradient, momentum); |
|
|
|
|
|
|
|
|
const S *momentum, bool use_nesterov, cudaStream_t cuda_stream) { |
|
|
|
|
|
MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( |
|
|
|
|
|
size, variable, accumulation, learning_rate, gradient, momentum, use_nesterov); |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -91,16 +115,20 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu |
|
|
|
|
|
|
|
|
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, |
|
|
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, |
|
|
const float *learning_rate, const float *gradient, |
|
|
const float *learning_rate, const float *gradient, |
|
|
const float *momentum, cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
|
const float *momentum, bool use_nesterov, |
|
|
|
|
|
cudaStream_t cuda_stream); |
|
|
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation, |
|
|
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation, |
|
|
const half *learning_rate, const half *gradient, |
|
|
const half *learning_rate, const half *gradient, |
|
|
const half *momentum, cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
|
const half *momentum, bool use_nesterov, |
|
|
|
|
|
cudaStream_t cuda_stream); |
|
|
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation, |
|
|
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation, |
|
|
const float *learning_rate, const half *gradient, |
|
|
const float *learning_rate, const half *gradient, |
|
|
const float *momentum, cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
|
const float *momentum, bool use_nesterov, |
|
|
|
|
|
cudaStream_t cuda_stream); |
|
|
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation, |
|
|
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation, |
|
|
const float *learning_rate, const half *gradient, |
|
|
const float *learning_rate, const half *gradient, |
|
|
const float *momentum, cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
|
const float *momentum, bool use_nesterov, |
|
|
|
|
|
cudaStream_t cuda_stream); |
|
|
|
|
|
|
|
|
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, |
|
|
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, |
|
|
float *variable, float *accumulation, const float *learning_rate, |
|
|
float *variable, float *accumulation, const float *learning_rate, |
|
|
|