| @@ -632,7 +632,7 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie | |||||
| } | } | ||||
| last_index = index; | last_index = index; | ||||
| } | } | ||||
| unique_grad->indices_size_ = unique_indices_size; | |||||
| unique_grad->indices_size_ = unique_indices_size + 1; | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,6 +22,13 @@ namespace { | |||||
| constexpr size_t kSparseApplyAdamInputSize = 11; | constexpr size_t kSparseApplyAdamInputSize = 11; | ||||
| } // namespace | } // namespace | ||||
| void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||||
| CPUKernel::InitInputOutputSize(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| } | |||||
| void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| @@ -50,7 +57,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| indices_size_ = indices_shape[0]; | indices_size_ = indices_shape[0]; | ||||
| if (grad_shape[0] != indices_size_) { | if (grad_shape[0] != indices_size_) { | ||||
| MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; | |||||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||||
| } | } | ||||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { | if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { | ||||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov"); | use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov"); | ||||
| @@ -58,7 +65,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, | void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, | ||||
| float *v, float beta1, float beta2) { | |||||
| float *v, float beta1, float beta2) const { | |||||
| MS_EXCEPTION_IF_NULL(m); | MS_EXCEPTION_IF_NULL(m); | ||||
| MS_EXCEPTION_IF_NULL(m_t); | MS_EXCEPTION_IF_NULL(m_t); | ||||
| MS_EXCEPTION_IF_NULL(v); | MS_EXCEPTION_IF_NULL(v); | ||||
| @@ -81,7 +88,7 @@ void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique | |||||
| } | } | ||||
| bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &workspace, | |||||
| const std::vector<kernel::AddressPtr> & /*outputs*/) { | const std::vector<kernel::AddressPtr> & /*outputs*/) { | ||||
| if (inputs.size() < kSparseApplyAdamInputSize) { | if (inputs.size() < kSparseApplyAdamInputSize) { | ||||
| MS_LOG(EXCEPTION) << "Error input size!"; | MS_LOG(EXCEPTION) << "Error input size!"; | ||||
| @@ -101,14 +108,12 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp | |||||
| auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; | auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; | ||||
| auto grad = reinterpret_cast<float *>(inputs[9]->addr); | auto grad = reinterpret_cast<float *>(inputs[9]->addr); | ||||
| auto indices = reinterpret_cast<int *>(inputs[10]->addr); | auto indices = reinterpret_cast<int *>(inputs[10]->addr); | ||||
| auto new_grad = reinterpret_cast<float *>(workspace[0]->addr); | |||||
| auto new_indices = reinterpret_cast<int *>(workspace[1]->addr); | |||||
| std::vector<float> new_grad; | |||||
| new_grad.reserve(indices_size_ * var_outer_dim_size_); | |||||
| std::vector<int> new_indices; | |||||
| new_indices.reserve(indices_size_); | |||||
| SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_}); | |||||
| DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||||
| var_outer_dim_size_); | |||||
| SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); | |||||
| ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||||
| var_outer_dim_size_); | |||||
| size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; | size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; | ||||
| // Update momentum | // Update momentum | ||||
| lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | ||||
| @@ -30,13 +30,13 @@ class SparseApplyAdamCPUKernel : public CPUKernel { | |||||
| ~SparseApplyAdamCPUKernel() override = default; | ~SparseApplyAdamCPUKernel() override = default; | ||||
| void InitKernel(const CNodePtr &kernel_node) override; | void InitKernel(const CNodePtr &kernel_node) override; | ||||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| void UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, float *v, float beta1, | void UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, float *v, float beta1, | ||||
| float beta2); | |||||
| float beta2) const; | |||||
| size_t indices_size_{0}; | size_t indices_size_{0}; | ||||
| size_t var_first_dim_size_{0}; | size_t var_first_dim_size_{0}; | ||||
| size_t var_outer_dim_size_{1}; | size_t var_outer_dim_size_{1}; | ||||
| @@ -58,7 +58,7 @@ void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| indices_size_ = indices_shape[0]; | indices_size_ = indices_shape[0]; | ||||
| if (grad_shape[0] != indices_size_) { | if (grad_shape[0] != indices_size_) { | ||||
| MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; | |||||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||||
| } | } | ||||
| lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr"); | lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr"); | ||||
| if (lr_ <= 0) { | if (lr_ <= 0) { | ||||
| @@ -23,6 +23,13 @@ namespace { | |||||
| constexpr size_t kSparseApplyLazyAdamInputSize = 11; | constexpr size_t kSparseApplyLazyAdamInputSize = 11; | ||||
| } // namespace | } // namespace | ||||
| void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||||
| CPUKernel::InitInputOutputSize(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| } | |||||
| void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| @@ -51,7 +58,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| indices_size_ = indices_shape[0]; | indices_size_ = indices_shape[0]; | ||||
| if (grad_shape[0] != indices_size_) { | if (grad_shape[0] != indices_size_) { | ||||
| MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; | |||||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||||
| } | } | ||||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { | if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { | ||||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov"); | use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov"); | ||||
| @@ -59,7 +66,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &workspace, | |||||
| const std::vector<kernel::AddressPtr> & /*outputs*/) { | const std::vector<kernel::AddressPtr> & /*outputs*/) { | ||||
| if (inputs.size() < kSparseApplyLazyAdamInputSize) { | if (inputs.size() < kSparseApplyLazyAdamInputSize) { | ||||
| MS_LOG(EXCEPTION) << "Error input size!"; | MS_LOG(EXCEPTION) << "Error input size!"; | ||||
| @@ -79,14 +86,12 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> | |||||
| auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; | auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; | ||||
| auto grad = reinterpret_cast<float *>(inputs[9]->addr); | auto grad = reinterpret_cast<float *>(inputs[9]->addr); | ||||
| auto indices = reinterpret_cast<int *>(inputs[10]->addr); | auto indices = reinterpret_cast<int *>(inputs[10]->addr); | ||||
| auto new_grad = reinterpret_cast<float *>(workspace[0]->addr); | |||||
| auto new_indices = reinterpret_cast<int *>(workspace[1]->addr); | |||||
| std::vector<float> new_grad; | |||||
| new_grad.reserve(indices_size_ * var_outer_dim_size_); | |||||
| std::vector<int> new_indices; | |||||
| new_indices.reserve(indices_size_); | |||||
| SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_}); | |||||
| DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||||
| var_outer_dim_size_); | |||||
| SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); | |||||
| ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||||
| var_outer_dim_size_); | |||||
| lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | ||||
| for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { | for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { | ||||
| @@ -29,7 +29,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel { | |||||
| ~SparseApplyLazyAdamCPUKernel() override = default; | ~SparseApplyLazyAdamCPUKernel() override = default; | ||||
| void InitKernel(const CNodePtr &kernel_node) override; | void InitKernel(const CNodePtr &kernel_node) override; | ||||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| @@ -0,0 +1,116 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| namespace { | |||||
| constexpr size_t kSparseApplyProximalAdagradInputSize = 7; | |||||
| } // namespace | |||||
| void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||||
| CPUKernel::InitInputOutputSize(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| } | |||||
| void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| std::vector<size_t> accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| std::vector<size_t> lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| std::vector<size_t> l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); | |||||
| std::vector<size_t> l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); | |||||
| std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); | |||||
| std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); | |||||
| if (!IsSameShape(var_shape, accum_shape)) { | |||||
| MS_LOG(EXCEPTION) << "var and accum should have the same shape"; | |||||
| } | |||||
| if (var_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "var must be at least 1D"; | |||||
| } | |||||
| var_first_dim_size_ = var_shape[0]; | |||||
| for (size_t i = 1; i < var_shape.size(); ++i) { | |||||
| if (var_shape[i] != grad_shape[i]) { | |||||
| MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; | |||||
| } | |||||
| var_outer_dim_size_ *= var_shape[i]; | |||||
| } | |||||
| if (indices_shape.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "indices must be a 1D vector"; | |||||
| } | |||||
| indices_size_ = indices_shape[0]; | |||||
| if (grad_shape[0] != indices_size_) { | |||||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||||
| } | |||||
| if (!lr_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "lr is not a scalar"; | |||||
| } | |||||
| if (!l1_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "l1 is not a scalar"; | |||||
| } | |||||
| if (!l2_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "l2 is not a scalar"; | |||||
| } | |||||
| } | |||||
| bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &workspace, | |||||
| const std::vector<kernel::AddressPtr> & /*outputs*/) { | |||||
| if (inputs.size() < kSparseApplyProximalAdagradInputSize) { | |||||
| MS_LOG(EXCEPTION) << "Wrong input size!"; | |||||
| } | |||||
| auto var = reinterpret_cast<float *>(inputs[0]->addr); | |||||
| auto accum = reinterpret_cast<float *>(inputs[1]->addr); | |||||
| auto lr = reinterpret_cast<float *>(inputs[2]->addr)[0]; | |||||
| auto l1 = reinterpret_cast<float *>(inputs[3]->addr)[0]; | |||||
| auto l2 = reinterpret_cast<float *>(inputs[4]->addr)[0]; | |||||
| auto grad = reinterpret_cast<float *>(inputs[5]->addr); | |||||
| auto indices = reinterpret_cast<int *>(inputs[6]->addr); | |||||
| auto new_grad = reinterpret_cast<float *>(workspace[0]->addr); | |||||
| auto new_indices = reinterpret_cast<int *>(workspace[1]->addr); | |||||
| SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); | |||||
| ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, | |||||
| var_outer_dim_size_); | |||||
| for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) { | |||||
| int index = unique_sparse_grad.indices_[i]; | |||||
| if (index < 0 || IntToSize(index) >= var_first_dim_size_) { | |||||
| MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; | |||||
| } | |||||
| size_t start_index = var_outer_dim_size_ * index; | |||||
| size_t end_index = start_index + var_outer_dim_size_; | |||||
| for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { | |||||
| accum[j] += grad[k] * grad[k]; | |||||
| auto learning_rate = lr * (1 / std::sqrt(accum[j])); | |||||
| auto prox_v = var[j]; | |||||
| prox_v -= grad[k] * learning_rate; | |||||
| if (l1 > 0) { | |||||
| var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast<float>(0.0)) / | |||||
| (1 + l2 * learning_rate); | |||||
| } else { | |||||
| var[j] = prox_v / (1 + l2 * learning_rate); | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "kernel/cpu/cpu_kernel.h" | |||||
| #include "kernel/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class SparseApplyProximalAdagradCPUKernel : public CPUKernel { | |||||
| public: | |||||
| SparseApplyProximalAdagradCPUKernel() = default; | |||||
| ~SparseApplyProximalAdagradCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| private: | |||||
| size_t indices_size_{0}; | |||||
| size_t var_first_dim_size_{0}; | |||||
| size_t var_outer_dim_size_{1}; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| SparseApplyProximalAdagradCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ | |||||
| @@ -21,6 +21,13 @@ from mindspore.common.parameter import Parameter | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| beta1_power = 0.9 | |||||
| beta2_power = 0.999 | |||||
| lr = 0.001 | |||||
| beta1 = 0.9 | |||||
| beta2 = 0.999 | |||||
| epsilon = 1e-8 | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -30,7 +37,7 @@ class Net(nn.Cell): | |||||
| self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m") | self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m") | ||||
| self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v") | self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v") | ||||
| def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices): | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, | out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, | ||||
| grad, indices) | grad, indices) | ||||
| return out | return out | ||||
| @@ -42,5 +49,5 @@ def test_net(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | ||||
| sparse_apply_adam = Net() | sparse_apply_adam = Net() | ||||
| output = sparse_apply_adam(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices) | |||||
| output = sparse_apply_adam(gradient, indices) | |||||
| print(output[0].asnumpy()) | print(output[0].asnumpy()) | ||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.common.dtype as mstype | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||||
| self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") | |||||
| self.lr = 0.01 | |||||
| self.l1 = 0.0 | |||||
| self.l2 = 0.0 | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices) | |||||
| return out | |||||
| def test_net(): | |||||
| gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||||
| indices = Tensor([0, 1, 2], mstype.int32) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| sparse_apply_proximal_adagrad = Net() | |||||
| output = sparse_apply_proximal_adagrad(gradient, indices) | |||||
| print(output.asnumpy()[0]) | |||||