| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/cuda_impl/ftrl_impl.cuh" | |||||
| template <typename T> | |||||
| __device__ __forceinline__ T PowFunc(T x, T y) { | |||||
| return pow(x, y); | |||||
| } | |||||
| template <> | |||||
| __device__ __forceinline__ half PowFunc(half x, half y) { | |||||
| return __float2half(pow(__half2float(x), __half2float(y))); | |||||
| } | |||||
| template <typename T> | |||||
| __device__ __forceinline__ bool CompareFunc(T x, T y) { | |||||
| return abs(x) > y; | |||||
| } | |||||
| template <> | |||||
| __device__ __forceinline__ bool CompareFunc(half x, half y) { | |||||
| return abs(__half2float(x)) > __half2float(y); | |||||
| } | |||||
| template <typename T> | |||||
| __device__ __forceinline__ T Sgn(T x) { | |||||
| return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0); | |||||
| } | |||||
| template <> | |||||
| __device__ __forceinline__ half Sgn(half x) { | |||||
| return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate, | |||||
| const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power, | |||||
| T *variable, T *accumulation, T *linear) { | |||||
| const T two = static_cast<T>(2.0); | |||||
| const T learning_rate_power_val = -learning_rate_power[0]; | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||||
| const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i]; | |||||
| const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); | |||||
| const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); | |||||
| const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0]; | |||||
| linear[i] += gradient[i] - sigma * variable[i]; | |||||
| variable[i] = CompareFunc(linear[i], l1_regularization[0]) | |||||
| ? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) / | |||||
| (cur_accumulation_power / learning_rate[0] + two * l2_regularization[0])) | |||||
| : static_cast<T>(0); | |||||
| accumulation[i] = cur_accumulation; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, | |||||
| const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, | |||||
| cudaStream_t cuda_stream) { | |||||
| ApplyFtrlKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, gradient, learning_rate, l1_regularization, | |||||
| l2_regularization, learning_rate_power, variable, | |||||
| accumulation, linear); | |||||
| } | |||||
| template void ApplyFtrl<float>(const size_t size, const float *gradient, const float *learning_rate, | |||||
| const float *l1_regularization, const float *l2_regularization, | |||||
| const float *learning_rate_power, float *variable, float *accumulation, float *linear, | |||||
| cudaStream_t cuda_stream); | |||||
| template void ApplyFtrl<half>(const size_t size, const half *gradient, const half *learning_rate, | |||||
| const half *l1_regularization, const half *l2_regularization, | |||||
| const half *learning_rate_power, half *variable, half *accumulation, half *linear, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, | |||||
| const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/ftrl_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(ApplyFtrl, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FtrlGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(ApplyFtrl, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| FtrlGpuKernel, half) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,130 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/ftrl_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class FtrlGpuKernel : public GpuKernel { | |||||
| public: | |||||
| FtrlGpuKernel() | |||||
| : variable_size_(0), | |||||
| accumulation_size_(0), | |||||
| linear_size_(0), | |||||
| gradient_size_(0), | |||||
| learning_rate_size_(0), | |||||
| l1_regularization_size_(0), | |||||
| l2_regularization_size_(0), | |||||
| learning_rate_power_size_(0) {} | |||||
| ~FtrlGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||||
| void *stream_ptr) override { | |||||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||||
| T *accumulation = GetDeviceAddress<T>(inputs, 1); | |||||
| T *linear = GetDeviceAddress<T>(inputs, 2); | |||||
| T *gradient = GetDeviceAddress<T>(inputs, 3); | |||||
| T *learning_rate = GetDeviceAddress<T>(inputs, 4); | |||||
| T *l1_regularization = GetDeviceAddress<T>(inputs, 5); | |||||
| T *l2_regularization = GetDeviceAddress<T>(inputs, 6); | |||||
| T *learning_rate_power = GetDeviceAddress<T>(inputs, 7); | |||||
| ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization, | |||||
| learning_rate_power, variable, accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 8) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs."; | |||||
| return false; | |||||
| } | |||||
| variable_size_ = sizeof(T); | |||||
| accumulation_size_ = sizeof(T); | |||||
| linear_size_ = sizeof(T); | |||||
| gradient_size_ = sizeof(T); | |||||
| learning_rate_size_ = sizeof(T); | |||||
| l1_regularization_size_ = sizeof(T); | |||||
| l2_regularization_size_ = sizeof(T); | |||||
| learning_rate_power_size_ = sizeof(T); | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < variable_shape.size(); i++) { | |||||
| variable_size_ *= variable_shape[i]; | |||||
| } | |||||
| auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < accumulation_shape.size(); i++) { | |||||
| accumulation_size_ *= accumulation_shape[i]; | |||||
| } | |||||
| auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| for (size_t i = 0; i < linear_shape.size(); i++) { | |||||
| linear_size_ *= linear_shape[i]; | |||||
| } | |||||
| auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); | |||||
| for (size_t i = 0; i < gradient_shape.size(); i++) { | |||||
| gradient_size_ *= gradient_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(variable_size_); | |||||
| input_size_list_.push_back(accumulation_size_); | |||||
| input_size_list_.push_back(linear_size_); | |||||
| input_size_list_.push_back(gradient_size_); | |||||
| input_size_list_.push_back(learning_rate_size_); | |||||
| input_size_list_.push_back(l1_regularization_size_); | |||||
| input_size_list_.push_back(l2_regularization_size_); | |||||
| input_size_list_.push_back(learning_rate_power_size_); | |||||
| output_size_list_.push_back(0); | |||||
| } | |||||
| private: | |||||
| size_t variable_size_; | |||||
| size_t accumulation_size_; | |||||
| size_t linear_size_; | |||||
| size_t gradient_size_; | |||||
| size_t learning_rate_size_; | |||||
| size_t l1_regularization_size_; | |||||
| size_t l2_regularization_size_; | |||||
| size_t learning_rate_power_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,78 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import Dense | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import FTRL | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class NetFtrl(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetFtrl, self).__init__() | |||||
| self.batch_size = 1 | |||||
| self.reshape = P.Reshape() | |||||
| weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) | |||||
| self.fc1 = Dense(16, 10, weight_init=weight) | |||||
| def construct(self, input_x): | |||||
| output = self.reshape(input_x, (self.batch_size, -1)) | |||||
| output = self.fc1(output) | |||||
| return output | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ftrl(): | |||||
| epoch = 3 | |||||
| net = NetFtrl() | |||||
| optimizer = FTRL(filter(lambda x: x.requires_grad, | |||||
| net.get_parameters()), learning_rate=0.01) | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| net_with_criterion = WithLossCell(net, criterion) | |||||
| train_network = TrainOneStepCell( | |||||
| net_with_criterion, optimizer) | |||||
| train_network.set_train() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| losses1 = [] | |||||
| for _ in range(epoch): | |||||
| data = Tensor(np.arange(0, 16).reshape( | |||||
| 1, 1, 4, 4).astype(np.float32) * 0.01) | |||||
| label = Tensor(np.array([0]).astype(np.int32)) | |||||
| loss = train_network(data, label) | |||||
| losses1.append(loss.asnumpy()) | |||||
| assert losses1[0] > losses1[1] | |||||
| assert losses1[1] > losses1[2] | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| losses2 = [] | |||||
| for _ in range(epoch): | |||||
| data = Tensor(np.arange(0, 16).reshape( | |||||
| 1, 1, 4, 4).astype(np.float32) * 0.01) | |||||
| label = Tensor(np.array([0]).astype(np.int32)) | |||||
| loss = train_network(data, label) | |||||
| losses2.append(loss.asnumpy()) | |||||
| assert losses2[0] > losses2[1] | |||||
| assert losses2[1] > losses2[2] | |||||