You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

recompute.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/recompute.h"
  17. #include <memory>
  18. #include <queue>
  19. #include <utility>
  20. #include <list>
  21. #include <vector>
  22. #include <unordered_map>
  23. #include <unordered_set>
  24. #include <algorithm>
  25. #include "ir/func_graph.h"
  26. #include "mindspore/core/base/core_ops.h"
  27. #include "utils/utils.h"
  28. namespace mindspore {
  29. namespace opt {
  30. namespace {
  31. constexpr auto kGradientsFlag = "Gradients";
  32. bool IsBpropNode(const AnfNodePtr &node) {
  33. MS_EXCEPTION_IF_NULL(node);
  34. if (!node->isa<CNode>()) {
  35. return false;
  36. }
  37. return node->fullname_with_scope().find(kGradientsFlag) == 0;
  38. }
  39. bool WithRecomputedScope(const AnfNodePtr &node) {
  40. MS_EXCEPTION_IF_NULL(node);
  41. if (!node->isa<CNode>()) {
  42. return false;
  43. }
  44. auto full_name_with_scope = node->fullname_with_scope();
  45. return full_name_with_scope.find(kAttrRecompute) == 0;
  46. }
  47. bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {
  48. auto cnode = node->cast<CNodePtr>();
  49. if (cnode == nullptr) {
  50. return false;
  51. }
  52. auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute);
  53. return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
  54. }
  55. bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); }
  56. std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
  57. const std::vector<CNodePtr> &cnodes) {
  58. MS_EXCEPTION_IF_NULL(mng);
  59. std::vector<CNodePtr> candidate_recomputed_nodes;
  60. for (const auto &cnode : cnodes) {
  61. MS_EXCEPTION_IF_NULL(cnode);
  62. if (!IsCandidateRecomputedNode(cnode)) {
  63. continue;
  64. }
  65. // Check outputs.
  66. const auto &node_users = mng->node_users();
  67. auto output_set_iter = node_users.find(cnode);
  68. if (output_set_iter == node_users.end()) {
  69. continue;
  70. }
  71. const auto &node_index_set = output_set_iter->second;
  72. if (!std::any_of(node_index_set.begin(), node_index_set.end(),
  73. [](const std::pair<AnfNodePtr, int> &node_index) { return IsBpropNode(node_index.first); })) {
  74. continue;
  75. }
  76. // Check inputs.
  77. const auto &inputs = cnode->inputs();
  78. if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) {
  79. continue;
  80. }
  81. candidate_recomputed_nodes.emplace_back(cnode);
  82. }
  83. return candidate_recomputed_nodes;
  84. }
  85. void GetMaxSubGraph(const FuncGraphManagerPtr &mng, std::unordered_set<CNodePtr> *recomputed_nodes, bool get_inputs,
  86. bool get_outputs) {
  87. MS_EXCEPTION_IF_NULL(mng);
  88. MS_EXCEPTION_IF_NULL(recomputed_nodes);
  89. std::queue<CNodePtr> nodes_to_visit;
  90. for (const auto &node : *recomputed_nodes) {
  91. nodes_to_visit.push(node);
  92. }
  93. recomputed_nodes->clear();
  94. while (!nodes_to_visit.empty()) {
  95. auto current_node = nodes_to_visit.front();
  96. nodes_to_visit.pop();
  97. recomputed_nodes->insert(current_node);
  98. if (get_inputs) {
  99. for (const auto &input : current_node->inputs()) {
  100. MS_EXCEPTION_IF_NULL(input);
  101. if (input->isa<CNode>()) {
  102. auto input_cnode = input->cast<CNodePtr>();
  103. if (recomputed_nodes->find(input_cnode) == recomputed_nodes->end() &&
  104. IsCandidateRecomputedNode(input_cnode)) {
  105. nodes_to_visit.push(input_cnode);
  106. }
  107. }
  108. }
  109. }
  110. if (get_outputs) {
  111. const auto &node_users = mng->node_users();
  112. auto output_set_iter = node_users.find(current_node);
  113. if (output_set_iter == node_users.end()) {
  114. continue;
  115. }
  116. for (const auto &node_index_set : output_set_iter->second) {
  117. auto output_node = node_index_set.first;
  118. MS_EXCEPTION_IF_NULL(output_node);
  119. if (output_node->isa<CNode>()) {
  120. auto output_cnode = output_node->cast<CNodePtr>();
  121. if (recomputed_nodes->find(output_cnode) == recomputed_nodes->end() &&
  122. IsCandidateRecomputedNode(output_cnode)) {
  123. nodes_to_visit.push(output_cnode);
  124. }
  125. }
  126. }
  127. }
  128. }
  129. }
  130. void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
  131. const std::unordered_set<CNodePtr> &max_recomputed_sub_graph,
  132. std::unordered_set<CNodePtr> *recompute_nodes,
  133. std::unordered_set<CNodePtr> *target_nodes) {
  134. MS_EXCEPTION_IF_NULL(mng);
  135. MS_EXCEPTION_IF_NULL(recompute_nodes);
  136. MS_EXCEPTION_IF_NULL(target_nodes);
  137. const auto &node_users = mng->node_users();
  138. for (const auto &node : max_recomputed_sub_graph) {
  139. bool inserted = false;
  140. auto output_set_iter = node_users.find(node);
  141. if (output_set_iter == node_users.end()) {
  142. continue;
  143. }
  144. for (const auto &node_index_set : output_set_iter->second) {
  145. auto output_node = node_index_set.first;
  146. MS_EXCEPTION_IF_NULL(output_node);
  147. if (!IsBpropNode(output_node)) {
  148. continue;
  149. }
  150. target_nodes->insert(output_node->cast<CNodePtr>());
  151. if (!inserted) {
  152. recompute_nodes->insert(node);
  153. inserted = true;
  154. }
  155. }
  156. }
  157. }
  158. std::vector<AnfNodePtr> GetFirstTargetInputs(const std::vector<CNodePtr> &origin_nodes_topological,
  159. const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
  160. const std::unordered_set<CNodePtr> &target_nodes) {
  161. std::vector<AnfNodePtr> first_target_inputs;
  162. for (const auto &node : origin_nodes_topological) {
  163. MS_EXCEPTION_IF_NULL(node);
  164. if (target_nodes.find(node) != target_nodes.end()) {
  165. for (size_t i = 1; i < node->size(); ++i) {
  166. auto input = node->input(i);
  167. if (!input->isa<CNode>()) {
  168. continue;
  169. }
  170. MS_EXCEPTION_IF_NULL(input);
  171. if (recomputed_origin_nodes.find(input->cast<CNodePtr>()) != recomputed_origin_nodes.end()) {
  172. continue;
  173. }
  174. first_target_inputs.emplace_back(input);
  175. }
  176. break;
  177. }
  178. }
  179. return first_target_inputs;
  180. }
  181. bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool> *has_grad_inputs_map) {
  182. MS_EXCEPTION_IF_NULL(node);
  183. MS_EXCEPTION_IF_NULL(has_grad_inputs_map);
  184. if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) {
  185. return has_grad_inputs_map->find(node)->second;
  186. }
  187. auto cnode = node->cast<CNodePtr>();
  188. if (cnode == nullptr) {
  189. has_grad_inputs_map->insert(std::make_pair(node, false));
  190. return false;
  191. }
  192. const auto &inputs = cnode->inputs();
  193. if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
  194. return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map);
  195. })) {
  196. has_grad_inputs_map->insert(std::make_pair(node, true));
  197. return true;
  198. }
  199. has_grad_inputs_map->insert(std::make_pair(node, false));
  200. return false;
  201. }
  202. bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
  203. MS_EXCEPTION_IF_NULL(mng);
  204. const auto &node_users = mng->node_users();
  205. auto output_set_iter = node_users.find(node);
  206. if (output_set_iter == node_users.end()) {
  207. return false;
  208. }
  209. for (const auto &node_index_set : output_set_iter->second) {
  210. if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
  211. return true;
  212. }
  213. }
  214. return false;
  215. }
  216. void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
  217. std::vector<AnfNodePtr> *tuple_getitem_output_nodes) {
  218. MS_EXCEPTION_IF_NULL(mng);
  219. MS_EXCEPTION_IF_NULL(tuple_getitem_output_nodes);
  220. const auto &node_users = mng->node_users();
  221. auto output_set_iter = node_users.find(node);
  222. if (output_set_iter == node_users.end()) {
  223. return;
  224. }
  225. for (const auto &node_index_set : output_set_iter->second) {
  226. if (IsPrimitiveCNode(node_index_set.first, prim::kPrimTupleGetItem)) {
  227. tuple_getitem_output_nodes->emplace_back(node_index_set.first);
  228. }
  229. }
  230. }
  231. // Set 'recompute' cnode attr for the nodes according to its scope.
  232. // A node set 'recompute' cnode attr can become the candidate recomputed node.
  233. void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) {
  234. MS_EXCEPTION_IF_NULL(graph);
  235. auto mng = graph->manager();
  236. MS_EXCEPTION_IF_NULL(mng);
  237. std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
  238. for (const auto &node : origin_nodes_topological) {
  239. MS_EXCEPTION_IF_NULL(node);
  240. if (IsBpropNode(node)) {
  241. continue;
  242. }
  243. // Do not recompute the communicate op.
  244. if (IsPrimitiveCNode(node, prim::kPrimAllGather)) {
  245. continue;
  246. }
  247. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  248. continue;
  249. }
  250. if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
  251. continue;
  252. }
  253. auto cnode = node->cast<CNodePtr>();
  254. MS_EXCEPTION_IF_NULL(cnode);
  255. auto prim = GetCNodePrimitive(cnode);
  256. if (prim == nullptr) {
  257. continue;
  258. }
  259. auto prim_recompute_attr = prim->GetAttr(kAttrRecompute);
  260. int prim_recompute_val = -1;
  261. if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
  262. prim_recompute_val = GetValue<bool>(prim_recompute_attr);
  263. }
  264. if ((WithRecomputedScope(node) && prim_recompute_val != 0) || prim_recompute_val == 1) {
  265. cnode->AddAttr(kAttrRecompute, MakeValue(true));
  266. }
  267. if (!HasRecomputeCNodeAttr(node)) {
  268. continue;
  269. }
  270. // Set attr for the tuple_getitem outputs.
  271. std::vector<AnfNodePtr> tuple_getitem_output_nodes;
  272. GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes);
  273. for (const auto &output_node : tuple_getitem_output_nodes) {
  274. auto output_cnode = output_node->cast<CNodePtr>();
  275. MS_EXCEPTION_IF_NULL(output_cnode);
  276. output_cnode->AddAttr(kAttrRecompute, MakeValue(true));
  277. }
  278. }
  279. }
  280. CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
  281. const std::vector<AnfNodePtr> &first_target_inputs,
  282. const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
  283. std::unordered_map<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
  284. MS_EXCEPTION_IF_NULL(graph);
  285. MS_EXCEPTION_IF_NULL(origin_node);
  286. MS_EXCEPTION_IF_NULL(origin_to_recomputed_nodes);
  287. auto iter = origin_to_recomputed_nodes->find(origin_node);
  288. if (iter != origin_to_recomputed_nodes->end()) {
  289. return iter->second;
  290. }
  291. MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString();
  292. std::vector<AnfNodePtr> new_inputs;
  293. bool has_recomputed_inputs = false;
  294. for (size_t i = 0; i < origin_node->size(); ++i) {
  295. auto input = origin_node->input(i);
  296. MS_EXCEPTION_IF_NULL(input);
  297. if (!input->isa<CNode>()) {
  298. new_inputs.emplace_back(input);
  299. continue;
  300. }
  301. auto input_cnode = input->cast<CNodePtr>();
  302. if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) {
  303. new_inputs.emplace_back(input);
  304. } else {
  305. has_recomputed_inputs = true;
  306. new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes,
  307. origin_to_recomputed_nodes));
  308. }
  309. }
  310. // Add the execution dependency.
  311. if (!has_recomputed_inputs && new_inputs.size() > 1) {
  312. std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
  313. std::copy(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(make_tuple_inputs));
  314. auto first_input = new_inputs[1];
  315. MS_EXCEPTION_IF_NULL(first_input);
  316. std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), first_input,
  317. graph->NewCNode(make_tuple_inputs)};
  318. auto depend_node = graph->NewCNode(depend_inputs);
  319. MS_EXCEPTION_IF_NULL(depend_node);
  320. depend_node->set_abstract(first_input->abstract());
  321. new_inputs[1] = depend_node;
  322. }
  323. auto recomputed_node = graph->NewCNode(new_inputs);
  324. MS_EXCEPTION_IF_NULL(recomputed_node);
  325. recomputed_node->AddAttr("duplicated", MakeValue(true));
  326. recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
  327. recomputed_node->set_abstract(origin_node->abstract());
  328. recomputed_node->set_scope(origin_node->scope());
  329. origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
  330. return recomputed_node;
  331. }
  332. void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_set<CNodePtr> &target_nodes,
  333. const std::unordered_set<CNodePtr> &origin_recomputed_nodes,
  334. const std::vector<AnfNodePtr> &first_target_inputs,
  335. std::unordered_map<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
  336. MS_EXCEPTION_IF_NULL(graph);
  337. auto mng = graph->manager();
  338. MS_EXCEPTION_IF_NULL(mng);
  339. for (const auto &target_node : target_nodes) {
  340. MS_EXCEPTION_IF_NULL(target_node);
  341. MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
  342. auto target_cnode = target_node->cast<CNodePtr>();
  343. MS_EXCEPTION_IF_NULL(target_cnode);
  344. std::vector<AnfNodePtr> new_target_inputs;
  345. for (const auto &input : target_cnode->inputs()) {
  346. MS_EXCEPTION_IF_NULL(input);
  347. if (!input->isa<CNode>()) {
  348. new_target_inputs.emplace_back(input);
  349. } else {
  350. auto input_cnode = input->cast<CNodePtr>();
  351. if (origin_recomputed_nodes.find(input_cnode) != origin_recomputed_nodes.end()) {
  352. new_target_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs,
  353. origin_recomputed_nodes, origin_to_recomputed_nodes));
  354. } else {
  355. new_target_inputs.emplace_back(input_cnode);
  356. }
  357. }
  358. }
  359. auto new_target_node = graph->NewCNode(new_target_inputs);
  360. new_target_node->AddAttr("target_grad", MakeValue(true));
  361. new_target_node->set_abstract(target_node->abstract());
  362. new_target_node->set_scope(target_node->scope());
  363. mng->Replace(target_node, new_target_node);
  364. }
  365. }
  366. } // namespace
  367. void InsertRecomputedNodes(const FuncGraphPtr &graph) {
  368. MS_EXCEPTION_IF_NULL(graph);
  369. auto mng = graph->manager();
  370. MS_EXCEPTION_IF_NULL(mng);
  371. std::list<CNodePtr> orders = graph->GetOrderedCnodes();
  372. std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end());
  373. SetRecomputedAttr(graph, origin_nodes_topological);
  374. // Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly.
  375. std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological);
  376. std::unordered_set<CNodePtr> visited_nodes;
  377. for (const auto &candidate_recomputed_node : candidate_recomputed_nodes) {
  378. if (visited_nodes.find(candidate_recomputed_node) != visited_nodes.end()) {
  379. continue;
  380. }
  381. std::unordered_set<CNodePtr> max_recomputed_sub_graph = {candidate_recomputed_node};
  382. // Get max continuous recomputed sub-graph.
  383. GetMaxSubGraph(mng, &max_recomputed_sub_graph, true, true);
  384. visited_nodes.insert(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end());
  385. // Get the origin recomputed nodes which directly output to the grad nodes.
  386. std::unordered_set<CNodePtr> origin_recomputed_nodes;
  387. std::unordered_set<CNodePtr> target_nodes;
  388. GetOriginRecomputeAndTargetNodes(mng, max_recomputed_sub_graph, &origin_recomputed_nodes, &target_nodes);
  389. // Also get the inputs of origin recomputed nodes which eventually output to the grad nodes.
  390. GetMaxSubGraph(mng, &origin_recomputed_nodes, true, false);
  391. // Get the inputs of the first target node in the topological sequence. The duplicated recomputed nodes should
  392. // not be executed until these inputs are ready.
  393. std::vector<AnfNodePtr> first_target_inputs =
  394. GetFirstTargetInputs(origin_nodes_topological, origin_recomputed_nodes, target_nodes);
  395. std::unordered_map<CNodePtr, CNodePtr> origin_to_recomputed_nodes;
  396. // Begin duplicate origin recomputed nodes with each target node.
  397. DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs,
  398. &origin_to_recomputed_nodes);
  399. }
  400. // Set need cse attr for doing cse after recompute.
  401. for (const auto &node : orders) {
  402. if (WithRecomputedScope(node)) {
  403. node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
  404. }
  405. }
  406. }
  407. } // namespace opt
  408. } // namespace mindspore