You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thread_pool.cc 8.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/thread_pool.h"
  17. #include <algorithm>
  18. #include <exception>
  19. #include "utils/log_adapter.h"
  20. #include "utils/convert_utils_base.h"
  21. #include "utils/ms_exception.h"
  22. namespace mindspore {
  23. namespace common {
  24. #ifdef ENABLE_D
  25. const int kDeviceNum = 8;
  26. #endif
  27. const int kMaxThreadNum = 23;
  28. bool Queue::Enqueue(Task *task) {
  29. const int tail_index = tail_.load(std::memory_order_relaxed);
  30. // queue full
  31. auto next = (tail_index + 1) % 2;
  32. if (next == head_.load(std::memory_order_acquire)) {
  33. return false;
  34. }
  35. buffer_[tail_index] = task;
  36. tail_.store(next, std::memory_order_release);
  37. ++task_size_;
  38. return true;
  39. }
  40. bool Queue::Dequeue(Task **out) {
  41. if (task_size_ == 0) {
  42. return false;
  43. }
  44. // queue empty
  45. const int head_index = head_.load(std::memory_order_relaxed);
  46. if (head_index == tail_.load(std::memory_order_acquire)) {
  47. return false;
  48. }
  49. *out = buffer_[head_index];
  50. head_.store((head_index + 1) % 2, std::memory_order_release);
  51. return true;
  52. }
  53. ThreadPool::ThreadPool() {
  54. int process_core_num = std::thread::hardware_concurrency() - 1;
  55. if (process_core_num < 1) {
  56. process_core_num = 1;
  57. }
  58. #ifdef ENABLE_D
  59. max_thread_num_ = process_core_num / kDeviceNum;
  60. #else
  61. max_thread_num_ = process_core_num;
  62. #endif
  63. if (max_thread_num_ < 1) {
  64. max_thread_num_ = 1;
  65. }
  66. if (max_thread_num_ > kMaxThreadNum) {
  67. max_thread_num_ = kMaxThreadNum;
  68. }
  69. }
  70. bool ThreadPool::SetThreadPool(int config_thread_num) {
  71. if (config_thread_num > max_thread_num_) {
  72. MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num="
  73. << config_thread_num << ", allowed max thread num=" << max_thread_num_;
  74. }
  75. if (config_thread_num > cur_thread_nums_) {
  76. AddNewThread(config_thread_num - cur_thread_nums_);
  77. }
  78. MS_LOG(DEBUG) << "cur_thread_nums_=" << cur_thread_nums_ << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
  79. return true;
  80. }
  81. void ThreadPool::AddNewThread(int add_num) {
  82. for (int i = cur_thread_nums_, j = 0; j < add_num; ++i, ++j) {
  83. auto active = new std::atomic_bool{true};
  84. auto queue = std::make_shared<Queue>();
  85. std::thread thread([this, i, active, queue]() {
  86. Task *task = nullptr;
  87. while (!exit_run_) {
  88. while (*active) {
  89. if (queue->Dequeue(&task)) {
  90. int ret;
  91. try {
  92. ret = (*task)();
  93. } catch (std::exception &e) {
  94. ret = FAIL;
  95. MsException::Instance().SetException();
  96. }
  97. if (ret != SUCCESS) {
  98. error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret)));
  99. }
  100. queue->task_size_--;
  101. }
  102. std::this_thread::yield();
  103. }
  104. std::unique_lock<std::mutex> queue_lock(thread_mtx_);
  105. queue_ready_.wait(queue_lock, [active, this] { return exit_run_ || *active; });
  106. }
  107. });
  108. thread_list_.emplace_back(std::move(thread));
  109. activate_list_.emplace_back(active);
  110. queue_list_.emplace_back(queue);
  111. }
  112. cur_thread_nums_ += add_num;
  113. cur_thread_run_nums_ += add_num;
  114. MS_LOG(INFO) << "add " << add_num << " thread";
  115. }
  116. void ThreadPool::AddRunThread(int num) {
  117. MS_LOG(DEBUG) << "num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
  118. int active_nums = num - cur_thread_run_nums_;
  119. if (active_nums <= 0 || static_cast<int>(activate_list_.size()) < active_nums) {
  120. return;
  121. }
  122. for (int i = cur_thread_run_nums_ - 1, j = 0; j < active_nums; ++i, ++j) {
  123. *activate_list_[i] = true;
  124. }
  125. std::lock_guard<std::mutex> queueLock(thread_mtx_);
  126. queue_ready_.notify_all();
  127. cur_thread_run_nums_ = num;
  128. }
  129. void ThreadPool::SubRunThread(int num) {
  130. MS_LOG(DEBUG) << "sub num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
  131. int deactive_nums = cur_thread_run_nums_ - num;
  132. if (deactive_nums <= 0) {
  133. return;
  134. }
  135. for (int i = num, j = 0; j < deactive_nums; ++i, ++j) {
  136. *activate_list_[i] = false;
  137. }
  138. cur_thread_run_nums_ = num;
  139. }
  140. void ThreadPool::SyncRunLoop() {
  141. while (true) {
  142. Task task;
  143. {
  144. std::unique_lock<std::mutex> lock(task_mutex_);
  145. task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
  146. if (exit_run_) {
  147. return;
  148. }
  149. task = task_queue_.front();
  150. task_queue_.pop();
  151. }
  152. try {
  153. task();
  154. } catch (std::exception &e) {
  155. MsException::Instance().SetException();
  156. }
  157. {
  158. std::unique_lock<std::mutex> task_lock(task_mutex_);
  159. task_finished_count_ = task_finished_count_ + 1;
  160. }
  161. finished_cond_var_.notify_one();
  162. }
  163. }
  164. bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
  165. if (tasks.size() == 1) {
  166. auto ret = tasks[0]();
  167. return ret == SUCCESS;
  168. }
  169. std::unique_lock<std::mutex> lock(pool_mtx_);
  170. exit_run_ = false;
  171. int task_num = tasks.size();
  172. int thread_num = sync_run_threads_.size();
  173. if (thread_num < max_thread_num_ && thread_num < task_num) {
  174. auto new_thread_num = max_thread_num_;
  175. if (task_num < max_thread_num_) {
  176. new_thread_num = task_num;
  177. }
  178. for (int i = thread_num; i < new_thread_num; ++i) {
  179. sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
  180. }
  181. }
  182. for (auto &task : tasks) {
  183. std::lock_guard<std::mutex> task_lock(task_mutex_);
  184. task_queue_.push(task);
  185. task_cond_var_.notify_one();
  186. }
  187. {
  188. std::unique_lock<std::mutex> task_lock(task_mutex_);
  189. finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
  190. task_finished_count_ = 0;
  191. }
  192. return true;
  193. }
  194. bool ThreadPool::InnerSyncRun(const std::vector<Task> &tasks) {
  195. std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
  196. int thread_num = tasks.size();
  197. if (thread_num > max_thread_num_) {
  198. thread_num = max_thread_num_;
  199. }
  200. if (!SetThreadPool(thread_num)) {
  201. return false;
  202. }
  203. error_info_.clear();
  204. bool succ_flag;
  205. for (int task_id = 0, queue_index = 0; task_id < SizeToInt(tasks.size()); ++task_id) {
  206. do {
  207. succ_flag = true;
  208. if (!queue_list_[queue_index]->Enqueue(const_cast<Task *>(&tasks[task_id]))) {
  209. std::this_thread::yield();
  210. succ_flag = false;
  211. }
  212. } while (!succ_flag);
  213. queue_index++;
  214. if (queue_index >= cur_thread_run_nums_) {
  215. queue_index = queue_index - cur_thread_run_nums_;
  216. }
  217. }
  218. succ_flag = false;
  219. while (!succ_flag) {
  220. std::this_thread::yield();
  221. succ_flag = true;
  222. for (int i = 0; i < cur_thread_run_nums_; ++i) {
  223. if (queue_list_[i]->task_size_ != 0) {
  224. succ_flag = false;
  225. break;
  226. }
  227. }
  228. }
  229. MS_LOG(INFO) << "Finish " << tasks.size() << " task successful";
  230. return CheckResult();
  231. }
  232. bool ThreadPool::CheckResult() {
  233. bool succ_flag = true;
  234. for (auto result : error_info_) {
  235. if (result.second.first) {
  236. MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second;
  237. succ_flag = false;
  238. }
  239. }
  240. return succ_flag;
  241. }
  242. ThreadPool &ThreadPool::GetInstance() {
  243. static ThreadPool instance;
  244. return instance;
  245. }
  246. void ThreadPool::ClearThreadPool() {
  247. std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
  248. if (exit_run_) {
  249. return;
  250. }
  251. exit_run_ = true;
  252. cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
  253. SubRunThread(0);
  254. queue_ready_.notify_all();
  255. task_cond_var_.notify_all();
  256. for (auto &it : sync_run_threads_) {
  257. if (it.joinable()) {
  258. it.join();
  259. }
  260. }
  261. sync_run_threads_.clear();
  262. for (auto &it : thread_list_) {
  263. if (it.joinable()) {
  264. it.join();
  265. }
  266. }
  267. thread_list_.clear();
  268. for (const auto &it : activate_list_) {
  269. delete it;
  270. }
  271. activate_list_.clear();
  272. }
  273. ThreadPool::~ThreadPool() { ClearThreadPool(); }
  274. } // namespace common
  275. } // namespace mindspore