Browse Source

!11577 Use ThreadPool in ParallelFor

From: @wuxuejian
Reviewed-by: @kisnwang,@guoqi1024,@c_34
Signed-off-by: @c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
48f2d82c0e
1 changed files with 9 additions and 7 deletions
  1. +9
    -7
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc

+ 9
- 7
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc View File

@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "common/thread_pool.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@@ -81,21 +82,22 @@ void CPUKernelUtils::GetElementNumEveryDim(const std::vector<size_t> &shape, std
} }
void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) { void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) {
auto max_thread_num = std::thread::hardware_concurrency();
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0; const float block_size = 128.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num; size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<std::thread> threads;
threads.reserve(thread_num);
std::vector<common::Task> tasks;
size_t start = 0; size_t start = 0;
size_t once_compute_size = (count + thread_num - 1) / thread_num; size_t once_compute_size = (count + thread_num - 1) / thread_num;
while (start < count) { while (start < count) {
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size); size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
threads.emplace_back(std::thread(task, start, end));
auto block = [&, start, end]() {
task(start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size; start += once_compute_size;
} }
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
} }
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) { std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {


Loading…
Cancel
Save