|
|
|
@@ -155,15 +155,14 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
input_params.v_ = v; |
|
|
|
input_params.beta1_ = beta1; |
|
|
|
input_params.beta2_ = beta2; |
|
|
|
const size_t kThreadNum = 16; |
|
|
|
MultiThreadCompute(ComputeMomentum, &input_params, kThreadNum, total_dim_size); |
|
|
|
MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size); |
|
|
|
|
|
|
|
input_params.m_t_ = m_t; |
|
|
|
input_params.use_nesterov_ = use_nesterov_; |
|
|
|
input_params.sparse_grad_ = unique_sparse_grad; |
|
|
|
input_params.var_first_dim_size_ = var_first_dim_size_; |
|
|
|
input_params.var_outer_dim_size_ = var_outer_dim_size_; |
|
|
|
MultiThreadCompute(ComputeAdam, &input_params, kThreadNum, unique_sparse_grad.indices_size_); |
|
|
|
MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_); |
|
|
|
|
|
|
|
if (use_nesterov_) { |
|
|
|
input_params.m_ = input_params.m_t_; |
|
|
|
@@ -171,7 +170,7 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
input_params.var_ = var; |
|
|
|
input_params.lr_ = lr; |
|
|
|
input_params.epsilon_ = epsilon; |
|
|
|
MultiThreadCompute(ComputeWeight, &input_params, kThreadNum, total_dim_size); |
|
|
|
MultiThreadCompute(ComputeWeight, &input_params, total_dim_size); |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
|