Merge pull request !322 from chenweifeng/rmsproptags/v0.3.0-alpha
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, | |||
| T* mean_square, T*moment, T* gradients, const size_t size) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; | |||
| moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i]; | |||
| variable[i] -= moment[i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, | |||
| T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { | |||
| RmsPropKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon, | |||
| variable, mean_square, moment, gradients, size); | |||
| } | |||
| template <typename T> | |||
| __global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, | |||
| T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients, | |||
| const size_t size) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i]; | |||
| mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; | |||
| moment[i] = momentum[0] * moment[i] + learning_rate[0] * | |||
| rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i]; | |||
| variable[i] -= moment[i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, | |||
| T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size, | |||
| cudaStream_t cuda_stream) { | |||
| RmsPropCenterKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon, | |||
| variable, mean_gradients, mean_square, | |||
| moment, gradients, size); | |||
| } | |||
| template | |||
| void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, | |||
| float* variable, float* mean_square, float* moment, float* gradients, const size_t size, | |||
| cudaStream_t cuda_stream); | |||
| template | |||
| void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, | |||
| float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients, | |||
| const size_t size, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square, | |||
| T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, | |||
| T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/nn/rmsprop_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| RMSPropGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| RMSPropGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class RMSPropGpuKernel : public GpuKernel { | |||
| public: | |||
| RMSPropGpuKernel() : size_(1), use_center_(false) {} | |||
| ~RMSPropGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream) override { | |||
| if (!use_center_) { | |||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||
| T *mean_square = GetDeviceAddress<T>(inputs, 1); | |||
| T *moment = GetDeviceAddress<T>(inputs, 2); | |||
| T *gradients = GetDeviceAddress<T>(inputs, 3); | |||
| T *learning_rate = GetDeviceAddress<T>(inputs, 4); | |||
| T *decay = GetDeviceAddress<T>(inputs, 5); | |||
| T *momentum = GetDeviceAddress<T>(inputs, 6); | |||
| T *epsilon = GetDeviceAddress<T>(inputs, 7); | |||
| RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_, | |||
| reinterpret_cast<cudaStream_t>(stream)); | |||
| } else { | |||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||
| T *mean_gradients = GetDeviceAddress<T>(inputs, 1); | |||
| T *mean_square = GetDeviceAddress<T>(inputs, 2); | |||
| T *moment = GetDeviceAddress<T>(inputs, 3); | |||
| T *gradients = GetDeviceAddress<T>(inputs, 4); | |||
| T *learning_rate = GetDeviceAddress<T>(inputs, 5); | |||
| T *decay = GetDeviceAddress<T>(inputs, 6); | |||
| T *momentum = GetDeviceAddress<T>(inputs, 7); | |||
| T *epsilon = GetDeviceAddress<T>(inputs, 8); | |||
| RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients, | |||
| size_, reinterpret_cast<cudaStream_t>(stream)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (node_name == "ApplyCenteredRMSProp") { | |||
| use_center_ = true; | |||
| } | |||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| for (auto &dim : input_shape) { | |||
| size_ *= dim; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| size_t input_size = size_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); | |||
| if (use_center_) { | |||
| input_size_list_.push_back(input_size); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(0); | |||
| } | |||
| private: | |||
| size_t size_; | |||
| bool use_center_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -0,0 +1,152 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class NetRMSProp(nn.Cell): | |||
| def __init__(self, use_centered): | |||
| super(NetRMSProp, self).__init__() | |||
| self.use_centered = use_centered | |||
| if use_centered: | |||
| self.rms_opt = P.ApplyCenteredRMSProp() | |||
| else: | |||
| self.rms_opt = P.ApplyRMSProp() | |||
| def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): | |||
| if self.use_centered: | |||
| return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) | |||
| else: | |||
| return self.rms_opt(var, rms, mom, g, lr, decay, momentum, epsilon) | |||
| def rmsprop_numpy(variable, gradients, mean_square, moment, | |||
| learning_rate, decay, momentum, epsilon): | |||
| mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients | |||
| moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients | |||
| variable = variable - moment | |||
| def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment, | |||
| learning_rate, decay, momentum, epsilon): | |||
| mean_gradients = mean_gradients * decay + (1.0 - decay) * gradients | |||
| mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients | |||
| moment = momentum * moment + learning_rate / np.sqrt(mean_square -mean_gradients * mean_gradients + epsilon) * gradients | |||
| variable = variable - moment | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_rmsprop(): | |||
| learning_rate, decay, momentum, epsilon, centered = [0.5, 0.8, 0.9, 1e-3, True] | |||
| variable_np = np.array([1.0, 2.0], dtype=np.float32) | |||
| gradients_np = np.array([0.1, 0.2], dtype=np.float32) | |||
| mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32) | |||
| mean_square_np = np.array([epsilon, epsilon], dtype=np.float32) | |||
| moment_np = np.array([0.0, 0.0], dtype=np.float32) | |||
| variable_ms = Tensor(variable_np) | |||
| gradients_ms = Tensor(gradients_np) | |||
| mean_gradients_ms = Tensor(mean_gradients_np) | |||
| mean_square_ms = Tensor(mean_square_np) | |||
| moment_ms = Tensor(moment_np) | |||
| if centered: | |||
| rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| else: | |||
| rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(centered) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, | |||
| moment_ms, learning_rate, decay, momentum, epsilon) | |||
| error = np.ones(shape=variable_np.shape) * 10e-6 | |||
| diff = variable_ms.asnumpy() - variable_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=gradients_np.shape) * 10e-6 | |||
| diff = gradients_ms.asnumpy() - gradients_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=mean_gradients_np.shape) * 10e-6 | |||
| diff = mean_gradients_ms.asnumpy() - mean_gradients_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=mean_square_np.shape) * 10e-6 | |||
| diff = mean_square_ms.asnumpy() - mean_square_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=moment_np.shape) * 10e-6 | |||
| diff = moment_ms.asnumpy() - moment_np | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_rmspropcenter(): | |||
| learning_rate, decay, momentum, epsilon, centered = [0.1, 0.3, 0.9, 1.0, False] | |||
| variable_np = np.array([1.0, 2.0], dtype=np.float32) | |||
| gradients_np = np.array([0.1, 0.2], dtype=np.float32) | |||
| mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32) | |||
| mean_square_np = np.array([epsilon, epsilon], dtype=np.float32) | |||
| moment_np = np.array([0.0, 0.0], dtype=np.float32) | |||
| variable_ms = Tensor(variable_np) | |||
| gradients_ms = Tensor(gradients_np) | |||
| mean_gradients_ms = Tensor(mean_gradients_np) | |||
| mean_square_ms = Tensor(mean_square_np) | |||
| moment_ms = Tensor(moment_np) | |||
| if centered: | |||
| rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| else: | |||
| rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(centered) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, | |||
| learning_rate, decay, momentum, epsilon) | |||
| error = np.ones(shape=variable_np.shape) * 10e-6 | |||
| diff = variable_ms.asnumpy() - variable_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=gradients_np.shape) * 10e-6 | |||
| diff = gradients_ms.asnumpy() - gradients_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=mean_gradients_np.shape) * 10e-6 | |||
| diff = mean_gradients_ms.asnumpy() - mean_gradients_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=mean_square_np.shape) * 10e-6 | |||
| diff = mean_square_ms.asnumpy() - mean_square_np | |||
| assert np.all(diff < error) | |||
| error = np.ones(shape=moment_np.shape) * 10e-6 | |||
| diff = moment_ms.asnumpy() - moment_np | |||
| assert np.all(diff < error) | |||