diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc new file mode 100644 index 0000000000..452f9123e7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/rmsprop_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void RMSPropCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "ApplyCenteredRMSProp") { + use_center_ = true; + } + + if (node_name == "ApplyRMSProp") { + decay_ = AnfAlgo::GetNodeAttr(kernel_node, "rho"); + momentum_ = AnfAlgo::GetNodeAttr(kernel_node, "momentum"); + epsilon_ = AnfAlgo::GetNodeAttr(kernel_node, "epsilon"); + } + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto &dim : input_shape) { + size_ *= dim; + } +} + +bool RMSPropCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (!use_center_) { + float *variable = reinterpret_cast(inputs[0]->addr); + float *mean_square = reinterpret_cast(inputs[1]->addr); + float *moment = reinterpret_cast(inputs[2]->addr); + float *learning_rate = reinterpret_cast(inputs[3]->addr); + float *gradients = reinterpret_cast(inputs[4]->addr); + + for (size_t i = 0; i < size_; i++) { + mean_square[i] += (gradients[i] * gradients[i] - mean_square[i]) * (1.0 - decay_); + moment[i] = moment[i] * momentum_ + (gradients[i] * learning_rate[0]) / sqrt(mean_square[i] + epsilon_); + variable[i] -= moment[i]; + } + } else { + float *variable = reinterpret_cast(inputs[0]->addr); + float *mean_gradients = reinterpret_cast(inputs[1]->addr); + float *mean_square = reinterpret_cast(inputs[2]->addr); + float *moment = reinterpret_cast(inputs[3]->addr); + float *gradients = reinterpret_cast(inputs[4]->addr); + float *learning_rate = reinterpret_cast(inputs[5]->addr); + float *decay = reinterpret_cast(inputs[6]->addr); + float *momentum = reinterpret_cast(inputs[7]->addr); + float *epsilon = reinterpret_cast(inputs[8]->addr); + + for (size_t i = 0; i < size_; i++) { + mean_square[i] += (gradients[i] * gradients[i] - mean_square[i]) * (1.0 - decay[0]); + mean_gradients[i] += (gradients[i] - mean_gradients[i]) * (1.0 - decay[0]); + auto denom = (mean_square[i] - mean_gradients[i] * mean_gradients[i]) + epsilon[0]; + if (denom > 0) { + moment[i] = moment[i] * momentum[0] + (gradients[i] * learning_rate[0]) / sqrt(denom); + variable[i] -= moment[i]; + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.h new file mode 100644 index 0000000000..661da1621a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class RMSPropCPUKernel : public CPUKernel { + public: + RMSPropCPUKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} + ~RMSPropCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t size_; + bool use_center_; + float decay_; + float momentum_; + float epsilon_; +}; + +MS_REG_CPU_KERNEL(ApplyRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropCPUKernel); + +MS_REG_CPU_KERNEL(ApplyCenteredRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_rmsprop.py b/tests/st/ops/cpu/test_rmsprop.py new file mode 100644 index 0000000000..4121d46071 --- /dev/null +++ b/tests/st/ops/cpu/test_rmsprop.py @@ -0,0 +1,202 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetCenteredRMSProp(nn.Cell): + def __init__(self, lr, decay, momentum, epsilon, var, g, mg, rms, mom): + super(NetCenteredRMSProp, self).__init__() + self.rms_opt = P.ApplyCenteredRMSProp() + self.lr = lr + self.decay = decay + self.momentum = momentum + self.epsilon = epsilon + self.var = var + self.g = g + self.mg = mg + self.rms = rms + self.mom = mom + + def construct(self): + return self.rms_opt(self.var, self.mg, self.rms, self.mom, self.g, self.lr, self.decay, self.momentum, + self.epsilon) + + +class NetRMSProp(nn.Cell): + def __init__(self, lr, decay, momentum, epsilon, var, g, mg, rms, mom): + super(NetRMSProp, self).__init__() + self.lr = lr + self.decay = decay + self.momentum = momentum + self.epsilon = epsilon + self.var = var + self.g = g + self.mg = mg + self.rms = rms + self.mom = mom + self.rms_opt = P.ApplyRMSProp() + + def construct(self): + return self.rms_opt(self.var, self.rms, self.mom, self.lr, self.g, self.decay, self.momentum, self.epsilon) + + +def rmsprop_numpy(variable, gradients, mean_square, moment, + learning_rate, decay, momentum, epsilon): + mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients + moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients + variable = variable - moment + return variable, gradients, mean_square, moment + + +def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment, + learning_rate, decay, momentum, epsilon): + mean_gradients = mean_gradients * decay + (1.0 - decay) * gradients + mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients + moment = momentum * moment + learning_rate / np.sqrt( + mean_square - mean_gradients * mean_gradients + epsilon) * gradients + variable = variable - moment + return variable, gradients, mean_gradients, mean_square, moment + + +@pytest.mark.level0 +@pytest.mark.platform_cpu_training +@pytest.mark.env_onecard +def test_rmsprop(): + learning_rate, decay, momentum, epsilon, centered = [0.5, 0.8, 0.9, 1e-3, True] + + variable_np = np.array([1.0, 2.0], dtype=np.float32) + gradients_np = np.array([0.1, 0.2], dtype=np.float32) + mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32) + mean_square_np = np.array([epsilon, epsilon], dtype=np.float32) + moment_np = np.array([0.0, 0.0], dtype=np.float32) + + variable = Tensor(variable_np) + gradients = Tensor(gradients_np) + mean_gradients = Tensor(mean_gradients_np) + mean_square = Tensor(mean_square_np) + moment = Tensor(moment_np) + + variable_ms = Parameter(initializer(variable, variable.shape), name='var') + gradients_ms = Parameter(initializer(gradients, gradients.shape), name='grad') + mean_gradients_ms = Parameter(initializer(mean_gradients, mean_gradients.shape), name='mg') + mean_square_ms = Parameter(initializer(mean_square, mean_square.shape), name='msr') + moment_ms = Parameter(initializer(moment, moment.shape), name='mom') + + if centered: + variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \ + rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms, + mean_square_ms, moment_ms) + _ = net() + + else: + variable_np, gradients_np, mean_square_np, moment_np = \ + rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + net = NetRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms, + mean_square_ms, moment_ms) + _ = net() + + error = np.ones(shape=variable_np.shape) * 10e-6 + diff = variable_ms.asnumpy() - variable_np + assert np.all(diff < error) + + error = np.ones(shape=gradients_np.shape) * 10e-6 + diff = gradients_ms.asnumpy() - gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_gradients_np.shape) * 10e-6 + diff = mean_gradients_ms.asnumpy() - mean_gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_square_np.shape) * 10e-6 + diff = mean_square_ms.asnumpy() - mean_square_np + assert np.all(diff < error) + + error = np.ones(shape=moment_np.shape) * 10e-6 + diff = moment_ms.asnumpy() - moment_np + assert np.all(diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_cpu_training +@pytest.mark.env_onecard +def test_rmspropcenter(): + learning_rate, decay, momentum, epsilon, centered = [0.1, 0.3, 0.9, 1.0, False] + + variable_np = np.array([1.0, 2.0], dtype=np.float32) + gradients_np = np.array([0.1, 0.2], dtype=np.float32) + mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32) + mean_square_np = np.array([epsilon, epsilon], dtype=np.float32) + moment_np = np.array([0.0, 0.0], dtype=np.float32) + + variable = Tensor(variable_np) + gradients = Tensor(gradients_np) + mean_gradients = Tensor(mean_gradients_np) + mean_square = Tensor(mean_square_np) + moment = Tensor(moment_np) + + variable_ms = Parameter(initializer(variable, variable.shape), name='var') + gradients_ms = Parameter(initializer(gradients, gradients.shape), name='grad') + mean_gradients_ms = Parameter(initializer(mean_gradients, mean_gradients.shape), name='mg') + mean_square_ms = Parameter(initializer(mean_square, mean_square.shape), name='msr') + moment_ms = Parameter(initializer(moment, moment.shape), name='mom') + + if centered: + variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \ + rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms, + mean_square_ms, moment_ms) + _ = net() + else: + variable_np, gradients_np, mean_square_np, moment_np = \ + rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + net = NetRMSProp(learning_rate, decay, momentum, epsilon, variable_ms, gradients_ms, mean_gradients_ms, + mean_square_ms, moment_ms) + _ = net() + + error = np.ones(shape=variable_np.shape) * 10e-6 + diff = variable_ms.asnumpy() - variable_np + assert np.all(diff < error) + + error = np.ones(shape=gradients_np.shape) * 10e-6 + diff = gradients_ms.asnumpy() - gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_gradients_np.shape) * 10e-6 + diff = mean_gradients_ms.asnumpy() - mean_gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_square_np.shape) * 10e-6 + diff = mean_square_ms.asnumpy() - mean_square_np + assert np.all(diff < error) + + error = np.ones(shape=moment_np.shape) * 10e-6 + diff = moment_ms.asnumpy() - moment_np + assert np.all(diff < error)