You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_fusion.cc 32 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/parallel_fusion.h"
  17. #include <algorithm>
  18. #include <list>
  19. #include <queue>
  20. #include <unordered_map>
  21. #include <utility>
  22. #include "include/common/utils/context/graph_kernel_flags.h"
  23. #include "kernel/kernel.h"
  24. #include "common/graph_kernel/graph_kernel_helper.h"
  25. #include "kernel/common_utils.h"
  26. #include "frontend/operator/ops.h"
  27. #include "ir/func_graph_cloner.h"
  28. #include "common/graph_kernel/core/update_state_formatter.h"
  29. #include "common/graph_kernel/core/graph_builder.h"
  30. namespace mindspore::graphkernel {
  31. namespace {
  32. // Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512.
  33. constexpr size_t CUDA_PARA_LIMIT = 512;
  34. bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) {
  35. return std::any_of(ops_prim.cbegin(), ops_prim.cend(),
  36. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  37. }
  38. void ProcessThroughPassCNode(const std::function<bool(const AnfNodePtr &)> &pass_fn,
  39. OrderedMap<AnfNodePtr, NodeRelation> *const node_rels) {
  40. std::set<AnfNodePtr> latter_to_be_erased;
  41. for (const auto &[node, node_rel] : (*node_rels)) {
  42. if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) {
  43. continue;
  44. }
  45. auto nexts = node_rel.nexts;
  46. std::vector<AnfNodePtr> pre_nodes;
  47. std::queue<AnfNodePtr> node_que;
  48. node_que.push(node);
  49. // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes.
  50. while (!node_que.empty()) {
  51. auto cur_node = node_que.front();
  52. node_que.pop();
  53. if (!pass_fn(cur_node)) {
  54. pre_nodes.push_back(cur_node);
  55. continue;
  56. }
  57. latter_to_be_erased.insert(cur_node);
  58. auto predecessors = (*node_rels)[cur_node].pres;
  59. if (predecessors.empty()) {
  60. continue;
  61. }
  62. for (const auto &pre_node : predecessors) {
  63. (*node_rels)[cur_node].pres.erase(pre_node);
  64. (*node_rels)[pre_node].nexts.erase(cur_node);
  65. node_que.push(pre_node);
  66. }
  67. }
  68. // Modify the relation: delete node <-> next_node, add pre node <-> next_node.
  69. for (const auto &next_node : nexts) {
  70. (*node_rels)[next_node].pres.erase(node);
  71. for (const auto &cur_node : pre_nodes) {
  72. (*node_rels)[next_node].pres.insert(cur_node);
  73. (*node_rels)[cur_node].nexts.insert(next_node);
  74. }
  75. }
  76. }
  77. for (const auto &node : latter_to_be_erased) {
  78. node_rels->erase(node);
  79. }
  80. }
  81. void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *const node_rels) {
  82. AnfNodePtrList latter_to_be_erased;
  83. for (auto &[node, node_rel] : (*node_rels)) {
  84. if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  85. continue;
  86. }
  87. AnfNodePtrList check_next_list;
  88. check_next_list.push_back(node);
  89. bool disinterested = false;
  90. for (auto &successor : node_rel.nexts) {
  91. if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
  92. disinterested = true;
  93. break;
  94. }
  95. check_next_list.push_back(successor);
  96. }
  97. if (disinterested) {
  98. continue;
  99. }
  100. if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
  101. [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
  102. continue;
  103. }
  104. latter_to_be_erased.push_back(node);
  105. }
  106. // Delete Tail MakeTuple(including its getitem nodes).
  107. for (const auto &node : latter_to_be_erased) {
  108. for (auto &pre : (*node_rels)[node].pres) {
  109. (*node_rels)[pre].nexts.erase(node);
  110. }
  111. // Tail MakeTuple is just be consumed by nothing or invalid getitem node.
  112. for (auto &getitem : (*node_rels)[node].nexts) {
  113. node_rels->erase(getitem);
  114. }
  115. node_rels->erase(node);
  116. }
  117. }
  118. bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  119. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
  120. return true;
  121. }
  122. return false;
  123. }
  124. bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  125. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
  126. return true;
  127. }
  128. return false;
  129. }
  130. bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  131. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
  132. return true;
  133. }
  134. return false;
  135. }
  136. bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  137. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
  138. return true;
  139. }
  140. return false;
  141. }
  142. bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  143. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
  144. return true;
  145. }
  146. return false;
  147. }
  148. bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  149. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
  150. return true;
  151. }
  152. return false;
  153. }
  154. void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels,
  155. std::set<AnfNodePtr> *const virtual_noout_nodes, std::set<AnfNodePtr> *ignore_noin_nodes) {
  156. // 1. Local relation
  157. // Graph as following left part, relation D->B and D->E(D is a no input node)
  158. // will make B and E to be multiply inputs node.
  159. // But for parallel, this local relation can ignore for B and E, which make
  160. // them be able to be paralleled.
  161. //
  162. // ************************************
  163. // * *
  164. // * | | *
  165. // * A D A D *
  166. // * | /| | / \ *
  167. // * | C | | C F *
  168. // * |/ / | | | *
  169. // * B F ====> B x x *
  170. // * | / | *
  171. // * |/ | *
  172. // * E E *
  173. // * | | *
  174. // * *
  175. // ************************************
  176. AnfNodePtrList no_input_nodes;
  177. for (const auto &node_rel : *node_rels) {
  178. auto &node = node_rel.first;
  179. if (IsNoInputsNode(*node_rels, node)) {
  180. no_input_nodes.push_back(node);
  181. }
  182. }
  183. std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
  184. for (const auto &ninode : no_input_nodes) {
  185. AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end());
  186. for (const auto &n : cnexts) {
  187. AnfNodePtr serial_tail = ninode;
  188. AnfNodePtr cur_node = n;
  189. while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
  190. serial_tail = cur_node;
  191. cur_node = *((*node_rels)[cur_node].nexts.begin());
  192. }
  193. latter_delete.emplace_back(serial_tail, cur_node);
  194. }
  195. }
  196. // Delete relation.
  197. for (const auto &[serial_tail, cur_node] : latter_delete) {
  198. virtual_noout_nodes->insert(serial_tail);
  199. ignore_noin_nodes->insert(cur_node);
  200. (*node_rels)[serial_tail].nexts.erase(cur_node);
  201. (*node_rels)[cur_node].pres.erase(serial_tail);
  202. MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
  203. << cur_node->fullname_with_scope();
  204. }
  205. }
  206. std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
  207. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
  208. const std::set<AnfNodePtr> &ignore_noin_nodes) {
  209. AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
  210. std::list<std::function<void(const AnfNodePtr &)>> func_list = {
  211. [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
  212. if (IsMultiInputsNode(node_rels, node)) {
  213. multi_inputs_nodes.push_back(node);
  214. }
  215. },
  216. [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
  217. if (IsMultiOutputsNode(node_rels, node)) {
  218. multi_outputs_nodes.push_back(node);
  219. }
  220. },
  221. [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
  222. if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
  223. no_input_nodes.push_back(node);
  224. }
  225. },
  226. [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
  227. if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
  228. no_output_nodes.push_back(node);
  229. }
  230. }};
  231. for (const auto &node_rel : node_rels) {
  232. for (const auto &func : func_list) {
  233. func(node_rel.first);
  234. }
  235. }
  236. return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
  237. }
  238. bool WhiteOpsFilter(const AnfNodePtr &node) {
  239. std::vector<PrimitivePtr> whiteable_ops = {}; // Not special for now.
  240. return common::AnfAlgo::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
  241. }
  242. bool Unfavorable(const AnfNodePtr &node) {
  243. // Parallel cannot work with stitching for now.
  244. auto cnode = node->cast<CNodePtr>();
  245. MS_EXCEPTION_IF_NULL(cnode);
  246. auto input = cnode->input(kAnfPrimitiveIndex);
  247. if (!IsValueNode<FuncGraph>(input)) {
  248. return common::AnfAlgo::HasNodeAttr(kAttrStitch, cnode);
  249. }
  250. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  251. MS_EXCEPTION_IF_NULL(func_graph);
  252. AnfNodePtrList sub_nodes;
  253. kernel::GetValidKernelNodes(func_graph, &sub_nodes);
  254. for (auto sub_node : sub_nodes) {
  255. auto sub_cnode = sub_node->cast<CNodePtr>();
  256. MS_EXCEPTION_IF_NULL(sub_cnode);
  257. if (common::AnfAlgo::HasNodeAttr(kAttrStitch, sub_cnode)) {
  258. return true;
  259. }
  260. }
  261. return false;
  262. }
  263. bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !Unfavorable(node); }
  264. std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
  265. const std::function<bool(const AnfNodePtr &)> &filter_func,
  266. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
  267. std::set<AnfNodePtr> *const seen) {
  268. // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
  269. // For backward search, the other multi-inputs node can be contained in.
  270. // For forward search, the other multi-outputs node can be contained in.
  271. auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; }
  272. : [](const NodeRelation &info) { return info.nexts; };
  273. auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; }
  274. : [](const NodeRelation &info) { return info.pres; };
  275. std::vector<AnfNodePtrList> group;
  276. for (const auto &node : nodes) {
  277. AnfNodePtrList stream;
  278. AnfNodePtr n = node;
  279. for (auto iter = node_rels.find(n);
  280. seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
  281. iter = node_rels.find(n)) {
  282. if (filter_func(n)) {
  283. stream.push_back(n);
  284. seen->insert(n);
  285. }
  286. if (get_contain_node_set(iter->second).size() != 1) {
  287. break;
  288. }
  289. n = *(get_contain_node_set(iter->second).begin());
  290. }
  291. if (stream.size() > 0) {
  292. group.push_back(stream);
  293. }
  294. }
  295. if (group.size() == 1) {
  296. for (const auto &drop : group[0]) {
  297. seen->erase(drop);
  298. }
  299. group.clear();
  300. }
  301. return group;
  302. }
  303. void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
  304. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
  305. std::vector<std::vector<AnfNodePtrList>> *groups,
  306. std::set<AnfNodePtr> *const seen) {
  307. auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; }
  308. : [](const NodeRelation &info) { return info.nexts; };
  309. for (const auto &node : multi_nodes) {
  310. if (auto iter = node_rels.find(node); iter != node_rels.end()) {
  311. const auto &pre_nodes = get_related_nodes(iter->second);
  312. AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
  313. groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen));
  314. }
  315. }
  316. // Erase empty groups.
  317. for (auto iter = groups->begin(); iter != groups->end();) {
  318. if (iter->size() == 0) {
  319. iter = groups->erase(iter);
  320. } else {
  321. ++iter;
  322. }
  323. }
  324. }
  325. void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
  326. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
  327. std::vector<std::vector<AnfNodePtrList>> *groups,
  328. std::set<AnfNodePtr> *const seen) {
  329. groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen));
  330. // Erase empty groups.
  331. for (auto iter = groups->begin(); iter != groups->end();) {
  332. if (iter->size() == 0) {
  333. iter = groups->erase(iter);
  334. } else {
  335. ++iter;
  336. }
  337. }
  338. }
  339. std::string DumpNode(const AnfNodePtr &node) {
  340. auto cnode = node->cast<CNodePtr>();
  341. MS_EXCEPTION_IF_NULL(cnode);
  342. std::stringstream buf;
  343. buf << (common::AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
  344. << cnode->ToString();
  345. return buf.str();
  346. }
  347. void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups, const std::string &title = "") {
  348. MS_LOG(INFO) << "[" << title << "]"
  349. << "There are " << groups.size() << " parallel groups, their detail is: ";
  350. int i = 0;
  351. for (const auto group : groups) {
  352. std::stringstream buf;
  353. buf << "[" << i << " group] " << group.size() << ":\n";
  354. for (const auto nodes : group) {
  355. buf << " " << nodes.size() << ": [<";
  356. for (const auto node : nodes) {
  357. buf << "(" << DumpNode(node) << ") -> ";
  358. }
  359. buf << ">]\n";
  360. }
  361. i++;
  362. MS_LOG(INFO) << buf.str();
  363. }
  364. }
  365. void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
  366. std::stringstream buf;
  367. buf << "Parallel fusion detail: ";
  368. for (const auto &node : source) {
  369. buf << "(" << DumpNode(node) << ") + ";
  370. }
  371. buf << "==>"
  372. << "(" << DumpNode(target) << ")";
  373. MS_LOG(INFO) << buf.str();
  374. }
  375. inline bool ParameterLimit(const AnfNodePtrList &nodes) {
  376. if (nodes.empty()) {
  377. MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition.";
  378. }
  379. bool res = true;
  380. switch (AnfAlgo::GetProcessor(nodes[0])) {
  381. case kernel::Processor::CUDA: {
  382. // The number of inputs and outputs for a valid kernel should be less than cuda's limit.
  383. size_t para_count = 0;
  384. for (const auto &node : nodes) {
  385. para_count += common::AnfAlgo::GetInputTensorNum(node);
  386. para_count += common::AnfAlgo::GetOutputTensorNum(node);
  387. }
  388. res = para_count <= CUDA_PARA_LIMIT;
  389. } break;
  390. default:
  391. break;
  392. }
  393. return res;
  394. }
  395. bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); }
  396. } // namespace
  397. OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
  398. // Based on anf node input information, build a simple graph for latter analyzation.
  399. OrderedMap<AnfNodePtr, NodeRelation> node_rels;
  400. auto get_info = [&node_rels](const AnfNodePtr &node) {
  401. if (node_rels.count(node) == 0) {
  402. (void)node_rels.emplace(node, NodeRelation());
  403. }
  404. return &(node_rels[node]);
  405. };
  406. for (const auto &node : nodes) {
  407. if (!node->isa<CNode>()) {
  408. continue;
  409. }
  410. auto prior_node = get_info(node);
  411. for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
  412. if (!input->isa<CNode>()) {
  413. continue;
  414. }
  415. auto behind_node = get_info(input);
  416. prior_node->pres.insert(input);
  417. behind_node->nexts.insert(node);
  418. }
  419. }
  420. ProcessThroughPassCNode(
  421. [](const AnfNodePtr &node) {
  422. return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
  423. },
  424. &node_rels);
  425. ProcessTailMakeTupleCNode(&node_rels);
  426. ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
  427. return node_rels;
  428. }
  429. std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
  430. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
  431. // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
  432. auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
  433. GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
  434. // Get streams and group them
  435. std::set<AnfNodePtr> seen;
  436. std::vector<std::vector<AnfNodePtrList>> groups;
  437. SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
  438. SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
  439. SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
  440. SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
  441. DumpParallelGroups(groups, "Dependency Analyze");
  442. return groups;
  443. }
  444. std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
  445. int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
  446. const std::set<int> &excludes) {
  447. // Get unused nodes by offset index, the result will contain the node with start index.
  448. int node_limit = static_cast<int>(nodes.size());
  449. if (start >= node_limit) {
  450. MS_LOG(EXCEPTION) << "Index offset should be less than the limit of given nodes " << node_limit << ", but got "
  451. << start;
  452. }
  453. AnfNodePtrList target_nodes = {nodes[IntToSize(start)]};
  454. std::vector<int> valid_indices;
  455. std::vector<size_t> unused;
  456. for (size_t i = IntToSize(start); i < used.size(); ++i) {
  457. if (!used[i] && excludes.count(i) == 0) {
  458. unused.push_back(i);
  459. }
  460. }
  461. size_t limit = unused.size();
  462. for (auto offset : offsets) {
  463. if (offset >= limit) {
  464. MS_LOG(EXCEPTION) << "Index offset should be less than the limit of unused nodes " << limit << ", but got "
  465. << offset;
  466. }
  467. if (SizeToInt(unused[offset]) >= node_limit) {
  468. MS_LOG(EXCEPTION) << "Index offset should be less than the limit of nodes " << node_limit << ", but got "
  469. << unused[offset];
  470. }
  471. valid_indices.push_back(unused[offset]);
  472. target_nodes.push_back(nodes[unused[offset]]);
  473. }
  474. return std::make_tuple(target_nodes, valid_indices);
  475. }
  476. std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
  477. size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
  478. std::map<AnfNodePtr, int> *sorted_indices) {
  479. auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
  480. MS_EXCEPTION_IF_NULL(node);
  481. if (indices->find(node) == indices->end()) {
  482. MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
  483. }
  484. return (*indices)[node];
  485. };
  486. std::vector<ParallelInfo> parallel_infos;
  487. std::vector<bool> origin_candidates_used(origin_size, false);
  488. std::vector<bool> sorted_candidates_used(candidates.size(), false);
  489. for (size_t i = 0; i < candidates.size(); ++i) {
  490. if (sorted_candidates_used[i]) {
  491. continue;
  492. }
  493. int max_benefit = 0;
  494. ParallelInfo best_parallel_info;
  495. size_t unused_num = 0;
  496. for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
  497. unused_num += sorted_candidates_used[j] ? 0 : 1;
  498. }
  499. if (unused_num < 1) {
  500. break;
  501. }
  502. unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
  503. size_t begin = 1, end = unused_num;
  504. while (begin <= end) {
  505. size_t mid = (begin + end) / 2;
  506. std::vector<size_t> tc(mid);
  507. std::iota(tc.begin(), tc.end(), 1);
  508. AnfNodePtrList other_candidates;
  509. std::tie(other_candidates, std::ignore) =
  510. GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
  511. if (ExtraFusionCondition(other_candidates)) {
  512. int benefit;
  513. std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
  514. if (benefit > 0) {
  515. begin = mid + 1;
  516. continue;
  517. }
  518. }
  519. end = mid - 1;
  520. }
  521. if (begin > 1) {
  522. std::vector<size_t> tc(begin - 1);
  523. std::iota(tc.begin(), tc.end(), 1);
  524. AnfNodePtrList other_candidates;
  525. std::tie(other_candidates, std::ignore) =
  526. GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
  527. auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
  528. if (benefit <= 0) {
  529. MS_LOG(EXCEPTION) << "Internal error in candidate search! benefit should be greater than 0, but got "
  530. << benefit;
  531. }
  532. max_benefit = benefit;
  533. best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
  534. i += begin - 1;
  535. }
  536. if (max_benefit > 0) {
  537. parallel_infos.push_back(best_parallel_info);
  538. for (const auto &node : best_parallel_info.nodes()) {
  539. sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true;
  540. origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true;
  541. }
  542. }
  543. }
  544. // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
  545. if (parallel_infos.size() == 0) {
  546. origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true;
  547. }
  548. return std::make_tuple(origin_candidates_used, parallel_infos);
  549. }
  550. std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
  551. const AnfNodePtrList &cs) {
  552. std::map<AnfNodePtr, int> origin_indices;
  553. std::vector<size_t> indices;
  554. for (size_t i = 0; i < cs.size(); ++i) {
  555. if (cs[i]) {
  556. (void)origin_indices.emplace(cs[i], i);
  557. indices.push_back(i);
  558. }
  559. }
  560. // A calculated heavy node can cover more lighter nodes' cost, so sort them first.
  561. std::map<size_t, int> cal_amounts;
  562. for (auto id : indices) {
  563. cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
  564. }
  565. std::sort(indices.begin(), indices.end(),
  566. [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
  567. AnfNodePtrList candidates;
  568. for (size_t i = 0; i < indices.size(); ++i) {
  569. candidates.push_back(cs[indices[i]]);
  570. }
  571. std::map<AnfNodePtr, int> sorted_indices;
  572. for (size_t i = 0; i < candidates.size(); ++i) {
  573. (void)sorted_indices.emplace(candidates[i], i);
  574. }
  575. return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
  576. }
  577. void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
  578. std::vector<ParallelInfo> *parallel_infos) {
  579. std::vector<AnfNodePtrList::const_iterator> tails;
  580. std::vector<AnfNodePtrList::const_iterator> ended;
  581. for (const auto &node_list : group) {
  582. tails.push_back(node_list.begin());
  583. ended.push_back(node_list.end());
  584. }
  585. auto get_candidates = [&tails, &ended]() {
  586. AnfNodePtrList candidates;
  587. for (size_t id = 0; id < tails.size(); ++id) {
  588. candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
  589. }
  590. return candidates;
  591. };
  592. auto update_tails = [&tails](const std::vector<bool> &used) {
  593. if (used.size() != tails.size()) {
  594. MS_LOG(EXCEPTION) << "Judged nodes size is different from left ones size: " << used.size() << " vs "
  595. << tails.size();
  596. }
  597. for (size_t id = 0; id < used.size(); ++id) {
  598. if (used[id]) {
  599. ++tails[id];
  600. }
  601. }
  602. };
  603. auto valid_candidate_num = [](const AnfNodePtrList &cs) {
  604. return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
  605. };
  606. auto candidates = get_candidates();
  607. while (valid_candidate_num(candidates) > 1) {
  608. auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
  609. std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
  610. [](const ParallelInfo &pi) { return pi; });
  611. update_tails(used);
  612. candidates = get_candidates();
  613. }
  614. }
  615. std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
  616. const std::vector<std::vector<AnfNodePtrList>> &groups) {
  617. // Find core-fusable groups with cost model.
  618. std::vector<ParallelInfo> parallel_infos;
  619. for (const auto &group : groups) {
  620. SearchFuseNodesInParallelGroup(group, &parallel_infos);
  621. }
  622. return parallel_infos;
  623. }
  624. void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info) {
  625. AnfNodePtr attach_node;
  626. // Dim info should be attach to each segment's output.
  627. for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
  628. const auto &fuse_nodes = parallel_info.nodes();
  629. std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
  630. if (!common::AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
  631. attach_node = fuse_nodes[i];
  632. SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
  633. } else {
  634. auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
  635. auto out_node = node_g->output();
  636. if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
  637. auto inputs = out_node->cast<CNodePtr>()->inputs();
  638. for (size_t j = 1; j < inputs.size(); ++j) {
  639. SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
  640. }
  641. attach_node = inputs[1];
  642. } else {
  643. attach_node = out_node;
  644. SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
  645. }
  646. }
  647. }
  648. // Fusion info is ok to attach to one of the segments.
  649. SetFusionInfoAttrToNode(attach_node, parallel_info);
  650. }
  651. void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info) {
  652. auto fusion_type = parallel_info.fusion_info()->FusionType();
  653. common::AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
  654. if (parallel_info.fusion_info()->ExistTypeInfo()) {
  655. if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
  656. common::AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
  657. MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
  658. }
  659. }
  660. }
  661. bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
  662. const std::shared_ptr<session::KernelGraph> &kernel_graph) {
  663. bool changed = false;
  664. for (size_t i = 0; i < parallel_infos.size(); ++i) {
  665. const auto &fuse_nodes = parallel_infos[i].nodes();
  666. if (fuse_nodes.size() <= 1) {
  667. continue;
  668. }
  669. changed = true;
  670. SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
  671. auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel");
  672. common::AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
  673. DumpParallelFusionDetail(fuse_nodes, sg_node);
  674. }
  675. return changed;
  676. }
  677. std::set<AnfNodePtr> CollectCapturedNodes(const std::vector<ParallelInfo> &infos) {
  678. std::set<AnfNodePtr> captured;
  679. (void)std::for_each(infos.cbegin(), infos.cend(), [&captured](const ParallelInfo &info) {
  680. captured.insert(info.nodes().begin(), info.nodes().end());
  681. });
  682. return captured;
  683. }
  684. std::vector<std::vector<AnfNodePtrList>> GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels,
  685. const std::set<AnfNodePtr> &exclude) {
  686. std::vector<std::vector<AnfNodePtrList>> groups;
  687. // BFS
  688. std::queue<AnfNodePtr> node_que;
  689. std::unordered_map<AnfNodePtr, int> outdegrees;
  690. for (const auto &[node, ref] : node_rels) {
  691. outdegrees[node] = SizeToInt(ref.nexts.size());
  692. if (outdegrees[node] == 0) {
  693. node_que.push(node);
  694. }
  695. }
  696. int total_node_num = SizeToInt(node_rels.size());
  697. while (!node_que.empty()) {
  698. std::vector<AnfNodePtrList> group;
  699. int node_size = SizeToInt(node_que.size());
  700. while (node_size--) {
  701. auto node = node_que.front();
  702. node_que.pop();
  703. if (exclude.count(node) == 0 && Parallelizable(node)) {
  704. (void)group.emplace_back(AnfNodePtrList({node}));
  705. }
  706. --total_node_num;
  707. auto iter = node_rels.find(node);
  708. if (iter == node_rels.end()) {
  709. MS_LOG(EXCEPTION) << "Internal error in node relationship!";
  710. }
  711. for (const auto &pre : iter->second.pres) {
  712. if (--outdegrees[pre] == 0) {
  713. node_que.push(pre);
  714. }
  715. }
  716. }
  717. if (!group.empty()) {
  718. groups.push_back(group);
  719. }
  720. }
  721. if (total_node_num > 0) {
  722. MS_LOG(EXCEPTION) << "There is circle in analyze graph!";
  723. }
  724. DumpParallelGroups(groups, "BFS");
  725. return groups;
  726. }
  727. bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
  728. MS_EXCEPTION_IF_NULL(graph);
  729. parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level;
  730. (void)std::make_shared<ShrinkUpdateState>()->Run(graph);
  731. auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
  732. MS_EXCEPTION_IF_NULL(kernel_graph);
  733. cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
  734. MS_EXCEPTION_IF_NULL(cost_model_ptr_);
  735. auto nodes = TopoSort(kernel_graph->get_return());
  736. std::reverse(nodes.begin(), nodes.end());
  737. auto node_rels = GenAnalysisGraph(nodes);
  738. auto groups = SearchParallelGroups(node_rels);
  739. auto parallel_infos = SearchFusableParallelCNodes(groups);
  740. // Search in BFS for left nodes.
  741. if (parallel_level_ > 0) {
  742. auto exclued_nodes = CollectCapturedNodes(parallel_infos);
  743. auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes);
  744. auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs);
  745. (void)parallel_infos.insert(parallel_infos.end(), bfs_parallel_infos.begin(), bfs_parallel_infos.end());
  746. }
  747. // Create core-fuse subgraph and change origin graph.
  748. bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
  749. (void)std::make_shared<SpreadUpdateState>()->Run(graph);
  750. return changed;
  751. }
  752. } // namespace mindspore::graphkernel