diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cu new file mode 100644 index 0000000000..a0407d0e67 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cu @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_ftrl_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" + +template +__device__ __forceinline__ T PowFunc(T x, T y) { + return pow(x, y); +} + +template <> +__device__ __forceinline__ half PowFunc(half x, half y) { + return __float2half(pow(__half2float(x), __half2float(y))); +} + +template +__device__ __forceinline__ bool CompareFunc(T x, T y) { + return abs(x) > y; +} + +template <> +__device__ __forceinline__ bool CompareFunc(half x, half y) { + return abs(__half2float(x)) > __half2float(y); +} + +template +__device__ __forceinline__ T Sgn(T x) { + return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); +} + +template <> +__device__ __forceinline__ half Sgn(half x) { + return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); +} + +template +__global__ void SparseApplyFtrlKernel(const T *gradient, const S *indices, const int num_index, const size_t n_stride, + const float learning_rate, const float l1_regularization, + const float l2_regularization, const float learning_rate_power, + T *variable, T *accumulation, T *linear) { + const T two = static_cast(2.0); + const T learning_rate_val = static_cast(learning_rate); + const T l1_regularization_val = static_cast(l1_regularization); + const T l2_regularization_val = static_cast(l2_regularization); + const T learning_rate_power_val = static_cast(-learning_rate_power); + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (num_index*n_stride); + pos += gridDim.x * blockDim.x) { + const int posn = pos / n_stride; + const int posi = pos % n_stride; + const int indexed_n = indices[posn]; + const int i = indexed_n*n_stride + posi; + const T cur_accumulation = accumulation[i] + gradient[pos] * gradient[pos]; + const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); + const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); + const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate_val; + + linear[i] += gradient[pos] - sigma * variable[i]; + variable[i] = CompareFunc(linear[i], l1_regularization_val) + ? ((l1_regularization_val * Sgn(linear[i]) - linear[i]) / + (cur_accumulation_power / learning_rate_val + two * l2_regularization_val)) + : static_cast(0); + accumulation[i] = cur_accumulation; + } + return; +} + +template +void CalSparseApplyFtrl(const T *gradient, const S *indices, const int num_index, const size_t n_stride, + const float learning_rate, const float l1_regularization, const float l2_regularization, + const float learning_rate_power, const bool use_locking, T *variable, T *accumulation, + T *linear, cudaStream_t cuda_stream) { + SparseApplyFtrlKernel<<>>(gradient, indices, num_index, + n_stride, learning_rate, l1_regularization, l2_regularization, learning_rate_power, variable, accumulation, linear); + return; +} + +template void CalSparseApplyFtrl(const float *gradient, const int *indices, const int num_index, + const size_t n_stride, const float learning_rate, + const float l1_regularization, const float l2_regularization, + const float learning_rate_power, const bool use_locking, float *variable, + float *accumulation, float *linear, cudaStream_t cuda_stream); +template void CalSparseApplyFtrl(const half *gradient, const int *indices, const int num_index, + const size_t n_stride, const float learning_rate, + const float l1_regularization, const float l2_regularization, + const float learning_rate_power, const bool use_locking, half *variable, + half *accumulation, half *linear, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cuh new file mode 100644 index 0000000000..c5cebe6d0f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPARSE_FTRL_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPARSE_FTRL_IMPL_H_ +template +void CalSparseApplyFtrl(const T *gradient, const S *indices, const int num_index, const size_t n_stride, + const float learning_rate, const float l1_regularization, const float l2_regularization, + const float learning_rate_power, const bool use_locking, T *variable, T *accumulation, + T *linear, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPARSE_FTRL_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.cc new file mode 100644 index 0000000000..2ee596a4f0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseFtrlGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SparseFtrlGpuKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h new file mode 100644 index 0000000000..62dc3b64ef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cuh" +namespace mindspore { +namespace kernel { +template +class SparseFtrlGpuKernel : public GpuKernel { + public: + SparseFtrlGpuKernel() + : variable_size_(0), + accumulation_size_(0), + linear_size_(0), + gradient_size_(0), + indices_size_(0), + lr_(0.0f), + l1_(0.0f), + l2_(0.0f), + lr_power_(0.0f), + use_locking_(false), + num_index_(0), + n_stride_(1) {} + + ~SparseFtrlGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + T *linear = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + S *indices = GetDeviceAddress(inputs, 4); + CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable, + accumulation, linear, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but sparse ftrl needs 5 inputs."; + return false; + } + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + linear_size_ = sizeof(T); + gradient_size_ = sizeof(T); + indices_size_ = sizeof(S); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + if (i > 0) { + n_stride_ *= variable_shape[i]; + } + } + + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + + auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < linear_shape.size(); i++) { + linear_size_ *= linear_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + for (size_t i = 0; i < indices_shape.size(); i++) { + indices_size_ *= indices_shape[i]; + } + + lr_ = GetAttr(kernel_node, "lr"); + l1_ = GetAttr(kernel_node, "l1"); + l2_ = GetAttr(kernel_node, "l2"); + lr_power_ = GetAttr(kernel_node, "lr_power"); + use_locking_ = GetAttr(kernel_node, "use_locking"); + num_index_ = indices_shape[0]; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(linear_size_); + input_size_list_.push_back(gradient_size_); + input_size_list_.push_back(indices_size_); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t linear_size_; + size_t gradient_size_; + size_t indices_size_; + float lr_; + float l1_; + float l2_; + float lr_power_; + bool use_locking_; + int num_index_; + size_t n_stride_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_sparse_apply_ftrl_op.py b/tests/st/ops/gpu/test_sparse_apply_ftrl_op.py new file mode 100644 index 0000000000..9519bac4d0 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_apply_ftrl_op.py @@ -0,0 +1,149 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False) + self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") + self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear") + + def construct(self, grad, indices): + out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + return out + + +class Net_half(nn.Cell): + def __init__(self): + super(Net_half, self).__init__() + self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False) + self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="var") + self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="accum") + self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="linear") + + def construct(self, grad, indices): + out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ftrl(): + gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + indices = Tensor([0, 1, 2], mstype.int32) + expect_var = np.array([[[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]], + [[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]], + [[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]]]).astype(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + sparse_apply_ftrl = Net() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sparse_apply_ftrl = Net() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ftrl_sparse(): + gradient = Tensor(np.ones([2, 3, 3]).astype(np.float32)) + indices = Tensor([0, 2], mstype.int32) + expect_var = np.array([[[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]]]).astype(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + sparse_apply_ftrl = Net() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sparse_apply_ftrl = Net() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ftrl_half(): + gradient = Tensor(np.ones([3, 3, 3]).astype(np.float16)) + indices = Tensor([0, 1, 2], mstype.int32) + expect_var = np.array([[[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]], + [[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]], + [[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]]]).astype(np.float16) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + sparse_apply_ftrl = Net_half() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sparse_apply_ftrl = Net_half() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ftrl_sparse_half(): + gradient = Tensor(np.ones([2, 3, 3]).astype(np.float16)) + indices = Tensor([0, 2], mstype.int32) + expect_var = np.array([[[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479], + [0.291479, 0.291479, 0.291479]]]).astype(np.float16) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + sparse_apply_ftrl = Net_half() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sparse_apply_ftrl = Net_half() + sparse_apply_ftrl(gradient, indices) + assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)