You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_fusion.cc 32 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/parallel_fusion.h"
  17. #include <algorithm>
  18. #include <list>
  19. #include <queue>
  20. #include <unordered_map>
  21. #include <utility>
  22. #include "include/common/utils/context/graph_kernel_flags.h"
  23. #include "kernel/kernel.h"
  24. #include "common/graph_kernel/graph_kernel_helper.h"
  25. #include "kernel/common_utils.h"
  26. #include "frontend/operator/ops.h"
  27. #include "ir/func_graph_cloner.h"
  28. #include "common/graph_kernel/core/update_state_formatter.h"
  29. #include "common/graph_kernel/core/graph_builder.h"
  30. namespace mindspore::graphkernel {
  31. namespace {
  32. // Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512.
  33. constexpr size_t CUDA_PARA_LIMIT = 512;
  34. bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) {
  35. return std::any_of(ops_prim.cbegin(), ops_prim.cend(),
  36. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  37. }
  38. void ProcessThroughPassCNode(const std::function<bool(const AnfNodePtr &)> &pass_fn,
  39. OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
  40. std::set<AnfNodePtr> latter_to_be_erased;
  41. for (const auto &[node, node_rel] : (*node_rels)) {
  42. if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) {
  43. continue;
  44. }
  45. auto nexts = node_rel.nexts;
  46. std::vector<AnfNodePtr> pre_nodes;
  47. std::queue<AnfNodePtr> node_que;
  48. node_que.push(node);
  49. // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes.
  50. while (!node_que.empty()) {
  51. auto cur_node = node_que.front();
  52. node_que.pop();
  53. if (!pass_fn(cur_node)) {
  54. pre_nodes.push_back(cur_node);
  55. continue;
  56. }
  57. (void)latter_to_be_erased.insert(cur_node);
  58. auto predecessors = (*node_rels)[cur_node].pres;
  59. if (predecessors.empty()) {
  60. continue;
  61. }
  62. for (const auto &pre_node : predecessors) {
  63. (void)(*node_rels)[cur_node].pres.erase(pre_node);
  64. (void)(*node_rels)[pre_node].nexts.erase(cur_node);
  65. node_que.push(pre_node);
  66. }
  67. }
  68. // Modify the relation: delete node <-> next_node, add pre node <-> next_node.
  69. for (const auto &next_node : nexts) {
  70. (void)(*node_rels)[next_node].pres.erase(node);
  71. for (const auto &cur_node : pre_nodes) {
  72. (void)(*node_rels)[next_node].pres.insert(cur_node);
  73. (void)(*node_rels)[cur_node].nexts.insert(next_node);
  74. }
  75. }
  76. }
  77. for (const auto &node : latter_to_be_erased) {
  78. (void)node_rels->erase(node);
  79. }
  80. }
  81. void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
  82. AnfNodePtrList latter_to_be_erased;
  83. for (auto &[node, node_rel] : (*node_rels)) {
  84. if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  85. continue;
  86. }
  87. AnfNodePtrList check_next_list;
  88. check_next_list.push_back(node);
  89. bool disinterested = false;
  90. for (auto &successor : node_rel.nexts) {
  91. if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
  92. disinterested = true;
  93. break;
  94. }
  95. check_next_list.push_back(successor);
  96. }
  97. if (disinterested) {
  98. continue;
  99. }
  100. if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
  101. [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
  102. continue;
  103. }
  104. latter_to_be_erased.push_back(node);
  105. }
  106. // Delete Tail MakeTuple(including its getitem nodes).
  107. for (const auto &node : latter_to_be_erased) {
  108. for (auto &pre : (*node_rels)[node].pres) {
  109. (void)(*node_rels)[pre].nexts.erase(node);
  110. }
  111. // Tail MakeTuple is just be consumed by nothing or invalid getitem node.
  112. for (auto &getitem : (*node_rels)[node].nexts) {
  113. (void)node_rels->erase(getitem);
  114. }
  115. (void)node_rels->erase(node);
  116. }
  117. }
  118. bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  119. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
  120. return true;
  121. }
  122. return false;
  123. }
  124. bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  125. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
  126. return true;
  127. }
  128. return false;
  129. }
  130. bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  131. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
  132. return true;
  133. }
  134. return false;
  135. }
  136. bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  137. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
  138. return true;
  139. }
  140. return false;
  141. }
  142. bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  143. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
  144. return true;
  145. }
  146. return false;
  147. }
  148. bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
  149. if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
  150. return true;
  151. }
  152. return false;
  153. }
  154. void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes,
  155. std::set<AnfNodePtr> *ignore_noin_nodes) {
  156. // 1. Local relation
  157. // Graph as following left part, relation D->B and D->E(D is a no input node)
  158. // will make B and E to be multiply inputs node.
  159. // But for parallel, this local relation can ignore for B and E, which make
  160. // them be able to be paralleled.
  161. //
  162. // ************************************
  163. // * *
  164. // * | | *
  165. // * A D A D *
  166. // * | /| | / \ *
  167. // * | C | | C F *
  168. // * |/ / | | | *
  169. // * B F ====> B x x *
  170. // * | / | *
  171. // * |/ | *
  172. // * E E *
  173. // * | | *
  174. // * *
  175. // ************************************
  176. AnfNodePtrList no_input_nodes;
  177. for (const auto &node_rel : *node_rels) {
  178. auto &node = node_rel.first;
  179. if (IsNoInputsNode(*node_rels, node)) {
  180. no_input_nodes.push_back(node);
  181. }
  182. }
  183. std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
  184. for (const auto &ninode : no_input_nodes) {
  185. AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end());
  186. for (const auto &n : cnexts) {
  187. AnfNodePtr serial_tail = ninode;
  188. AnfNodePtr cur_node = n;
  189. while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
  190. serial_tail = cur_node;
  191. cur_node = *((*node_rels)[cur_node].nexts.begin());
  192. }
  193. (void)latter_delete.emplace_back(serial_tail, cur_node);
  194. }
  195. }
  196. // Delete relation.
  197. for (const auto &[serial_tail, cur_node] : latter_delete) {
  198. (void)virtual_noout_nodes->insert(serial_tail);
  199. (void)ignore_noin_nodes->insert(cur_node);
  200. (void)(*node_rels)[serial_tail].nexts.erase(cur_node);
  201. (void)(*node_rels)[cur_node].pres.erase(serial_tail);
  202. MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
  203. << cur_node->fullname_with_scope();
  204. }
  205. }
  206. std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
  207. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
  208. const std::set<AnfNodePtr> &ignore_noin_nodes) {
  209. AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
  210. std::list<std::function<void(const AnfNodePtr &)>> func_list = {
  211. [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
  212. if (IsMultiInputsNode(node_rels, node)) {
  213. multi_inputs_nodes.push_back(node);
  214. }
  215. },
  216. [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
  217. if (IsMultiOutputsNode(node_rels, node)) {
  218. multi_outputs_nodes.push_back(node);
  219. }
  220. },
  221. [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
  222. if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
  223. no_input_nodes.push_back(node);
  224. }
  225. },
  226. [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
  227. if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
  228. no_output_nodes.push_back(node);
  229. }
  230. }};
  231. for (const auto &node_rel : node_rels) {
  232. for (const auto &func : func_list) {
  233. func(node_rel.first);
  234. }
  235. }
  236. return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
  237. }
  238. bool WhiteOpsFilter(const AnfNodePtr &node) {
  239. std::vector<PrimitivePtr> whiteable_ops = {}; // Not special for now.
  240. return common::AnfAlgo::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
  241. }
  242. bool Unfavorable(const AnfNodePtr &node) {
  243. // Parallel cannot work with stitching for now.
  244. auto cnode = node->cast<CNodePtr>();
  245. MS_EXCEPTION_IF_NULL(cnode);
  246. auto input = cnode->input(kAnfPrimitiveIndex);
  247. if (!IsValueNode<FuncGraph>(input)) {
  248. return common::AnfAlgo::HasNodeAttr(kAttrStitch, cnode);
  249. }
  250. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  251. MS_EXCEPTION_IF_NULL(func_graph);
  252. AnfNodePtrList sub_nodes;
  253. kernel::GetValidKernelNodes(func_graph, &sub_nodes);
  254. for (auto sub_node : sub_nodes) {
  255. auto sub_cnode = sub_node->cast<CNodePtr>();
  256. MS_EXCEPTION_IF_NULL(sub_cnode);
  257. if (common::AnfAlgo::HasNodeAttr(kAttrStitch, sub_cnode)) {
  258. return true;
  259. }
  260. }
  261. return false;
  262. }
  263. bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !Unfavorable(node); }
  264. std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
  265. const std::function<bool(const AnfNodePtr &)> &filter_func,
  266. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
  267. std::set<AnfNodePtr> *seen) {
  268. // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
  269. // For backward search, the other multi-inputs node can be contained in.
  270. // For forward search, the other multi-outputs node can be contained in.
  271. auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; }
  272. : [](const NodeRelation &info) { return info.nexts; };
  273. auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; }
  274. : [](const NodeRelation &info) { return info.pres; };
  275. std::vector<AnfNodePtrList> group;
  276. for (const auto &node : nodes) {
  277. AnfNodePtrList stream;
  278. AnfNodePtr n = node;
  279. for (auto iter = node_rels.find(n);
  280. seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
  281. iter = node_rels.find(n)) {
  282. if (filter_func(n)) {
  283. stream.push_back(n);
  284. (void)seen->insert(n);
  285. }
  286. if (get_contain_node_set(iter->second).size() != 1) {
  287. break;
  288. }
  289. n = *(get_contain_node_set(iter->second).cbegin());
  290. }
  291. if (stream.size() > 0) {
  292. group.push_back(stream);
  293. }
  294. }
  295. if (group.size() == 1) {
  296. for (const auto &drop : group[0]) {
  297. (void)seen->erase(drop);
  298. }
  299. group.clear();
  300. }
  301. return group;
  302. }
  303. void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
  304. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
  305. std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
  306. auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; }
  307. : [](const NodeRelation &info) { return info.nexts; };
  308. for (const auto &node : multi_nodes) {
  309. if (auto iter = node_rels.find(node); iter != node_rels.end()) {
  310. const auto &pre_nodes = get_related_nodes(iter->second);
  311. AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
  312. groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen));
  313. }
  314. }
  315. // Erase empty groups.
  316. for (auto iter = groups->begin(); iter != groups->end();) {
  317. if (iter->size() == 0) {
  318. iter = groups->erase(iter);
  319. } else {
  320. ++iter;
  321. }
  322. }
  323. }
  324. void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
  325. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
  326. std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
  327. groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen));
  328. // Erase empty groups.
  329. for (auto iter = groups->begin(); iter != groups->end();) {
  330. if (iter->size() == 0) {
  331. iter = groups->erase(iter);
  332. } else {
  333. ++iter;
  334. }
  335. }
  336. }
  337. std::string DumpNode(const AnfNodePtr &node) {
  338. auto cnode = node->cast<CNodePtr>();
  339. MS_EXCEPTION_IF_NULL(cnode);
  340. std::stringstream buf;
  341. buf << (common::AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
  342. << cnode->ToString();
  343. return buf.str();
  344. }
  345. void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups, const std::string &title = "") {
  346. MS_LOG(INFO) << "[" << title << "]"
  347. << "There are " << groups.size() << " parallel groups, their detail is: ";
  348. int i = 0;
  349. for (const auto group : groups) {
  350. std::stringstream buf;
  351. buf << "[" << i << " group] " << group.size() << ":\n";
  352. for (const auto nodes : group) {
  353. buf << " " << nodes.size() << ": [<";
  354. for (const auto node : nodes) {
  355. buf << "(" << DumpNode(node) << ") -> ";
  356. }
  357. buf << ">]\n";
  358. }
  359. i++;
  360. MS_LOG(INFO) << buf.str();
  361. }
  362. }
  363. void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
  364. std::stringstream buf;
  365. buf << "Parallel fusion detail: ";
  366. for (const auto &node : source) {
  367. buf << "(" << DumpNode(node) << ") + ";
  368. }
  369. buf << "==>"
  370. << "(" << DumpNode(target) << ")";
  371. MS_LOG(INFO) << buf.str();
  372. }
  373. inline bool ParameterLimit(const AnfNodePtrList &nodes) {
  374. if (nodes.empty()) {
  375. MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition.";
  376. }
  377. bool res = true;
  378. auto processor_type = AnfAlgo::GetProcessor(nodes[0]);
  379. if (processor_type == kernel::Processor::CUDA) {
  380. // The number of inputs and outputs for a valid kernel should be less than cuda's limit.
  381. size_t para_count = 0;
  382. for (const auto &node : nodes) {
  383. para_count += common::AnfAlgo::GetInputTensorNum(node);
  384. para_count += common::AnfAlgo::GetOutputTensorNum(node);
  385. }
  386. res = para_count <= CUDA_PARA_LIMIT;
  387. }
  388. return res;
  389. }
  390. bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); }
  391. } // namespace
  392. OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
  393. // Based on anf node input information, build a simple graph for latter analyzation.
  394. OrderedMap<AnfNodePtr, NodeRelation> node_rels;
  395. auto get_info = [&node_rels](const AnfNodePtr &node) {
  396. if (node_rels.count(node) == 0) {
  397. (void)node_rels.emplace(node, NodeRelation());
  398. }
  399. return &(node_rels[node]);
  400. };
  401. for (const auto &node : nodes) {
  402. if (!node->isa<CNode>()) {
  403. continue;
  404. }
  405. auto prior_node = get_info(node);
  406. for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
  407. if (!input->isa<CNode>()) {
  408. continue;
  409. }
  410. auto behind_node = get_info(input);
  411. (void)prior_node->pres.insert(input);
  412. (void)behind_node->nexts.insert(node);
  413. }
  414. }
  415. ProcessThroughPassCNode(
  416. [](const AnfNodePtr &node) {
  417. return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
  418. },
  419. &node_rels);
  420. ProcessTailMakeTupleCNode(&node_rels);
  421. ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
  422. return node_rels;
  423. }
  424. std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
  425. const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
  426. // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
  427. auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
  428. GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
  429. // Get streams and group them
  430. std::set<AnfNodePtr> seen;
  431. std::vector<std::vector<AnfNodePtrList>> groups;
  432. SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
  433. SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
  434. SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
  435. SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
  436. DumpParallelGroups(groups, "Dependency Analyze");
  437. return groups;
  438. }
  439. std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
  440. int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
  441. const std::set<int> &excludes) {
  442. // Get unused nodes by offset index, the result will contain the node with start index.
  443. int node_limit = static_cast<int>(nodes.size());
  444. if (start >= node_limit) {
  445. MS_LOG(EXCEPTION) << "Index offset should be less than the limit of given nodes " << node_limit << ", but got "
  446. << start;
  447. }
  448. AnfNodePtrList target_nodes = {nodes[IntToSize(start)]};
  449. std::vector<int> valid_indices;
  450. std::vector<size_t> unused;
  451. for (size_t i = IntToSize(start); i < used.size(); ++i) {
  452. if (!used[i] && excludes.count(i) == 0) {
  453. unused.push_back(i);
  454. }
  455. }
  456. size_t limit = unused.size();
  457. for (auto offset : offsets) {
  458. if (offset >= limit) {
  459. MS_LOG(EXCEPTION) << "Index offset should be less than the limit of unused nodes " << limit << ", but got "
  460. << offset;
  461. }
  462. if (SizeToInt(unused[offset]) >= node_limit) {
  463. MS_LOG(EXCEPTION) << "Index offset should be less than the limit of nodes " << node_limit << ", but got "
  464. << unused[offset];
  465. }
  466. valid_indices.push_back(unused[offset]);
  467. target_nodes.push_back(nodes[unused[offset]]);
  468. }
  469. return std::make_tuple(target_nodes, valid_indices);
  470. }
  471. std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
  472. size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
  473. std::map<AnfNodePtr, int> *sorted_indices) {
  474. auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
  475. MS_EXCEPTION_IF_NULL(node);
  476. if (indices->find(node) == indices->end()) {
  477. MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
  478. }
  479. return (*indices)[node];
  480. };
  481. std::vector<ParallelInfo> parallel_infos;
  482. std::vector<bool> origin_candidates_used(origin_size, false);
  483. std::vector<bool> sorted_candidates_used(candidates.size(), false);
  484. size_t offset;
  485. for (size_t i = 0; i < candidates.size(); i += offset + 1) {
  486. offset = 0;
  487. if (sorted_candidates_used[i]) {
  488. continue;
  489. }
  490. int max_benefit = 0;
  491. ParallelInfo best_parallel_info;
  492. size_t unused_num = 0;
  493. for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
  494. unused_num += sorted_candidates_used[j] ? 0 : 1;
  495. }
  496. if (unused_num < 1) {
  497. break;
  498. }
  499. unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
  500. size_t begin = 1, end = unused_num;
  501. while (begin <= end) {
  502. size_t mid = (begin + end) / 2;
  503. std::vector<size_t> tc(mid);
  504. for (size_t idx = 0; idx < mid; idx++) {
  505. tc[idx] = idx + 1;
  506. }
  507. AnfNodePtrList other_candidates;
  508. std::tie(other_candidates, std::ignore) =
  509. GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
  510. if (ExtraFusionCondition(other_candidates)) {
  511. int benefit;
  512. std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
  513. if (benefit > 0) {
  514. begin = mid + 1;
  515. continue;
  516. }
  517. }
  518. end = mid - 1;
  519. }
  520. if (begin > 1) {
  521. std::vector<size_t> tc(begin - 1);
  522. for (size_t idx = 0; idx < begin - 1; idx++) {
  523. tc[idx] = idx + 1;
  524. }
  525. AnfNodePtrList other_candidates;
  526. std::tie(other_candidates, std::ignore) =
  527. GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
  528. auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
  529. if (benefit <= 0) {
  530. MS_LOG(EXCEPTION) << "Internal error in candidate search! benefit should be greater than 0, but got "
  531. << benefit;
  532. }
  533. max_benefit = benefit;
  534. best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
  535. offset = begin - 1;
  536. }
  537. if (max_benefit > 0) {
  538. parallel_infos.push_back(best_parallel_info);
  539. for (const auto &node : best_parallel_info.nodes()) {
  540. sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true;
  541. origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true;
  542. }
  543. }
  544. }
  545. // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
  546. if (parallel_infos.size() == 0) {
  547. origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true;
  548. }
  549. return std::make_tuple(origin_candidates_used, parallel_infos);
  550. }
  551. std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
  552. const AnfNodePtrList &cs) {
  553. std::map<AnfNodePtr, int> origin_indices;
  554. std::vector<size_t> indices;
  555. for (size_t i = 0; i < cs.size(); ++i) {
  556. if (cs[i]) {
  557. origin_indices[cs[i]] = SizeToInt(i);
  558. indices.push_back(i);
  559. }
  560. }
  561. // A calculated heavy node can cover more lighter nodes' cost, so sort them first.
  562. std::map<size_t, int> cal_amounts;
  563. for (auto id : indices) {
  564. cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
  565. }
  566. std::sort(indices.begin(), indices.end(),
  567. [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
  568. AnfNodePtrList candidates;
  569. for (size_t i = 0; i < indices.size(); ++i) {
  570. candidates.push_back(cs[indices[i]]);
  571. }
  572. std::map<AnfNodePtr, int> sorted_indices;
  573. for (size_t i = 0; i < candidates.size(); ++i) {
  574. sorted_indices[candidates[i]] = SizeToInt(i);
  575. }
  576. return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
  577. }
  578. void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
  579. std::vector<ParallelInfo> *parallel_infos) {
  580. std::vector<AnfNodePtrList::const_iterator> tails;
  581. std::vector<AnfNodePtrList::const_iterator> ended;
  582. for (const auto &node_list : group) {
  583. tails.push_back(node_list.begin());
  584. ended.push_back(node_list.end());
  585. }
  586. auto get_candidates = [&tails, &ended]() {
  587. AnfNodePtrList candidates;
  588. for (size_t id = 0; id < tails.size(); ++id) {
  589. candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
  590. }
  591. return candidates;
  592. };
  593. auto update_tails = [&tails](const std::vector<bool> &used) {
  594. if (used.size() != tails.size()) {
  595. MS_LOG(EXCEPTION) << "Judged nodes size is different from left ones size: " << used.size() << " vs "
  596. << tails.size();
  597. }
  598. for (size_t id = 0; id < used.size(); ++id) {
  599. if (used[id]) {
  600. ++tails[id];
  601. }
  602. }
  603. };
  604. auto valid_candidate_num = [](const AnfNodePtrList &cs) {
  605. return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
  606. };
  607. auto candidates = get_candidates();
  608. while (valid_candidate_num(candidates) > 1) {
  609. auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
  610. (void)std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
  611. [](const ParallelInfo &pi) { return pi; });
  612. update_tails(used);
  613. candidates = get_candidates();
  614. }
  615. }
  616. std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
  617. const std::vector<std::vector<AnfNodePtrList>> &groups) {
  618. // Find core-fusable groups with cost model.
  619. std::vector<ParallelInfo> parallel_infos;
  620. for (const auto &group : groups) {
  621. SearchFuseNodesInParallelGroup(group, &parallel_infos);
  622. }
  623. return parallel_infos;
  624. }
  625. void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info) {
  626. AnfNodePtr attach_node;
  627. // Dim info should be attach to each segment's output.
  628. for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
  629. const auto &fuse_nodes = parallel_info.nodes();
  630. std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
  631. if (!common::AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
  632. attach_node = fuse_nodes[i];
  633. SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
  634. } else {
  635. auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
  636. auto out_node = node_g->output();
  637. if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
  638. auto inputs = out_node->cast<CNodePtr>()->inputs();
  639. for (size_t j = 1; j < inputs.size(); ++j) {
  640. SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
  641. }
  642. attach_node = inputs[1];
  643. } else {
  644. attach_node = out_node;
  645. SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
  646. }
  647. }
  648. }
  649. // Fusion info is ok to attach to one of the segments.
  650. SetFusionInfoAttrToNode(attach_node, parallel_info);
  651. }
  652. void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info) {
  653. auto fusion_type = parallel_info.fusion_info()->FusionType();
  654. common::AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
  655. if (parallel_info.fusion_info()->ExistTypeInfo()) {
  656. if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
  657. common::AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
  658. MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
  659. }
  660. }
  661. }
  662. bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
  663. const std::shared_ptr<session::KernelGraph> &kernel_graph) {
  664. bool changed = false;
  665. for (size_t i = 0; i < parallel_infos.size(); ++i) {
  666. const auto &fuse_nodes = parallel_infos[i].nodes();
  667. if (fuse_nodes.size() <= 1) {
  668. continue;
  669. }
  670. changed = true;
  671. SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
  672. auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel");
  673. common::AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
  674. DumpParallelFusionDetail(fuse_nodes, sg_node);
  675. }
  676. return changed;
  677. }
  678. std::set<AnfNodePtr> CollectCapturedNodes(const std::vector<ParallelInfo> &infos) {
  679. std::set<AnfNodePtr> captured;
  680. (void)std::for_each(infos.cbegin(), infos.cend(), [&captured](const ParallelInfo &info) {
  681. captured.insert(info.nodes().begin(), info.nodes().end());
  682. });
  683. return captured;
  684. }
  685. std::vector<std::vector<AnfNodePtrList>> GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels,
  686. const std::set<AnfNodePtr> &exclude) {
  687. std::vector<std::vector<AnfNodePtrList>> groups;
  688. // BFS
  689. std::queue<AnfNodePtr> node_que;
  690. std::unordered_map<AnfNodePtr, int> outdegrees;
  691. for (const auto &[node, ref] : node_rels) {
  692. outdegrees[node] = SizeToInt(ref.nexts.size());
  693. if (outdegrees[node] == 0) {
  694. node_que.push(node);
  695. }
  696. }
  697. int total_node_num = SizeToInt(node_rels.size());
  698. while (!node_que.empty()) {
  699. std::vector<AnfNodePtrList> group;
  700. int node_size = SizeToInt(node_que.size());
  701. while (node_size--) {
  702. auto node = node_que.front();
  703. node_que.pop();
  704. if (exclude.count(node) == 0 && Parallelizable(node)) {
  705. (void)group.emplace_back(AnfNodePtrList({node}));
  706. }
  707. --total_node_num;
  708. auto iter = node_rels.find(node);
  709. if (iter == node_rels.end()) {
  710. MS_LOG(EXCEPTION) << "Internal error in node relationship!";
  711. }
  712. for (const auto &pre : iter->second.pres) {
  713. if (--outdegrees[pre] == 0) {
  714. node_que.push(pre);
  715. }
  716. }
  717. }
  718. if (!group.empty()) {
  719. groups.push_back(group);
  720. }
  721. }
  722. if (total_node_num > 0) {
  723. MS_LOG(EXCEPTION) << "There is circle in analyze graph!";
  724. }
  725. DumpParallelGroups(groups, "BFS");
  726. return groups;
  727. }
  728. bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
  729. MS_EXCEPTION_IF_NULL(graph);
  730. parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level;
  731. (void)std::make_shared<ShrinkUpdateState>()->Run(graph);
  732. auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
  733. MS_EXCEPTION_IF_NULL(kernel_graph);
  734. cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
  735. MS_EXCEPTION_IF_NULL(cost_model_ptr_);
  736. auto nodes = TopoSort(kernel_graph->get_return());
  737. std::reverse(nodes.begin(), nodes.end());
  738. auto node_rels = GenAnalysisGraph(nodes);
  739. auto groups = SearchParallelGroups(node_rels);
  740. auto parallel_infos = SearchFusableParallelCNodes(groups);
  741. // Search in BFS for left nodes.
  742. if (parallel_level_ > 0) {
  743. auto exclued_nodes = CollectCapturedNodes(parallel_infos);
  744. auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes);
  745. auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs);
  746. (void)parallel_infos.insert(parallel_infos.end(), bfs_parallel_infos.begin(), bfs_parallel_infos.end());
  747. }
  748. // Create core-fuse subgraph and change origin graph.
  749. bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
  750. (void)std::make_shared<SpreadUpdateState>()->Run(graph);
  751. return changed;
  752. }
  753. } // namespace mindspore::graphkernel