Browse Source

adagrad: support ouput on gpu

pull/13864/head
zhuyuxiao 4 years ago
parent
commit
a11287c332
4 changed files with 22 additions and 48 deletions
  1. +6
    -40
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu
  2. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh
  3. +14
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h
  4. +2
    -2
      tests/st/ops/gpu/test_adagrad_op.py

+ 6
- 40
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu View File

@@ -32,16 +32,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
const S *learning_rate,
const G *gradient,
T *variable,
T *accumulation,
T *variable_out,
T *accumulation_out) {
T *accumulation) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += gradient[i] * gradient[i];
accumulation_out[i] = accumulation[i];
}
variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}

@@ -51,16 +47,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
const float *learning_rate,
const half *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out) {
half *accumulation) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += gradient[i] * gradient[i];
accumulation_out[i] = accumulation[i];
}
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}

@@ -70,16 +62,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
const float *learning_rate,
const half *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out) {
float *accumulation) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += __half2float(gradient[i]) * __half2float(gradient[i]);
accumulation_out[i] = accumulation[i];
}
variable[i] -= learning_rate[0] * __half2float(gradient[i]) / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}

@@ -89,16 +77,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
const half *learning_rate,
const float *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out) {
float *accumulation) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += gradient[i] * gradient[i];
accumulation_out[i] = accumulation[i];
}
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}

@@ -108,16 +92,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
const float *learning_rate,
const float *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out) {
half *accumulation) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += __float2half(gradient[i]) * __float2half(gradient[i]);
accumulation_out[i] = accumulation[i];
}
variable[i] -= __float2half(learning_rate[0]) * __float2half(gradient[i]) / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}

@@ -128,11 +108,9 @@ void ApplyAdagrad(const size_t size,
const G *gradient,
T *variable,
T *accumulation,
T *variable_out,
T *accumulation_out,
cudaStream_t cuda_stream) {
ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, update_slots, learning_rate, gradient, variable, accumulation, variable_out, accumulation_out);
size, update_slots, learning_rate, gradient, variable, accumulation);
}

template void ApplyAdagrad<float, float, float>(const size_t size,
@@ -141,8 +119,6 @@ template void ApplyAdagrad<float, float, float>(const size_t size,
const float *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out,
cudaStream_t cuda_stream);

template void ApplyAdagrad<half, half, half>(const size_t size,
@@ -151,8 +127,6 @@ template void ApplyAdagrad<half, half, half>(const size_t size,
const half *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out,
cudaStream_t cuda_stream);

template void ApplyAdagrad<half, float, half>(const size_t size,
@@ -161,8 +135,6 @@ template void ApplyAdagrad<half, float, half>(const size_t size,
const half *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out,
cudaStream_t cuda_stream);

template void ApplyAdagrad<float, float, half>(const size_t size,
@@ -171,8 +143,6 @@ template void ApplyAdagrad<float, float, half>(const size_t size,
const half *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out,
cudaStream_t cuda_stream);

template void ApplyAdagrad<float, half, float>(const size_t size,
@@ -181,8 +151,6 @@ template void ApplyAdagrad<float, half, float>(const size_t size,
const float *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out,
cudaStream_t cuda_stream);

template void ApplyAdagrad<half, float, float>(const size_t size,
@@ -191,6 +159,4 @@ template void ApplyAdagrad<half, float, float>(const size_t size,
const float *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out,
cudaStream_t cuda_stream);

+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh View File

@@ -25,8 +25,6 @@ void ApplyAdagrad(const size_t size,
const G *gradient,
T *variable,
T *accumulation,
T *variable_out,
T *accumulation_out,
cudaStream_t stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_

+ 14
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h View File

@@ -45,7 +45,17 @@ class AdagradGpuKernel : public GpuKernel {
T *variable_out = GetDeviceAddress<T>(outputs, 0);
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
variable_out, accumulation_out, reinterpret_cast<cudaStream_t>(stream_ptr));
reinterpret_cast<cudaStream_t>(stream_ptr));

CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&variable_out[0], &variable[0], variable_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&accumulation_out[0], &accumulation[0], accumulation_size_,
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");

return true;
}

@@ -61,17 +71,17 @@ class AdagradGpuKernel : public GpuKernel {
learning_rate_size_ = sizeof(S);
gradient_size_ = sizeof(G);

auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
}

auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < accumulation_shape.size(); i++) {
accumulation_size_ *= accumulation_shape[i];
}

auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i];
}


+ 2
- 2
tests/st/ops/gpu/test_adagrad_op.py View File

@@ -36,8 +36,8 @@ class Net(nn.Cell):
self.accum = Parameter(Tensor(accum_np), name="accum")

def construct(self, lr, grad):
self.apply_adagrad(self.var, self.accum, lr, grad)
return self.var, self.accum
z = self.apply_adagrad(self.var, self.accum, lr, grad)
return z


@pytest.mark.level0


Loading…
Cancel
Save