diff --git a/mindspore/ccsrc/kernel/common_utils.cc b/mindspore/ccsrc/kernel/common_utils.cc index e4626940d0..e80037fa6e 100644 --- a/mindspore/ccsrc/kernel/common_utils.cc +++ b/mindspore/ccsrc/kernel/common_utils.cc @@ -632,7 +632,7 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie } last_index = index; } - unique_grad->indices_size_ = unique_indices_size; + unique_grad->indices_size_ = unique_indices_size + 1; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc index 4d03645578..13450d0485 100644 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc @@ -22,6 +22,13 @@ namespace { constexpr size_t kSparseApplyAdamInputSize = 11; } // namespace +void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -50,7 +57,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { } indices_size_ = indices_shape[0]; if (grad_shape[0] != indices_size_) { - MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); @@ -58,7 +65,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { } void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, - float *v, float beta1, float beta2) { + float *v, float beta1, float beta2) const { MS_EXCEPTION_IF_NULL(m); MS_EXCEPTION_IF_NULL(m_t); MS_EXCEPTION_IF_NULL(v); @@ -81,7 +88,7 @@ void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique } bool SparseApplyAdamCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, + const std::vector &workspace, const std::vector & /*outputs*/) { if (inputs.size() < kSparseApplyAdamInputSize) { MS_LOG(EXCEPTION) << "Error input size!"; @@ -101,14 +108,12 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector &inp auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; auto grad = reinterpret_cast(inputs[9]->addr); auto indices = reinterpret_cast(inputs[10]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); - std::vector new_grad; - new_grad.reserve(indices_size_ * var_outer_dim_size_); - std::vector new_indices; - new_indices.reserve(indices_size_); - SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_}); - DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; // Update momentum lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h index ea1ce54995..71be65eca2 100644 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h @@ -30,13 +30,13 @@ class SparseApplyAdamCPUKernel : public CPUKernel { ~SparseApplyAdamCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; - + void InitInputOutputSize(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; private: void UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, float *v, float beta1, - float beta2); + float beta2) const; size_t indices_size_{0}; size_t var_first_dim_size_{0}; size_t var_outer_dim_size_{1}; diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc index 9b4d536b36..49ef0813fa 100644 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc @@ -58,7 +58,7 @@ void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { } indices_size_ = indices_shape[0]; if (grad_shape[0] != indices_size_) { - MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } lr_ = AnfAlgo::GetNodeAttr(kernel_node, "lr"); if (lr_ <= 0) { diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc index c0e091f02b..0d6e0405d9 100644 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc @@ -23,6 +23,13 @@ namespace { constexpr size_t kSparseApplyLazyAdamInputSize = 11; } // namespace +void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -51,7 +58,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { } indices_size_ = indices_shape[0]; if (grad_shape[0] != indices_size_) { - MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); @@ -59,7 +66,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { } bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, + const std::vector &workspace, const std::vector & /*outputs*/) { if (inputs.size() < kSparseApplyLazyAdamInputSize) { MS_LOG(EXCEPTION) << "Error input size!"; @@ -79,14 +86,12 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; auto grad = reinterpret_cast(inputs[9]->addr); auto indices = reinterpret_cast(inputs[10]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); - std::vector new_grad; - new_grad.reserve(indices_size_ * var_outer_dim_size_); - std::vector new_indices; - new_indices.reserve(indices_size_); - SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_}); - DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h index 0a52181561..795568a64d 100644 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h @@ -29,7 +29,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel { ~SparseApplyLazyAdamCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; - + void InitInputOutputSize(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc new file mode 100644 index 0000000000..69b4755bc1 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" +#include "kernel/common_utils.h" +#include "device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyProximalAdagradInputSize = 7; +} // namespace + +void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + std::vector l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); + if (!IsSameShape(var_shape, accum_shape)) { + MS_LOG(EXCEPTION) << "var and accum should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (!lr_shape.empty()) { + MS_LOG(EXCEPTION) << "lr is not a scalar"; + } + if (!l1_shape.empty()) { + MS_LOG(EXCEPTION) << "l1 is not a scalar"; + } + if (!l2_shape.empty()) { + MS_LOG(EXCEPTION) << "l2 is not a scalar"; + } +} + +bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyProximalAdagradInputSize) { + MS_LOG(EXCEPTION) << "Wrong input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto accum = reinterpret_cast(inputs[1]->addr); + auto lr = reinterpret_cast(inputs[2]->addr)[0]; + auto l1 = reinterpret_cast(inputs[3]->addr)[0]; + auto l2 = reinterpret_cast(inputs[4]->addr)[0]; + auto grad = reinterpret_cast(inputs[5]->addr); + auto indices = reinterpret_cast(inputs[6]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); + + for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size_) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size_ * index; + size_t end_index = start_index + var_outer_dim_size_; + for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { + accum[j] += grad[k] * grad[k]; + auto learning_rate = lr * (1 / std::sqrt(accum[j])); + auto prox_v = var[j]; + prox_v -= grad[k] * learning_rate; + if (l1 > 0) { + var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast(0.0)) / + (1 + l2 * learning_rate); + } else { + var[j] = prox_v / (1 + l2 * learning_rate); + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h new file mode 100644 index 0000000000..082809a9c2 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ + +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyProximalAdagradCPUKernel : public CPUKernel { + public: + SparseApplyProximalAdagradCPUKernel() = default; + ~SparseApplyProximalAdagradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; +}; + +MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyProximalAdagradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_sparse_apply_adam_op.py b/tests/st/ops/cpu/test_sparse_apply_adam_op.py index e81ac470b7..a62c8bfa10 100644 --- a/tests/st/ops/cpu/test_sparse_apply_adam_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_adam_op.py @@ -21,6 +21,13 @@ from mindspore.common.parameter import Parameter from mindspore.ops import operations as P import mindspore.common.dtype as mstype +beta1_power = 0.9 +beta2_power = 0.999 +lr = 0.001 +beta1 = 0.9 +beta2 = 0.999 +epsilon = 1e-8 + class Net(nn.Cell): def __init__(self): @@ -30,7 +37,7 @@ class Net(nn.Cell): self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m") self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v") - def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices): + def construct(self, grad, indices): out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices) return out @@ -42,5 +49,5 @@ def test_net(): context.set_context(mode=context.GRAPH_MODE, device_target="CPU") sparse_apply_adam = Net() - output = sparse_apply_adam(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices) + output = sparse_apply_adam(gradient, indices) print(output[0].asnumpy()) diff --git a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py new file mode 100644 index 0000000000..0eaa11a201 --- /dev/null +++ b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") + self.lr = 0.01 + self.l1 = 0.0 + self.l2 = 0.0 + + def construct(self, grad, indices): + out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices) + return out + + +def test_net(): + gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) + indices = Tensor([0, 1, 2], mstype.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + sparse_apply_proximal_adagrad = Net() + output = sparse_apply_proximal_adagrad(gradient, indices) + print(output.asnumpy()[0])