You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thread_pool.cpp 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /**
  2. * \file src/core/impl/utils/thread_pool.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/utils/thread_pool.h"
  12. #include <chrono>
  13. using namespace mgb;
  14. #if MGB_HAVE_THREAD
  15. ThreadPool::ThreadPool(size_t threads_num)
  16. : m_nr_threads(threads_num),
  17. m_main_affinity_flag{false},
  18. m_stop{false},
  19. m_active{false} {
  20. if (threads_num < 1) {
  21. m_nr_threads = 1;
  22. }
  23. if (m_nr_threads > 1) {
  24. if (m_nr_threads > static_cast<uint32_t>(sys::get_cpu_count())) {
  25. mgb_log_debug(
  26. "The number of threads is bigger than number of "
  27. "physical cpu cores, got: %zu core_number: %zu",
  28. static_cast<size_t>(sys::get_cpu_count()), nr_threads());
  29. }
  30. for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
  31. m_workers.push_back(new Worker([this, i]() {
  32. while (!m_stop) {
  33. while (m_active) {
  34. if (m_workers[i]->affinity_flag &&
  35. m_core_binding_function != nullptr) {
  36. m_core_binding_function(i);
  37. m_workers[i]->affinity_flag = false;
  38. }
  39. //! if the thread should work
  40. if (m_workers[i]->work_flag.load(
  41. std::memory_order_acquire)) {
  42. int index = -1;
  43. //! Get one task and execute
  44. while ((index = m_task_iter.fetch_sub(
  45. 1, std::memory_order_acq_rel)) &&
  46. index > 0) {
  47. //! index is decrease, use
  48. //! m_all_task_number - index to get the
  49. //! increase id which will pass to task
  50. m_task(static_cast<size_t>(m_nr_parallelism -
  51. index),
  52. i);
  53. }
  54. //! Flag worker is finished
  55. m_workers[i]->work_flag.store(
  56. false, std::memory_order_release);
  57. }
  58. //! Wait next task coming
  59. std::this_thread::yield();
  60. }
  61. {
  62. std::unique_lock<std::mutex> lock(m_mutex);
  63. if (!m_stop && !m_active) {
  64. m_cv.wait(lock,
  65. [this] { return m_stop || m_active; });
  66. }
  67. }
  68. }
  69. }));
  70. }
  71. }
  72. }
  73. void ThreadPool::add_task(const TaskElem& task_elem) {
  74. //! Make sure the main thread have bind
  75. if (m_main_affinity_flag &&
  76. m_core_binding_function != nullptr) {
  77. std::lock_guard<std::mutex> lock(m_mutex_task);
  78. m_core_binding_function(m_nr_threads - 1);
  79. m_main_affinity_flag = false;
  80. }
  81. size_t parallelism = task_elem.nr_parallelism;
  82. //! If only one thread or one task, execute directly
  83. if (task_elem.nr_parallelism == 1 || m_nr_threads == 1) {
  84. for (size_t i = 0; i < parallelism; i++) {
  85. task_elem.task(i, 0);
  86. }
  87. return;
  88. } else {
  89. std::lock_guard<std::mutex> lock(m_mutex_task);
  90. mgb_assert(m_task_iter.load(std::memory_order_acquire) <= 0,
  91. "The init value of m_all_sub_task is not zero.");
  92. active();
  93. //! Set the task number, task iter and task
  94. m_nr_parallelism = parallelism;
  95. m_task_iter.exchange(parallelism, std::memory_order_relaxed);
  96. m_task = [&task_elem](size_t index, size_t thread_id) {
  97. task_elem.task(index, thread_id);
  98. };
  99. //! Set flag to start thread working
  100. for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
  101. m_workers[i]->work_flag = true;
  102. }
  103. //! Main thread working
  104. int index = -1;
  105. while ((index = m_task_iter.fetch_sub(1, std::memory_order_acq_rel)) &&
  106. (index > 0)) {
  107. m_task(static_cast<size_t>(m_nr_parallelism - index),
  108. m_nr_threads - 1);
  109. }
  110. //! make sure all threads done
  111. sync();
  112. }
  113. }
  114. void ThreadPool::set_affinity(AffinityCallBack affinity_cb) {
  115. mgb_assert(affinity_cb, "The affinity callback must not be nullptr");
  116. std::lock_guard<std::mutex> lock(m_mutex_task);
  117. m_core_binding_function = affinity_cb;
  118. for (size_t i = 0; i < m_nr_threads - 1; i++) {
  119. m_workers[i]->affinity_flag = true;
  120. }
  121. m_main_affinity_flag = true;
  122. }
  123. size_t ThreadPool::nr_threads() const {
  124. return m_nr_threads;
  125. }
  126. void ThreadPool::sync() {
  127. bool no_finished = false;
  128. do {
  129. no_finished = false;
  130. for (uint32_t i = 0; i < m_nr_threads - 1; ++i) {
  131. if (m_workers[i]->work_flag) {
  132. no_finished = true;
  133. break;
  134. }
  135. }
  136. if (no_finished) {
  137. std::this_thread::yield();
  138. }
  139. } while (no_finished);
  140. }
  141. void ThreadPool::active() {
  142. if (!m_active) {
  143. std::unique_lock<std::mutex> lock(m_mutex);
  144. m_active = true;
  145. m_cv.notify_all();
  146. }
  147. }
  148. void ThreadPool::deactive() {
  149. std::lock_guard<std::mutex> lock_task(m_mutex_task);
  150. std::unique_lock<std::mutex> lock(m_mutex);
  151. m_active = false;
  152. }
  153. ThreadPool::~ThreadPool() {
  154. std::lock_guard<std::mutex> lock_task(m_mutex_task);
  155. {
  156. std::unique_lock<std::mutex> lock(m_mutex);
  157. m_stop = true;
  158. m_active = false;
  159. m_cv.notify_all();
  160. }
  161. for (auto& worker : m_workers) {
  162. delete worker;
  163. }
  164. }
  165. #else
  166. void ThreadPool::add_task(const TaskElem& task_elem) {
  167. for (size_t i = 0; i < task_elem.nr_parallelism; i++) {
  168. task_elem.task(i, 0);
  169. }
  170. }
  171. void ThreadPool::set_affinity(AffinityCallBack affinity_cb) {
  172. mgb_assert(affinity_cb != nullptr, "The affinity callback is nullptr");
  173. affinity_cb(0);
  174. }
  175. #endif
  176. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台