Browse Source

add GPU SparseApplyProximalAdagrad

tags/v1.1.0
TFbunny 5 years ago
parent
commit
a638973378
5 changed files with 449 additions and 0 deletions
  1. +103
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cu
  2. +27
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cuh
  3. +46
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.cc
  4. +140
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.h
  5. +133
    -0
      tests/st/ops/gpu/test_sparse_apply_proximal_adagrad_op.py

+ 103
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cu View File

@@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cuh"

template <typename T>
__device__ __forceinline__ bool CompareFunc(T x, T y) {
return x > y;
}

template <>
__device__ __forceinline__ bool CompareFunc(half x, half y) {
return __half2float(x) > __half2float(y);
}

template <typename T>
__device__ __forceinline__ T RsqrtFunc(T x) {
return __frsqrt_rn(x);
}

template <>
__device__ __forceinline__ half RsqrtFunc(half x) {
return hrsqrt(x);
}

template <typename T>
__device__ __forceinline__ T AbsFunc(T x) {
return abs(x);
}

template <>
__device__ __forceinline__ half AbsFunc(half x) {
return __float2half(abs(__half2float(x)));
}

template <typename T>
__device__ __forceinline__ T Sgn(T x) {
return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0);
}

template <>
__device__ __forceinline__ half Sgn(half x) {
return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0);
}

template <typename T>
__global__ void SparseApplyProximalAdagradUpdate(const size_t size, const size_t indices_size, const T *learning_rate,
const T *l1_regularization, const T *l2_regularization,
const T *gradient, const int *indices, T *variable, T *accumulation,
T *variable_out, T *accumulation_out) {
const int inner_size = static_cast<int>(size / indices_size);
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast<int>(size); pos += gridDim.x * blockDim.x) {
const int index = pos / inner_size;
const int inner_pos = pos % inner_size;
const int grad_pos = pos;
const int cur_pos = indices[index] * inner_size + inner_pos;
accumulation[cur_pos] += gradient[grad_pos] * gradient[grad_pos];
const T scratch1 = learning_rate[0] * RsqrtFunc(accumulation[cur_pos]);
T prox_v = variable[cur_pos];
prox_v -= gradient[grad_pos] * scratch1;
const T scratch2 = AbsFunc(prox_v) - scratch1 * l1_regularization[0];
const T scratch3 = CompareFunc(scratch2, static_cast<T>(0.0)) ? scratch2 : static_cast<T>(0.0);
variable[cur_pos] = CompareFunc(l1_regularization[0], static_cast<T>(0.0)) ? Sgn(prox_v) * scratch3 : prox_v;
variable[cur_pos] = variable[cur_pos] / (static_cast<T>(1.0) + l2_regularization[0] * scratch1);
accumulation_out[cur_pos] = accumulation[cur_pos];
variable_out[cur_pos] = variable[cur_pos];
}
}

template <typename T>
void CalSparseApplyProximalAdagrad(const size_t size, const size_t indices_size, const T *learning_rate,
const T *l1_regularization, const T *l2_regularization, const T *gradient,
const int *indices, T *variable, T *accumulation, T *variable_out,
T *accumulation_out, cudaStream_t cuda_stream) {
SparseApplyProximalAdagradUpdate<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, indices_size, learning_rate, l1_regularization, l2_regularization, gradient, indices, variable, accumulation,
variable_out, accumulation_out);
}

template void CalSparseApplyProximalAdagrad<float>(const size_t size, const size_t indices_size,
const float *learning_rate, const float *l1_regularization,
const float *l2_regularization, const float *gradient,
const int *indices, float *variable, float *accumulation,
float *variable_out, float *accumulation_out,
cudaStream_t cuda_stream);
template void CalSparseApplyProximalAdagrad<half>(const size_t size, const size_t indices_size,
const half *learning_rate, const half *l1_regularization,
const half *l2_regularization, const half *gradient,
const int *indices, half *variable, half *accumulation,
half *variable_out, half *accumulation_out, cudaStream_t cuda_stream);

+ 27
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cuh View File

@@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMP_SPARSE_APPLY_PROXIMAL_ADAGRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMP_SPARSE_APPLY_PROXIMAL_ADAGRAD_IMPL_CUH_

#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalSparseApplyProximalAdagrad(const size_t size, const size_t indices_size, const T *learning_rate,
const T *l1_regularization, const T *l2_regularization, const T *gradient,
const int *indices, T *variable, T *accumulation, T *variable_out,
T *accumulation_out, cudaStream_t cuda_stream);

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMP_SPARSE_APPLY_PROXIMAL_ADAGRAD_IMPL_CUH_

+ 46
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.cc View File

@@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(SparseApplyProximalAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SparseApplyProximalAdagradKernel, float)
MS_REG_GPU_KERNEL_ONE(SparseApplyProximalAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SparseApplyProximalAdagradKernel, half)
} // namespace kernel
} // namespace mindspore

+ 140
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.h View File

@@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the
* "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the
* License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in
* writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_APPLY_PROXIMAL_ADAGRAD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_APPLY_PROXIMAL_ADAGRAD_KERNEL_H_

#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/sparse_apply_proximal_adagrad_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T>
class SparseApplyProximalAdagradKernel : public GpuKernel {
public:
SparseApplyProximalAdagradKernel()
: variable_size_(0),
accumulation_size_(0),
learning_rate_size_(0),
l1_regularization_size_(0),
l2_regularization_size_(0),
gradient_size_(0),
indices_size_(0) {}

~SparseApplyProximalAdagradKernel() override = default;

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
T *learning_rate = GetDeviceAddress<T>(inputs, 2);
T *l1_regularization = GetDeviceAddress<T>(inputs, 3);
T *l2_regularization = GetDeviceAddress<T>(inputs, 4);
T *gradient = GetDeviceAddress<T>(inputs, 5);
int *indices = GetDeviceAddress<int>(inputs, 6);
T *variable_out = GetDeviceAddress<T>(outputs, 0);
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);

CalSparseApplyProximalAdagrad(inputs[0]->size / sizeof(T), indices_size_ / sizeof(int), learning_rate,
l1_regularization, l2_regularization, gradient, indices, variable, accumulation,
variable_out, accumulation_out, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 7) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but SparseApplyProximalAdagrad needs 7 inputs.";
return false;
}

variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
learning_rate_size_ = sizeof(T);
l1_regularization_size_ = sizeof(T);
l2_regularization_size_ = sizeof(T);
gradient_size_ = sizeof(T);
indices_size_ = sizeof(int);

auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
}

auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < accumulation_shape.size(); i++) {
accumulation_size_ *= accumulation_shape[i];
}

auto learning_rate_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < learning_rate_shape.size(); i++) {
learning_rate_size_ *= learning_rate_shape[i];
}

auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5);
for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i];
}

auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6);
for (size_t i = 0; i < indices_shape.size(); i++) {
indices_size_ *= indices_shape[i];
}
InitSizeLists();
return true;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(accumulation_size_);
input_size_list_.push_back(learning_rate_size_);
input_size_list_.push_back(l1_regularization_size_);
input_size_list_.push_back(l2_regularization_size_);
input_size_list_.push_back(gradient_size_);
input_size_list_.push_back(indices_size_);
output_size_list_.push_back(variable_size_);
output_size_list_.push_back(accumulation_size_);
}

private:
size_t variable_size_;
size_t accumulation_size_;
size_t learning_rate_size_;
size_t l1_regularization_size_;
size_t l2_regularization_size_;
size_t gradient_size_;
size_t indices_size_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_APPLY_PROXIMAL_ADAGRAD_KERNEL_H_

+ 133
- 0
tests/st/ops/gpu/test_sparse_apply_proximal_adagrad_op.py View File

@@ -0,0 +1,133 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

class Net(nn.Cell):
def __init__(self, var, accum, lr, l1, l2):
super(Net, self).__init__()
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
self.var = Parameter(var, name="var")
self.accum = Parameter(accum, name="accum")
self.lr = lr
self.l1 = l1
self.l2 = l2

def construct(self, grad, indices):
out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices)
return out

def add_testcase(var, accum, lr, l1, l2, grad, indices):
net = Net(var, accum, lr, l1, l2)
return net(grad, indices)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_small_shape():
var = Tensor(np.arange(9).reshape(3, 3).astype(np.float32))
accum = Tensor(np.zeros(9).reshape(3, 3).astype(np.float32))
lr = 1.0
l1 = 1.0
l2 = 0.0
grad = Tensor(np.ones(9).reshape(3, 3).astype(np.float32) * 8)
indices = Tensor(np.array([1, 0, 2], np.int32))
output1, output2 = add_testcase(var, accum, lr, l1, l2, grad, indices)
expect1 = np.array([[-0.875, 0., 0.875],
[1.875, 2.875, 3.875],
[4.875, 5.875, 6.875]])
expect2 = np.array([[64., 64., 64.],
[64., 64., 64.],
[64., 64., 64.]])
np.testing.assert_array_almost_equal(output1.asnumpy(), expect1)
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_lr_l1_l2():
var = Tensor(np.arange(9).reshape(3, 3).astype(np.float32))
accum = Tensor(np.zeros(9).reshape(3, 3).astype(np.float32))
lr = 100.0
l1 = 34.0
l2 = 16.0
grad = Tensor(np.ones(9).reshape(3, 3).astype(np.float32) * 8)
indices = Tensor(np.array([1, 0, 2], np.int32))
output1, output2 = add_testcase(var, accum, lr, l1, l2, grad, indices)
expect1 = np.array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
expect2 = np.array([[64., 64., 64.],
[64., 64., 64.],
[64., 64., 64.]])
np.testing.assert_array_almost_equal(output1.asnumpy(), expect1)
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_with_np_arange():
var = Tensor(np.arange(9).reshape(3, 3).astype(np.float32))
accum = Tensor(np.arange(63, 72).reshape(3, 3).astype(np.float32))
lr = 1.0
l1 = 1.0
l2 = 2.0
grad = Tensor(np.arange(34, 43).reshape(3, 3).astype(np.float32) * 8)
indices = Tensor(np.array([2, 1, 0], np.int32))
output1, output2 = add_testcase(var, accum, lr, l1, l2, grad, indices)
expect1 = np.array([[-0.99038047, 0., 0.9914129],
[1.9836018, 2.9774926, 3.9716945],
[4.9603353, 5.9543643, 6.948723]])
expect2 = np.array([[102463., 107648., 112961.],
[87682., 92483., 97412.],
[74053., 78470., 83015.]])
np.testing.assert_array_almost_equal(output1.asnumpy(), expect1)
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_large_shape():
var = Tensor(np.arange(24).reshape((2, 3, 4)).astype(np.float32))
accum = Tensor(np.arange(34, 58).reshape((2, 3, 4)).astype(np.float32))
lr = 1.0
l1 = 1.0
l2 = 2.0
grad = Tensor(np.ones(24).reshape((2, 3, 4)).astype(np.float32) * 2)
indices = Tensor(np.arange(2).astype(np.int32))
output1, output2 = add_testcase(var, accum, lr, l1, l2, grad, indices)
#expected outputs are from Dchip
expect1 = np.array([[[-0.12248275, 0.39357165, 1.1591142, 1.9289699],
[2.7029436, 3.4808538, 4.2625313, 5.0478177],
[5.836565, 6.6286335, 7.423894, 8.222222]],
[[9.023503, 9.82763, 10.634497, 11.444007],
[12.256072, 13.0706005, 13.887513, 14.706733],
[15.528182, 16.35179, 17.177492, 18.005226]]])
expect2 = np.array([[[38., 39., 40., 41.],
[42., 43., 44., 45.],
[46., 47., 48., 49.]],
[[50., 51., 52., 53.],
[54., 55., 56., 57.],
[58., 59., 60., 61.]]])
np.testing.assert_array_almost_equal(output1.asnumpy(), expect1)
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)

Loading…
Cancel
Save