You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thread_pool.cc 6.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/thread_pool.h"
  17. #include <algorithm>
  18. #include "utils/log_adapter.h"
  19. #include "utils/convert_utils_base.h"
  20. namespace mindspore {
  21. #ifdef ENABLE_D
  22. const int kDeviceNum = 8;
  23. #endif
  24. bool Queue::Enqueue(Task *task) {
  25. const int tail_index = tail_.load(std::memory_order_relaxed);
  26. // queue full
  27. auto next = (tail_index + 1) % 2;
  28. if (next == head_.load(std::memory_order_acquire)) {
  29. return false;
  30. }
  31. buffer_[tail_index] = task;
  32. tail_.store(next, std::memory_order_release);
  33. ++task_size_;
  34. return true;
  35. }
  36. bool Queue::Dequeue(Task **out) {
  37. if (task_size_ == 0) {
  38. return false;
  39. }
  40. // queue empty
  41. const int head_index = head_.load(std::memory_order_relaxed);
  42. if (head_index == tail_.load(std::memory_order_acquire)) {
  43. return false;
  44. }
  45. *out = buffer_[head_index];
  46. head_.store((head_index + 1) % 2, std::memory_order_release);
  47. return true;
  48. }
  49. ThreadPool::ThreadPool() {
  50. #ifdef ENABLE_D
  51. auto cpu_core_num = std::thread::hardware_concurrency();
  52. max_thread_num_ = cpu_core_num / kDeviceNum;
  53. #endif
  54. SetThreadPool(core_thread_num_);
  55. }
  56. bool ThreadPool::SetThreadPool(int config_thread_num) {
  57. std::lock_guard<std::mutex> Lock(pool_mtx_);
  58. if (config_thread_num > max_thread_num_) {
  59. MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num="
  60. << config_thread_num << ", allowed max thread num=" << max_thread_num_;
  61. }
  62. if (config_thread_num > cur_thread_nums_) {
  63. AddNewThread(config_thread_num - cur_thread_nums_);
  64. }
  65. MS_LOG(DEBUG) << "cur_thread_nums_=" << cur_thread_nums_ << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
  66. return true;
  67. }
  68. void ThreadPool::AddNewThread(int add_num) {
  69. for (int i = cur_thread_nums_, j = 0; j < add_num; ++i, ++j) {
  70. auto active = new std::atomic_bool{true};
  71. auto queue = std::make_shared<Queue>();
  72. std::thread thread([this, i, active, queue]() {
  73. Task *task = nullptr;
  74. while (!exit_run_) {
  75. while (*active) {
  76. if (queue->Dequeue(&task)) {
  77. auto ret = (*task)();
  78. if (ret != SUCCESS) {
  79. error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret)));
  80. }
  81. queue->task_size_--;
  82. }
  83. std::this_thread::yield();
  84. }
  85. std::unique_lock<std::mutex> queue_lock(thread_mtx_);
  86. queue_ready_.wait(queue_lock, [active, this] { return exit_run_ || *active; });
  87. }
  88. });
  89. thread_list_.emplace_back(std::move(thread));
  90. activate_list_.emplace_back(active);
  91. queue_list_.emplace_back(queue);
  92. }
  93. cur_thread_nums_ += add_num;
  94. cur_thread_run_nums_ += add_num;
  95. MS_LOG(INFO) << "add " << add_num << " thread";
  96. }
  97. void ThreadPool::AddRunThread(int num) {
  98. MS_LOG(DEBUG) << "num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
  99. int active_nums = num - cur_thread_run_nums_;
  100. if (active_nums <= 0 || static_cast<int>(activate_list_.size()) < active_nums) {
  101. return;
  102. }
  103. for (int i = cur_thread_run_nums_ - 1, j = 0; j < active_nums; ++i, ++j) {
  104. *activate_list_[i] = true;
  105. }
  106. std::lock_guard<std::mutex> queueLock(thread_mtx_);
  107. queue_ready_.notify_all();
  108. cur_thread_run_nums_ = num;
  109. }
  110. void ThreadPool::SubRunThread(int num) {
  111. MS_LOG(DEBUG) << "sub num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
  112. int deactive_nums = cur_thread_run_nums_ - num;
  113. if (deactive_nums <= 0) {
  114. return;
  115. }
  116. for (int i = num, j = 0; j < deactive_nums; ++i, ++j) {
  117. *activate_list_[i] = false;
  118. }
  119. cur_thread_run_nums_ = num;
  120. }
  121. bool ThreadPool::LaunchMultipleTask(const std::vector<Task> &tasks) {
  122. int thread_num = tasks.size();
  123. if (thread_num > max_thread_num_) {
  124. thread_num = max_thread_num_;
  125. }
  126. if (!SetThreadPool(thread_num)) {
  127. return false;
  128. }
  129. error_info_.clear();
  130. bool succ_flag;
  131. for (int task_id = 0, queue_index = 0; task_id < SizeToInt(tasks.size()); ++task_id) {
  132. do {
  133. succ_flag = true;
  134. if (!queue_list_[queue_index]->Enqueue(const_cast<Task *>(&tasks[task_id]))) {
  135. std::this_thread::yield();
  136. succ_flag = false;
  137. }
  138. } while (!succ_flag);
  139. queue_index++;
  140. if (queue_index >= cur_thread_run_nums_) {
  141. queue_index = queue_index - cur_thread_run_nums_;
  142. }
  143. }
  144. succ_flag = false;
  145. while (!succ_flag) {
  146. std::this_thread::yield();
  147. succ_flag = true;
  148. for (int i = 0; i < cur_thread_run_nums_; ++i) {
  149. if (queue_list_[i]->task_size_ != 0) {
  150. succ_flag = false;
  151. break;
  152. }
  153. }
  154. }
  155. MS_LOG(INFO) << "Finish " << tasks.size() << " task successful";
  156. return CheckResult();
  157. }
  158. bool ThreadPool::CheckResult() {
  159. bool succ_flag = true;
  160. for (auto result : error_info_) {
  161. if (result.second.first) {
  162. MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second;
  163. succ_flag = false;
  164. }
  165. }
  166. return succ_flag;
  167. }
  168. ThreadPool *ThreadPool::GetInstance() {
  169. static ThreadPool instance;
  170. return &instance;
  171. }
  172. ThreadPool::~ThreadPool() {
  173. cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
  174. exit_run_ = true;
  175. SubRunThread(0);
  176. queue_ready_.notify_all();
  177. for (auto &it : thread_list_) {
  178. if (it.joinable()) {
  179. it.join();
  180. }
  181. }
  182. for (const auto &it : activate_list_) {
  183. delete it;
  184. }
  185. }
  186. } // namespace mindspore