| @@ -632,7 +632,7 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie | |||
| } | |||
| last_index = index; | |||
| } | |||
| unique_grad->indices_size_ = unique_indices_size; | |||
| unique_grad->indices_size_ = unique_indices_size + 1; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,13 @@ namespace { | |||
| constexpr size_t kSparseApplyAdamInputSize = 11; | |||
| } // namespace | |||
| void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| } | |||
| void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| @@ -50,7 +57,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| indices_size_ = indices_shape[0]; | |||
| if (grad_shape[0] != indices_size_) { | |||
| MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; | |||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||
| } | |||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { | |||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov"); | |||
| @@ -58,7 +65,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, | |||
| float *v, float beta1, float beta2) { | |||
| float *v, float beta1, float beta2) const { | |||
| MS_EXCEPTION_IF_NULL(m); | |||
| MS_EXCEPTION_IF_NULL(m_t); | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| @@ -81,7 +88,7 @@ void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique | |||
| } | |||
| bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> & /*outputs*/) { | |||
| if (inputs.size() < kSparseApplyAdamInputSize) { | |||
| MS_LOG(EXCEPTION) << "Error input size!"; | |||
| @@ -101,14 +108,12 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp | |||
| auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; | |||
| auto grad = reinterpret_cast<float *>(inputs[9]->addr); | |||
| auto indices = reinterpret_cast<int *>(inputs[10]->addr); | |||
| auto new_grad = reinterpret_cast<float *>(workspace[0]->addr); | |||
| auto new_indices = reinterpret_cast<int *>(workspace[1]->addr); | |||
| std::vector<float> new_grad; | |||
| new_grad.reserve(indices_size_ * var_outer_dim_size_); | |||
| std::vector<int> new_indices; | |||
| new_indices.reserve(indices_size_); | |||
| SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_}); | |||
| DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||
| var_outer_dim_size_); | |||
| SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); | |||
| ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||
| var_outer_dim_size_); | |||
| size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; | |||
| // Update momentum | |||
| lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | |||
| @@ -30,13 +30,13 @@ class SparseApplyAdamCPUKernel : public CPUKernel { | |||
| ~SparseApplyAdamCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| void UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, float *v, float beta1, | |||
| float beta2); | |||
| float beta2) const; | |||
| size_t indices_size_{0}; | |||
| size_t var_first_dim_size_{0}; | |||
| size_t var_outer_dim_size_{1}; | |||
| @@ -58,7 +58,7 @@ void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| indices_size_ = indices_shape[0]; | |||
| if (grad_shape[0] != indices_size_) { | |||
| MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; | |||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||
| } | |||
| lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr"); | |||
| if (lr_ <= 0) { | |||
| @@ -23,6 +23,13 @@ namespace { | |||
| constexpr size_t kSparseApplyLazyAdamInputSize = 11; | |||
| } // namespace | |||
| void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| } | |||
| void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| @@ -51,7 +58,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| indices_size_ = indices_shape[0]; | |||
| if (grad_shape[0] != indices_size_) { | |||
| MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; | |||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||
| } | |||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { | |||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov"); | |||
| @@ -59,7 +66,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> & /*outputs*/) { | |||
| if (inputs.size() < kSparseApplyLazyAdamInputSize) { | |||
| MS_LOG(EXCEPTION) << "Error input size!"; | |||
| @@ -79,14 +86,12 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> | |||
| auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; | |||
| auto grad = reinterpret_cast<float *>(inputs[9]->addr); | |||
| auto indices = reinterpret_cast<int *>(inputs[10]->addr); | |||
| auto new_grad = reinterpret_cast<float *>(workspace[0]->addr); | |||
| auto new_indices = reinterpret_cast<int *>(workspace[1]->addr); | |||
| std::vector<float> new_grad; | |||
| new_grad.reserve(indices_size_ * var_outer_dim_size_); | |||
| std::vector<int> new_indices; | |||
| new_indices.reserve(indices_size_); | |||
| SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_}); | |||
| DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||
| var_outer_dim_size_); | |||
| SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); | |||
| ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||
| var_outer_dim_size_); | |||
| lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | |||
| for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { | |||
| @@ -29,7 +29,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel { | |||
| ~SparseApplyLazyAdamCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| @@ -0,0 +1,116 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kSparseApplyProximalAdagradInputSize = 7; | |||
| } // namespace | |||
| void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| } | |||
| void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| std::vector<size_t> accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| std::vector<size_t> lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| std::vector<size_t> l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); | |||
| std::vector<size_t> l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); | |||
| std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); | |||
| std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); | |||
| if (!IsSameShape(var_shape, accum_shape)) { | |||
| MS_LOG(EXCEPTION) << "var and accum should have the same shape"; | |||
| } | |||
| if (var_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "var must be at least 1D"; | |||
| } | |||
| var_first_dim_size_ = var_shape[0]; | |||
| for (size_t i = 1; i < var_shape.size(); ++i) { | |||
| if (var_shape[i] != grad_shape[i]) { | |||
| MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; | |||
| } | |||
| var_outer_dim_size_ *= var_shape[i]; | |||
| } | |||
| if (indices_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "indices must be a 1D vector"; | |||
| } | |||
| indices_size_ = indices_shape[0]; | |||
| if (grad_shape[0] != indices_size_) { | |||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||
| } | |||
| if (!lr_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "lr is not a scalar"; | |||
| } | |||
| if (!l1_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "l1 is not a scalar"; | |||
| } | |||
| if (!l2_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "l2 is not a scalar"; | |||
| } | |||
| } | |||
| bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> & /*outputs*/) { | |||
| if (inputs.size() < kSparseApplyProximalAdagradInputSize) { | |||
| MS_LOG(EXCEPTION) << "Wrong input size!"; | |||
| } | |||
| auto var = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto accum = reinterpret_cast<float *>(inputs[1]->addr); | |||
| auto lr = reinterpret_cast<float *>(inputs[2]->addr)[0]; | |||
| auto l1 = reinterpret_cast<float *>(inputs[3]->addr)[0]; | |||
| auto l2 = reinterpret_cast<float *>(inputs[4]->addr)[0]; | |||
| auto grad = reinterpret_cast<float *>(inputs[5]->addr); | |||
| auto indices = reinterpret_cast<int *>(inputs[6]->addr); | |||
| auto new_grad = reinterpret_cast<float *>(workspace[0]->addr); | |||
| auto new_indices = reinterpret_cast<int *>(workspace[1]->addr); | |||
| SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); | |||
| ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||
| var_outer_dim_size_); | |||
| for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { | |||
| int index = unique_sparse_grad.indices_[i]; | |||
| if (index < 0 || IntToSize(index) >= var_first_dim_size_) { | |||
| MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; | |||
| } | |||
| size_t start_index = var_outer_dim_size_ * index; | |||
| size_t end_index = start_index + var_outer_dim_size_; | |||
| for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { | |||
| accum[j] += grad[k] * grad[k]; | |||
| auto learning_rate = lr * (1 / std::sqrt(accum[j])); | |||
| auto prox_v = var[j]; | |||
| prox_v -= grad[k] * learning_rate; | |||
| if (l1 > 0) { | |||
| var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast<float>(0.0)) / | |||
| (1 + l2 * learning_rate); | |||
| } else { | |||
| var[j] = prox_v / (1 + l2 * learning_rate); | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "kernel/cpu/cpu_kernel.h" | |||
| #include "kernel/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SparseApplyProximalAdagradCPUKernel : public CPUKernel { | |||
| public: | |||
| SparseApplyProximalAdagradCPUKernel() = default; | |||
| ~SparseApplyProximalAdagradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| size_t indices_size_{0}; | |||
| size_t var_first_dim_size_{0}; | |||
| size_t var_outer_dim_size_{1}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SparseApplyProximalAdagradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ | |||
| @@ -21,6 +21,13 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| import mindspore.common.dtype as mstype | |||
| beta1_power = 0.9 | |||
| beta2_power = 0.999 | |||
| lr = 0.001 | |||
| beta1 = 0.9 | |||
| beta2 = 0.999 | |||
| epsilon = 1e-8 | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| @@ -30,7 +37,7 @@ class Net(nn.Cell): | |||
| self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m") | |||
| self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v") | |||
| def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices): | |||
| def construct(self, grad, indices): | |||
| out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, | |||
| grad, indices) | |||
| return out | |||
| @@ -42,5 +49,5 @@ def test_net(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| sparse_apply_adam = Net() | |||
| output = sparse_apply_adam(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices) | |||
| output = sparse_apply_adam(gradient, indices) | |||
| print(output[0].asnumpy()) | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| import mindspore.common.dtype as mstype | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||
| self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") | |||
| self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") | |||
| self.lr = 0.01 | |||
| self.l1 = 0.0 | |||
| self.l2 = 0.0 | |||
| def construct(self, grad, indices): | |||
| out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices) | |||
| return out | |||
| def test_net(): | |||
| gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||
| indices = Tensor([0, 1, 2], mstype.int32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| sparse_apply_proximal_adagrad = Net() | |||
| output = sparse_apply_proximal_adagrad(gradient, indices) | |||
| print(output.asnumpy()[0]) | |||