You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

allreduce_graph.cc 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/allreduce_fusion/allreduce_graph.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include "ir/anf.h"
  20. #include "parallel/allreduce_fusion/allreduce_node.h"
  21. #include "utils/log_adapter.h"
  22. namespace mindspore {
  23. namespace parallel {
  24. Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr &para) {
  25. AllreduceNodePtr arnode;
  26. auto cnode_emplace_return = cnode_set_.emplace(node);
  27. if (!cnode_emplace_return.second) {
  28. MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
  29. auto cnode_arnode_pair = cnode_arnode_map_.find(node);
  30. if (cnode_arnode_pair == cnode_arnode_map_.end()) {
  31. MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
  32. }
  33. arnode = cnode_arnode_pair->second;
  34. } else {
  35. arnode = std::make_shared<AllreduceNode>(AllreduceNode());
  36. }
  37. if (arnode->Init(node) != SUCCESS) {
  38. MS_LOG(ERROR) << "AllreduceNode Init failed";
  39. return FAILED;
  40. }
  41. if (arnode->AddPara(para) != SUCCESS) {
  42. MS_LOG(ERROR) << "AllreduceNode AddPara failed";
  43. return FAILED;
  44. }
  45. cnode_arnode_map_[node] = arnode;
  46. auto arnode_emplace_return = arnode_set_.insert(arnode);
  47. if (!arnode_emplace_return.second) {
  48. MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
  49. }
  50. cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
  51. if (!cnode_emplace_return.second) {
  52. MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
  53. << "'s cnodeset!";
  54. }
  55. auto para_emplace_return = cnode_paraset_map_[node].emplace(para);
  56. if (!para_emplace_return.second) {
  57. MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString()
  58. << "'s paraset!";
  59. }
  60. return SUCCESS;
  61. }
  62. Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
  63. auto from_arnode_iter = cnode_arnode_map_.find(from);
  64. if (from_arnode_iter == cnode_arnode_map_.end()) {
  65. MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
  66. PrintCNodeSet();
  67. return FAILED;
  68. }
  69. auto to_arnode_iter = cnode_arnode_map_.find(to);
  70. if (to_arnode_iter == cnode_arnode_map_.end()) {
  71. MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added";
  72. PrintCNodeSet();
  73. return FAILED;
  74. }
  75. auto from_arnode = from_arnode_iter->second;
  76. auto to_arnode = to_arnode_iter->second;
  77. if (from_arnode->AddNext(to_arnode) != SUCCESS) {
  78. MS_LOG(ERROR) << "from_arnode AddNext failed";
  79. return FAILED;
  80. }
  81. if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
  82. MS_LOG(ERROR) << "to_arnode AddPrev failed";
  83. return FAILED;
  84. }
  85. max_ = std::max(max_, to_arnode->depend_feat_size());
  86. MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString();
  87. MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size()
  88. << ", to depend_feat_size: " << to_arnode->depend_feat_size();
  89. return SUCCESS;
  90. }
  91. bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
  92. auto cnode_iter = cnode_set_.find(node);
  93. return !(cnode_iter == cnode_set_.end());
  94. }
  95. std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
  96. std::vector<AnfNodePtr> nodes;
  97. for (auto &cnode_arnode : cnode_arnode_map_) {
  98. MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
  99. << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
  100. << " curr_para_size: " << cnode_arnode.second->curr_para_size();
  101. if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) {
  102. (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(),
  103. cnode_paraset_map_[cnode_arnode.first].end());
  104. }
  105. }
  106. return nodes;
  107. }
  108. std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) {
  109. std::vector<AnfNodePtr> nodes;
  110. double cur_para_size = 0;
  111. double from = to;
  112. for (auto &arnode : arnode_vec_) {
  113. if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
  114. continue;
  115. }
  116. if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
  117. return std::make_pair(nodes, from);
  118. }
  119. (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end());
  120. cur_para_size += arnode.curr_para_size();
  121. from = arnode.depend_feat_size();
  122. }
  123. MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size
  124. << " cur_para_size: " << cur_para_size << " from: " << from;
  125. return std::make_pair(nodes, from);
  126. }
  127. void AllreduceGraph::PrintCNodeSet() const {
  128. MS_LOG(INFO) << "CNodeSet:";
  129. for (auto &cnode : cnode_set_) {
  130. MS_LOG(INFO) << cnode->DebugString();
  131. }
  132. }
  133. void AllreduceGraph::PrintAllredueGraphInfo() const {
  134. MS_LOG(INFO) << "max: " << max_;
  135. for (auto &cnode_arnode : cnode_arnode_map_) {
  136. MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
  137. MS_LOG(INFO) << "arnode info: ";
  138. cnode_arnode.second->ToString();
  139. }
  140. }
  141. void AllreduceGraph::PrintArnodeVec() const {
  142. MS_LOG(INFO) << "ArnodeVec:";
  143. for (auto &arnode : arnode_vec_) {
  144. arnode.ToString();
  145. }
  146. }
  147. void AllreduceGraph::PrintArnodeSet() const {
  148. MS_LOG(INFO) << "ArnodeSet:";
  149. for (auto &arnode : arnode_set_) {
  150. arnode->ToString();
  151. }
  152. }
  153. void AllreduceGraph::SortArnode() {
  154. arnode_vec_.clear();
  155. for (auto &node : arnode_set_) {
  156. arnode_vec_.emplace_back(*node);
  157. }
  158. std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
  159. }
  160. Status AllreduceGraph::RemoveExtraParas() {
  161. std::unordered_set<AnfNodePtr> para_map;
  162. for (auto &node : arnode_vec_) {
  163. for (auto &para : node.paras()) {
  164. auto emplac_result = para_map.emplace(para);
  165. if (!emplac_result.second) {
  166. MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
  167. if (node.RemovePara(para) != SUCCESS) {
  168. MS_LOG(ERROR) << "remove para failed";
  169. return FAILED;
  170. }
  171. }
  172. }
  173. }
  174. return SUCCESS;
  175. }
  176. Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
  177. auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
  178. if (arnode->Init(node) != SUCCESS) {
  179. MS_LOG(ERROR) << "AllreduceNode Init failed";
  180. }
  181. head_cnode_ = node;
  182. cnode_arnode_map_[node] = arnode;
  183. auto arnode_emplace_return = arnode_set_.insert(arnode);
  184. if (!arnode_emplace_return.second) {
  185. MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!";
  186. }
  187. auto cnode_emplace_return = cnode_set_.emplace(node);
  188. if (!cnode_emplace_return.second) {
  189. MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!";
  190. }
  191. return SUCCESS;
  192. }
  193. } // namespace parallel
  194. } // namespace mindspore