|
|
|
@@ -86,11 +86,15 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; |
|
|
|
auto max_thread_num = std::thread::hardware_concurrency(); |
|
|
|
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; |
|
|
|
MS_LOG(INFO) << "lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; |
|
|
|
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; |
|
|
|
std::vector<std::thread> threads; |
|
|
|
threads.reserve(thread_num); |
|
|
|
size_t start = 0; |
|
|
|
size_t once_compute_size = (lens + thread_num - 1) / thread_num; |
|
|
|
if (thread_num < 1 || once_compute_size < 1) { |
|
|
|
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; |
|
|
|
return false; |
|
|
|
} |
|
|
|
while (start < lens) { |
|
|
|
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); |
|
|
|
threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon, |
|
|
|
|