You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

recompute.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/recompute.h"
  17. #include <memory>
  18. #include <queue>
  19. #include <utility>
  20. #include <list>
  21. #include <vector>
  22. #include <algorithm>
  23. #include "utils/hash_map.h"
  24. #include "utils/hash_set.h"
  25. #include "ir/func_graph.h"
  26. #include "mindspore/core/base/core_ops.h"
  27. #include "utils/utils.h"
  28. namespace mindspore {
  29. namespace opt {
  30. namespace {
  31. constexpr auto kGradientsFlag = "Gradients";
  32. const int64_t fusion_id_increasement_size = 2000;
  33. bool CanNotRecomputed(const CNodePtr &node) {
  34. static mindspore::HashSet<PrimitivePtr> not_recomputed_op_list{
  35. prim::kPrimDropoutGenMask, prim::kPrimLoad, prim::kPrimTupleGetItem, prim::kPrimSend, prim::kPrimReceive};
  36. return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
  37. [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
  38. }
  39. bool IsBpropNode(const AnfNodePtr &node) {
  40. MS_EXCEPTION_IF_NULL(node);
  41. if (!node->isa<CNode>()) {
  42. return false;
  43. }
  44. return node->fullname_with_scope().find(kGradientsFlag) == 0;
  45. }
  46. bool WithRecomputedScope(const AnfNodePtr &node) {
  47. MS_EXCEPTION_IF_NULL(node);
  48. if (!node->isa<CNode>()) {
  49. return false;
  50. }
  51. auto full_name_with_scope = node->fullname_with_scope();
  52. return full_name_with_scope.find(kAttrRecompute) == 0;
  53. }
  54. ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) {
  55. MS_EXCEPTION_IF_NULL(node);
  56. auto cnode = node->cast<CNodePtr>();
  57. if (cnode == nullptr) {
  58. return nullptr;
  59. }
  60. return cnode->GetAttr(kAttrRecompute);
  61. }
  62. bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node) {
  63. auto cnode_recompute_val = GetRecomputeCNodeAttr(node);
  64. return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && !GetValue<bool>(cnode_recompute_val);
  65. }
  66. bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) {
  67. auto cnode_recompute_val = GetRecomputeCNodeAttr(node);
  68. return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
  69. }
  70. bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && IsSetRecomputeCNodeAttr(node); }
  71. std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
  72. const std::vector<CNodePtr> &cnodes) {
  73. MS_EXCEPTION_IF_NULL(mng);
  74. std::vector<CNodePtr> candidate_recomputed_nodes;
  75. for (const auto &cnode : cnodes) {
  76. MS_EXCEPTION_IF_NULL(cnode);
  77. if (!IsCandidateRecomputedNode(cnode)) {
  78. continue;
  79. }
  80. // Check outputs.
  81. const auto &node_users = mng->node_users();
  82. auto output_set_iter = node_users.find(cnode);
  83. if (output_set_iter == node_users.end()) {
  84. continue;
  85. }
  86. const auto &node_index_set = output_set_iter->second;
  87. if (!std::any_of(node_index_set.begin(), node_index_set.end(),
  88. [](const auto &node_index) { return IsBpropNode(node_index.first); })) {
  89. continue;
  90. }
  91. // Check inputs.
  92. const auto &inputs = cnode->inputs();
  93. if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) {
  94. continue;
  95. }
  96. (void)candidate_recomputed_nodes.emplace_back(cnode);
  97. }
  98. return candidate_recomputed_nodes;
  99. }
  100. void GetMaxSubGraph(const FuncGraphManagerPtr &mng, mindspore::HashSet<CNodePtr> *recomputed_nodes, bool get_inputs,
  101. bool get_outputs) {
  102. MS_EXCEPTION_IF_NULL(mng);
  103. MS_EXCEPTION_IF_NULL(recomputed_nodes);
  104. std::queue<CNodePtr> nodes_to_visit;
  105. for (const auto &node : *recomputed_nodes) {
  106. nodes_to_visit.push(node);
  107. }
  108. recomputed_nodes->clear();
  109. while (!nodes_to_visit.empty()) {
  110. auto current_node = nodes_to_visit.front();
  111. nodes_to_visit.pop();
  112. recomputed_nodes->insert(current_node);
  113. // No need to find nodes through side-effect dependency.
  114. if (IsPrimitiveCNode(current_node, prim::kPrimUpdateState)) {
  115. continue;
  116. }
  117. if (get_inputs) {
  118. for (const auto &input : current_node->inputs()) {
  119. MS_EXCEPTION_IF_NULL(input);
  120. if (input->isa<CNode>()) {
  121. auto input_cnode = input->cast<CNodePtr>();
  122. if (recomputed_nodes->find(input_cnode) == recomputed_nodes->end() &&
  123. IsCandidateRecomputedNode(input_cnode)) {
  124. nodes_to_visit.push(input_cnode);
  125. }
  126. }
  127. }
  128. }
  129. if (get_outputs) {
  130. const auto &node_users = mng->node_users();
  131. auto output_set_iter = node_users.find(current_node);
  132. if (output_set_iter == node_users.end()) {
  133. continue;
  134. }
  135. for (const auto &node_index_set : output_set_iter->second) {
  136. auto output_node = node_index_set.first;
  137. MS_EXCEPTION_IF_NULL(output_node);
  138. if (output_node->isa<CNode>()) {
  139. auto output_cnode = output_node->cast<CNodePtr>();
  140. if (recomputed_nodes->find(output_cnode) == recomputed_nodes->end() &&
  141. IsCandidateRecomputedNode(output_cnode)) {
  142. nodes_to_visit.push(output_cnode);
  143. }
  144. }
  145. }
  146. }
  147. }
  148. }
  149. void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
  150. const mindspore::HashSet<CNodePtr> &max_recomputed_sub_graph,
  151. mindspore::HashSet<CNodePtr> *recompute_nodes,
  152. mindspore::HashSet<CNodePtr> *target_nodes) {
  153. MS_EXCEPTION_IF_NULL(mng);
  154. MS_EXCEPTION_IF_NULL(recompute_nodes);
  155. MS_EXCEPTION_IF_NULL(target_nodes);
  156. const auto &node_users = mng->node_users();
  157. for (const auto &node : max_recomputed_sub_graph) {
  158. bool inserted = false;
  159. auto output_set_iter = node_users.find(node);
  160. if (output_set_iter == node_users.end()) {
  161. continue;
  162. }
  163. for (const auto &node_index_set : output_set_iter->second) {
  164. auto output_node = node_index_set.first;
  165. MS_EXCEPTION_IF_NULL(output_node);
  166. if (!IsBpropNode(output_node)) {
  167. continue;
  168. }
  169. target_nodes->insert(output_node->cast<CNodePtr>());
  170. if (!inserted) {
  171. recompute_nodes->insert(node);
  172. inserted = true;
  173. }
  174. }
  175. }
  176. }
  177. std::vector<AnfNodePtr> GetFirstTargetInputs(const std::vector<CNodePtr> &origin_nodes_topological,
  178. const mindspore::HashSet<CNodePtr> &recomputed_origin_nodes,
  179. const mindspore::HashSet<CNodePtr> &target_nodes) {
  180. std::vector<AnfNodePtr> first_target_inputs;
  181. for (const auto &node : origin_nodes_topological) {
  182. MS_EXCEPTION_IF_NULL(node);
  183. if (target_nodes.find(node) != target_nodes.end()) {
  184. for (size_t i = 1; i < node->size(); ++i) {
  185. auto input = node->input(i);
  186. MS_EXCEPTION_IF_NULL(input);
  187. if (!input->isa<CNode>()) {
  188. continue;
  189. }
  190. if (recomputed_origin_nodes.find(input->cast<CNodePtr>()) != recomputed_origin_nodes.end()) {
  191. continue;
  192. }
  193. (void)first_target_inputs.emplace_back(input);
  194. }
  195. break;
  196. }
  197. }
  198. return first_target_inputs;
  199. }
  200. bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, bool> *has_grad_inputs_map) {
  201. MS_EXCEPTION_IF_NULL(node);
  202. MS_EXCEPTION_IF_NULL(has_grad_inputs_map);
  203. if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) {
  204. return has_grad_inputs_map->find(node)->second;
  205. }
  206. auto cnode = node->cast<CNodePtr>();
  207. if (cnode == nullptr) {
  208. (void)has_grad_inputs_map->emplace(node, false);
  209. return false;
  210. }
  211. const auto &inputs = cnode->inputs();
  212. for (size_t i = 0; i < inputs.size(); ++i) {
  213. // For the pipeline split case, the forward pass may depend on the backward pass.
  214. if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) {
  215. continue;
  216. }
  217. if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) {
  218. (void)has_grad_inputs_map->emplace(node, true);
  219. return true;
  220. }
  221. }
  222. (void)has_grad_inputs_map->emplace(node, false);
  223. return false;
  224. }
  225. bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
  226. MS_EXCEPTION_IF_NULL(mng);
  227. const auto &node_users = mng->node_users();
  228. auto output_set_iter = node_users.find(node);
  229. if (output_set_iter == node_users.end()) {
  230. return false;
  231. }
  232. return std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
  233. [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); });
  234. }
  235. void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
  236. std::vector<AnfNodePtr> *tuple_getitem_output_nodes) {
  237. MS_EXCEPTION_IF_NULL(mng);
  238. MS_EXCEPTION_IF_NULL(tuple_getitem_output_nodes);
  239. const auto &node_users = mng->node_users();
  240. auto output_set_iter = node_users.find(node);
  241. if (output_set_iter == node_users.end()) {
  242. return;
  243. }
  244. for (const auto &node_index_set : output_set_iter->second) {
  245. if (IsPrimitiveCNode(node_index_set.first, prim::kPrimTupleGetItem)) {
  246. (void)tuple_getitem_output_nodes->emplace_back(node_index_set.first);
  247. }
  248. }
  249. }
  250. bool SetRecomputedScope(const CNodePtr &node) {
  251. return WithRecomputedScope(node) ||
  252. (IsPrimitiveCNode(node, prim::kPrimDepend) && WithRecomputedScope(node->input(kRealInputIndexInDepend)));
  253. }
  254. // Set 'recompute' cnode attr for the nodes according to its scope.
  255. // A node set 'recompute' cnode attr can become the candidate recomputed node.
  256. void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) {
  257. MS_EXCEPTION_IF_NULL(graph);
  258. auto mng = graph->manager();
  259. MS_EXCEPTION_IF_NULL(mng);
  260. mindspore::HashMap<AnfNodePtr, bool> has_grad_inputs_map;
  261. for (const auto &node : origin_nodes_topological) {
  262. MS_EXCEPTION_IF_NULL(node);
  263. // The node may be set the non-recomputed before such as the cell outputs.
  264. if (IsSetNoRecomputeCNodeAttr(node)) {
  265. continue;
  266. }
  267. if (IsBpropNode(node)) {
  268. continue;
  269. }
  270. // Filter some unrecomputable operators.
  271. if (CanNotRecomputed(node)) {
  272. continue;
  273. }
  274. if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
  275. continue;
  276. }
  277. auto cnode = node->cast<CNodePtr>();
  278. MS_EXCEPTION_IF_NULL(cnode);
  279. auto prim = GetCNodePrimitive(cnode);
  280. if (prim == nullptr) {
  281. continue;
  282. }
  283. auto prim_recompute_attr = prim->GetAttr(kAttrRecompute);
  284. int prim_recompute_val = -1;
  285. if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
  286. prim_recompute_val = static_cast<int>(GetValue<bool>(prim_recompute_attr));
  287. }
  288. if ((SetRecomputedScope(cnode) && prim_recompute_val != 0) || prim_recompute_val == 1) {
  289. cnode->AddAttr(kAttrRecompute, MakeValue(true));
  290. }
  291. if (!IsSetRecomputeCNodeAttr(node)) {
  292. continue;
  293. }
  294. // Set attr for the tuple_getitem outputs.
  295. std::vector<AnfNodePtr> tuple_getitem_output_nodes;
  296. GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes);
  297. for (const auto &output_node : tuple_getitem_output_nodes) {
  298. auto output_cnode = output_node->cast<CNodePtr>();
  299. MS_EXCEPTION_IF_NULL(output_cnode);
  300. output_cnode->AddAttr(kAttrRecompute, MakeValue(true));
  301. }
  302. }
  303. }
  304. CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
  305. const std::vector<AnfNodePtr> &new_inputs) {
  306. auto recomputed_node = graph->NewCNode(new_inputs);
  307. MS_EXCEPTION_IF_NULL(recomputed_node);
  308. recomputed_node->AddAttr("duplicated", MakeValue(true));
  309. recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
  310. recomputed_node->set_abstract(origin_node->abstract());
  311. recomputed_node->set_scope(origin_node->scope());
  312. return recomputed_node;
  313. }
  314. CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
  315. const std::vector<AnfNodePtr> &first_target_inputs,
  316. const mindspore::HashSet<CNodePtr> &recomputed_origin_nodes,
  317. mindspore::HashMap<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
  318. MS_EXCEPTION_IF_NULL(graph);
  319. MS_EXCEPTION_IF_NULL(origin_node);
  320. MS_EXCEPTION_IF_NULL(origin_to_recomputed_nodes);
  321. auto iter = origin_to_recomputed_nodes->find(origin_node);
  322. if (iter != origin_to_recomputed_nodes->end()) {
  323. return iter->second;
  324. }
  325. MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString();
  326. std::vector<AnfNodePtr> new_inputs;
  327. bool has_recomputed_inputs = false;
  328. for (size_t i = 0; i < origin_node->size(); ++i) {
  329. auto input = origin_node->input(i);
  330. if (i == 0 && IsPrimitive(input, prim::kPrimAllGather)) {
  331. auto prim = GetValueNode<PrimitivePtr>(input);
  332. auto instance_name = prim->instance_name();
  333. bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos;
  334. int64_t fusion_id = prim->HasAttr(kAttrFusion) ? GetValue<int64_t>(prim->GetAttr(kAttrFusion)) : 0;
  335. if (is_from_parallel_optimizer && fusion_id > 0) {
  336. auto new_prim = std::make_shared<Primitive>(prim::kPrimAllGather->name());
  337. new_prim->SetAttrs(prim->attrs());
  338. new_prim->set_attr(kAttrFusion, MakeValue(fusion_id + fusion_id_increasement_size));
  339. new_prim->set_prim_type(prim->prim_type());
  340. new_prim->set_instance_name(instance_name);
  341. auto value_node = NewValueNode(new_prim);
  342. (void)new_inputs.emplace_back(value_node);
  343. continue;
  344. }
  345. }
  346. MS_EXCEPTION_IF_NULL(input);
  347. if (!input->isa<CNode>()) {
  348. (void)new_inputs.emplace_back(input);
  349. continue;
  350. }
  351. auto input_cnode = input->cast<CNodePtr>();
  352. if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) {
  353. if (IsPrimitiveCNode(input_cnode, prim::kPrimUpdateState)) {
  354. auto u = NewValueNode(kUMonad);
  355. u->set_abstract(kUMonad->ToAbstract());
  356. (void)new_inputs.emplace_back(u);
  357. } else {
  358. (void)new_inputs.emplace_back(input);
  359. }
  360. } else {
  361. has_recomputed_inputs = true;
  362. (void)new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes,
  363. origin_to_recomputed_nodes));
  364. }
  365. }
  366. // Add the execution dependency.
  367. if (!has_recomputed_inputs && new_inputs.size() > 1) {
  368. std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
  369. std::copy(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(make_tuple_inputs));
  370. auto first_input = new_inputs[1];
  371. MS_EXCEPTION_IF_NULL(first_input);
  372. std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), first_input,
  373. graph->NewCNode(make_tuple_inputs)};
  374. auto depend_node = graph->NewCNode(depend_inputs);
  375. MS_EXCEPTION_IF_NULL(depend_node);
  376. depend_node->set_abstract(first_input->abstract());
  377. depend_node->AddAttr("recompute_depend", MakeValue(true));
  378. new_inputs[1] = depend_node;
  379. }
  380. auto recomputed_node = CreateNewRecomputedNode(graph, origin_node, new_inputs);
  381. (void)origin_to_recomputed_nodes->emplace(origin_node, recomputed_node);
  382. return recomputed_node;
  383. }
  384. void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSet<CNodePtr> &target_nodes,
  385. const mindspore::HashSet<CNodePtr> &origin_recomputed_nodes,
  386. const std::vector<AnfNodePtr> &first_target_inputs,
  387. mindspore::HashMap<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
  388. MS_EXCEPTION_IF_NULL(graph);
  389. auto mng = graph->manager();
  390. MS_EXCEPTION_IF_NULL(mng);
  391. for (const auto &target_node : target_nodes) {
  392. MS_EXCEPTION_IF_NULL(target_node);
  393. MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
  394. auto target_cnode = target_node->cast<CNodePtr>();
  395. MS_EXCEPTION_IF_NULL(target_cnode);
  396. std::vector<AnfNodePtr> new_target_inputs;
  397. for (const auto &input : target_cnode->inputs()) {
  398. MS_EXCEPTION_IF_NULL(input);
  399. if (!input->isa<CNode>()) {
  400. (void)new_target_inputs.emplace_back(input);
  401. } else {
  402. auto input_cnode = input->cast<CNodePtr>();
  403. if (origin_recomputed_nodes.find(input_cnode) != origin_recomputed_nodes.end()) {
  404. (void)new_target_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs,
  405. origin_recomputed_nodes, origin_to_recomputed_nodes));
  406. } else {
  407. (void)new_target_inputs.emplace_back(input_cnode);
  408. }
  409. }
  410. }
  411. auto new_target_node = graph->NewCNode(new_target_inputs);
  412. new_target_node->CloneCNodeInfo(target_node);
  413. new_target_node->AddAttr("target_grad", MakeValue(true));
  414. new_target_node->set_scope(target_node->scope());
  415. mng->Replace(target_node, new_target_node);
  416. }
  417. }
  418. } // namespace
  419. void InsertRecomputedNodes(const FuncGraphPtr &graph) {
  420. MS_EXCEPTION_IF_NULL(graph);
  421. auto mng = graph->manager();
  422. MS_EXCEPTION_IF_NULL(mng);
  423. std::list<CNodePtr> orders = graph->GetOrderedCnodes();
  424. std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end());
  425. SetRecomputedAttr(graph, origin_nodes_topological);
  426. // Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly.
  427. std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological);
  428. mindspore::HashSet<CNodePtr> visited_nodes;
  429. for (const auto &candidate_recomputed_node : candidate_recomputed_nodes) {
  430. if (visited_nodes.find(candidate_recomputed_node) != visited_nodes.end()) {
  431. continue;
  432. }
  433. mindspore::HashSet<CNodePtr> max_recomputed_sub_graph = {candidate_recomputed_node};
  434. // Get max continuous recomputed sub-graph.
  435. GetMaxSubGraph(mng, &max_recomputed_sub_graph, true, true);
  436. visited_nodes.insert(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end());
  437. // Get the origin recomputed nodes which directly output to the grad nodes.
  438. mindspore::HashSet<CNodePtr> origin_recomputed_nodes;
  439. mindspore::HashSet<CNodePtr> target_nodes;
  440. GetOriginRecomputeAndTargetNodes(mng, max_recomputed_sub_graph, &origin_recomputed_nodes, &target_nodes);
  441. // Also get the inputs of origin recomputed nodes which eventually output to the grad nodes.
  442. GetMaxSubGraph(mng, &origin_recomputed_nodes, true, false);
  443. // Get the inputs of the first target node in the topological sequence. The duplicated recomputed nodes should
  444. // not be executed until these inputs are ready.
  445. std::vector<AnfNodePtr> first_target_inputs =
  446. GetFirstTargetInputs(origin_nodes_topological, origin_recomputed_nodes, target_nodes);
  447. mindspore::HashMap<CNodePtr, CNodePtr> origin_to_recomputed_nodes;
  448. // Begin duplicate origin recomputed nodes with each target node.
  449. DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs,
  450. &origin_to_recomputed_nodes);
  451. }
  452. // Set need cse attr for doing cse after recompute.
  453. for (const auto &node : orders) {
  454. if (WithRecomputedScope(node)) {
  455. node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
  456. }
  457. }
  458. }
  459. } // namespace opt
  460. } // namespace mindspore