From: @zhao_ting_v Reviewed-by: @kisnwang,@liangchenghui Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -86,11 +86,15 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; | size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; | ||||
| auto max_thread_num = std::thread::hardware_concurrency(); | auto max_thread_num = std::thread::hardware_concurrency(); | ||||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | ||||
| MS_LOG(INFO) << "lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||||
| std::vector<std::thread> threads; | std::vector<std::thread> threads; | ||||
| threads.reserve(thread_num); | threads.reserve(thread_num); | ||||
| size_t start = 0; | size_t start = 0; | ||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | size_t once_compute_size = (lens + thread_num - 1) / thread_num; | ||||
| if (thread_num < 1 || once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; | |||||
| return false; | |||||
| } | |||||
| while (start < lens) { | while (start < lens) { | ||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | ||||
| threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon, | threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon, | ||||
| @@ -217,6 +217,10 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co | |||||
| threads.reserve(thread_num); | threads.reserve(thread_num); | ||||
| size_t start = 0; | size_t start = 0; | ||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | size_t once_compute_size = (lens + thread_num - 1) / thread_num; | ||||
| if (thread_num < 1 || once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; | |||||
| return; | |||||
| } | |||||
| while (start < lens) { | while (start < lens) { | ||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | ||||
| if (operate_type_ == ADD) { | if (operate_type_ == ADD) { | ||||
| @@ -75,6 +75,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||||
| threads.reserve(thread_num); | threads.reserve(thread_num); | ||||
| size_t start = 0; | size_t start = 0; | ||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | size_t once_compute_size = (lens + thread_num - 1) / thread_num; | ||||
| if (thread_num < 1 || once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; | |||||
| return; | |||||
| } | |||||
| while (start < lens) { | while (start < lens) { | ||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | ||||
| if (operate_type_ == SQUARE) { | if (operate_type_ == SQUARE) { | ||||
| @@ -43,6 +43,10 @@ void LaunchCast(const std::vector<kernel::AddressPtr> &inputs, const std::vector | |||||
| threads.reserve(thread_num); | threads.reserve(thread_num); | ||||
| size_t start = 0; | size_t start = 0; | ||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | size_t once_compute_size = (lens + thread_num - 1) / thread_num; | ||||
| if (thread_num < 1 || once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; | |||||
| return; | |||||
| } | |||||
| while (start < lens) { | while (start < lens) { | ||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | ||||
| threads.emplace_back(std::thread(Cast<S, T>, input, output, start, end)); | threads.emplace_back(std::thread(Cast<S, T>, input, output, start, end)); | ||||
| @@ -149,6 +149,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c | |||||
| threads.reserve(thread_num); | threads.reserve(thread_num); | ||||
| size_t start = 0; | size_t start = 0; | ||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | size_t once_compute_size = (lens + thread_num - 1) / thread_num; | ||||
| if (thread_num < 1 || once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; | |||||
| return; | |||||
| } | |||||
| while (start < lens) { | while (start < lens) { | ||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | ||||
| if (operate_type_ == RELUGRAD) { | if (operate_type_ == RELUGRAD) { | ||||