You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collective_ops_impl.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/server/collective_ops_impl.h"
  17. namespace mindspore {
  18. namespace ps {
  19. namespace server {
  20. void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
  21. MS_EXCEPTION_IF_NULL(server_node);
  22. server_node_ = server_node;
  23. local_rank_ = server_node_->rank_id();
  24. server_num_ = PSContext::instance()->initial_server_num();
  25. return;
  26. }
  27. template <typename T>
  28. bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
  29. int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
  30. if (ret != 0) {
  31. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  32. return false;
  33. }
  34. uint32_t rank_size = server_num_;
  35. uint32_t local_rank_ = server_node_->rank_id();
  36. size_t chunk_size = count / rank_size;
  37. size_t remainder_size = count % rank_size;
  38. std::vector<size_t> chunk_sizes(rank_size, chunk_size);
  39. // The rest of the data should be assigned to each chunk.
  40. for (size_t i = 0; i < remainder_size; i++) {
  41. chunk_sizes[i]++;
  42. }
  43. // Store offsets to get every data chunk's address.
  44. std::vector<size_t> chunk_offset;
  45. for (size_t i = 0; i < rank_size; i++) {
  46. size_t ofs =
  47. std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>());
  48. chunk_offset.push_back(ofs);
  49. }
  50. T *output_buff = reinterpret_cast<T *>(recvbuff);
  51. uint32_t send_to_rank = (local_rank_ + 1) % rank_size;
  52. uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size;
  53. MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_
  54. << ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size
  55. << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
  56. << ", recv_from_rank:" << recv_from_rank;
  57. // Ring ReduceScatter.
  58. MS_LOG(DEBUG) << "Start Ring ReduceScatter.";
  59. std::unique_ptr<T[]> tmp_recv_chunk = std::make_unique<T[]>(chunk_sizes[0]);
  60. for (size_t i = 0; i < rank_size - 1; i++) {
  61. // Step 1: Async send data to next rank.
  62. size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
  63. T *send_chunk = output_buff + chunk_offset[send_chunk_index];
  64. auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
  65. chunk_sizes[send_chunk_index] * sizeof(T));
  66. // Step 2: Async receive data to next rank and wait until it's done.
  67. size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
  68. T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
  69. MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
  70. << ", send count:" << chunk_sizes[send_chunk_index]
  71. << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
  72. std::shared_ptr<std::vector<unsigned char>> recv_str;
  73. auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
  74. if (!server_node_->CollectiveWait(recv_req_id, 1)) {
  75. MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
  76. return false;
  77. }
  78. memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
  79. // Step 3: Reduce the data so we can overlap the time cost of send.
  80. for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
  81. recv_chunk[j] += tmp_recv_chunk[j];
  82. }
  83. // Step 4: Wait until send is done.
  84. if (!server_node_->Wait(send_req_id, 1)) {
  85. MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
  86. return false;
  87. }
  88. }
  89. MS_LOG(DEBUG) << "End Ring ReduceScatter.";
  90. // Ring AllGather.
  91. MS_LOG(DEBUG) << "Start Ring AllGather.";
  92. for (size_t i = 0; i < rank_size - 1; i++) {
  93. size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
  94. T *send_chunk = output_buff + chunk_offset[send_chunk_index];
  95. auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
  96. chunk_sizes[send_chunk_index] * sizeof(T));
  97. size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
  98. T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
  99. MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
  100. << ", send count:" << chunk_sizes[send_chunk_index]
  101. << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
  102. std::shared_ptr<std::vector<unsigned char>> recv_str;
  103. auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
  104. if (!server_node_->CollectiveWait(recv_req_id, 1)) {
  105. MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
  106. return false;
  107. }
  108. memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
  109. if (!server_node_->Wait(send_req_id, 1)) {
  110. MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
  111. return false;
  112. }
  113. }
  114. MS_LOG(DEBUG) << "End Ring AllGather.";
  115. return true;
  116. }
  117. template <typename T>
  118. bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
  119. uint32_t rank_size = server_num_;
  120. uint32_t local_rank_ = server_node_->rank_id();
  121. MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
  122. << ", count:" << count;
  123. int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
  124. if (ret != 0) {
  125. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  126. return false;
  127. }
  128. T *output_buff = reinterpret_cast<T *>(recvbuff);
  129. // Reduce data to rank 0 process.
  130. MS_LOG(DEBUG) << "Start Reduce to rank 0 process.";
  131. if (local_rank_ == 0) {
  132. std::unique_ptr<T[]> tmp_recv_buff = std::make_unique<T[]>(count);
  133. for (uint32_t i = 1; i < rank_size; i++) {
  134. std::shared_ptr<std::vector<unsigned char>> recv_str;
  135. MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
  136. auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
  137. if (!server_node_->CollectiveWait(recv_req_id, 1)) {
  138. MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
  139. return false;
  140. }
  141. memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
  142. for (size_t j = 0; j < count; j++) {
  143. output_buff[j] += tmp_recv_buff[j];
  144. }
  145. }
  146. } else {
  147. MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
  148. auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
  149. if (!server_node_->Wait(send_req_id, 1)) {
  150. MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
  151. return false;
  152. }
  153. }
  154. MS_LOG(DEBUG) << "End Reduce.";
  155. // Broadcast data to not 0 rank process.
  156. MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes.";
  157. if (local_rank_ == 0) {
  158. for (uint32_t i = 1; i < rank_size; i++) {
  159. MS_LOG(DEBUG) << "Broadcast data to process " << i;
  160. auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
  161. if (!server_node_->Wait(send_req_id, 1)) {
  162. MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
  163. return false;
  164. }
  165. }
  166. } else {
  167. MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
  168. std::shared_ptr<std::vector<unsigned char>> recv_str;
  169. auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
  170. if (!server_node_->CollectiveWait(recv_req_id, 1)) {
  171. MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
  172. return false;
  173. }
  174. memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
  175. }
  176. MS_LOG(DEBUG) << "End broadcast.";
  177. return true;
  178. }
  179. template <typename T>
  180. bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) {
  181. // The collective communication API does not support calling Send and Recv concurrently with multiple threads;
  182. std::unique_lock<std::mutex> lock(mtx_);
  183. if (sendbuff == nullptr || recvbuff == nullptr) {
  184. MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr.";
  185. return false;
  186. }
  187. uint32_t rank_size = server_num_;
  188. if (count >= rank_size) {
  189. return RingAllReduce<T>(sendbuff, recvbuff, count);
  190. } else {
  191. return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count);
  192. }
  193. }
  194. template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
  195. template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
  196. template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
  197. template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
  198. template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
  199. template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
  200. template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
  201. template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
  202. template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
  203. } // namespace server
  204. } // namespace ps
  205. } // namespace mindspore