You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functionalize_while.cc 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *conv_activation_fusion.h
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <memory>
  18. #include <deque>
  19. #include "tools/optimizer/graph/functionalize_while.h"
  20. #include "include/errorcode.h"
  21. #include "ops/make_tuple.h"
  22. #include "ops/return.h"
  23. #include "ops/tuple_get_item.h"
  24. #include "tools/converter/ops/while.h"
  25. namespace {
  26. mindspore::ValueNodePtr GetWhileAnfPrim() {
  27. auto while_primc = std::make_shared<mindspore::lite::While>();
  28. if (while_primc == nullptr) {
  29. MS_LOG(ERROR) << "new while_primitive failed";
  30. return nullptr;
  31. }
  32. while_primc->set_cond_subgraph_index(mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex());
  33. while_primc->set_body_subgraph_index(mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex());
  34. mindspore::ValueNodePtr partial_anf_prim = NewValueNode(while_primc);
  35. return partial_anf_prim;
  36. }
  37. } // namespace
  38. namespace mindspore::opt {
  39. using mindspore::lite::RET_NULL_PTR;
  40. CNodePtr FunctionalizeWhile::BlongToWhichSwitch(const CNodePtr &node) {
  41. return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsSwitch);
  42. }
  43. CNodePtr FunctionalizeWhile::BlongToWhichMerge(const CNodePtr &node) {
  44. return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsMerge);
  45. }
  46. CNodePtr FunctionalizeWhile::BlongToWhichEnter(const CNodePtr &node) {
  47. return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsEnter);
  48. }
  49. CNodePtr FunctionalizeWhile::BlongToWhichExternalEnter(const CNodePtr &node) {
  50. if (node == nullptr) {
  51. return nullptr;
  52. }
  53. if (FunctionalizeControlOpPass::IsEnter(node)) {
  54. return node;
  55. }
  56. CNodePtr aim_node = nullptr;
  57. std::deque<AnfNodePtr> todo(256);
  58. todo.clear();
  59. for (auto &input_node : node->inputs()) {
  60. if (FunctionalizeControlOpPass::IsEnter(input_node) && WhileNodeExternalInputIsContain(input_node)) {
  61. aim_node = utils::cast<CNodePtr>(input_node);
  62. todo.clear();
  63. break;
  64. }
  65. todo.push_back(input_node);
  66. }
  67. while (!todo.empty()) {
  68. AnfNodePtr todo_node = todo.front();
  69. todo.pop_front();
  70. if (FunctionalizeControlOpPass::IsEnter(todo_node) && WhileNodeExternalInputIsContain(todo_node)) {
  71. aim_node = utils::cast<CNodePtr>(todo_node);
  72. todo.clear();
  73. break;
  74. }
  75. if (utils::isa<CNodePtr>(todo_node)) {
  76. auto cnode = utils::cast<CNodePtr>(todo_node);
  77. for (size_t i = 0; i < cnode->inputs().size(); i++) {
  78. todo.push_back(cnode->input(i));
  79. }
  80. }
  81. }
  82. if (aim_node == nullptr) {
  83. MS_LOG(WARNING) << "not found belonging enter node.";
  84. return nullptr;
  85. }
  86. return aim_node;
  87. }
  88. int FunctionalizeWhile::PosInInputEnterNodes(const CNodePtr &node) {
  89. auto index = std::find(input_enter_nodes_.begin(), input_enter_nodes_.end(), node);
  90. if (index == input_enter_nodes_.end()) {
  91. return POS_INVALID;
  92. }
  93. return index - input_enter_nodes_.begin();
  94. }
  95. STATUS FunctionalizeWhile::NewWhileNode() {
  96. ValueNodePtr while_anf_primitive = GetWhileAnfPrim();
  97. if (while_anf_primitive == nullptr) {
  98. MS_LOG(ERROR) << "Get while anf primitive failed.";
  99. return RET_NULL_PTR;
  100. }
  101. static int count = 0;
  102. std::vector<AnfNodePtr> while_op_inputs = {while_anf_primitive};
  103. while_node_ = fg_->NewCNode(while_op_inputs);
  104. while_node_->set_fullname_with_scope(loop_cond_node_->fullname_with_scope() + "-while-" + std::to_string(count++));
  105. return RET_OK;
  106. }
  107. STATUS FunctionalizeWhile::IdentifyWhileNodeInput() {
  108. for (auto &node : node_cluster_) {
  109. if (FunctionalizeControlOpPass::IsEnter(node)) {
  110. auto enter_cnode = node->cast<CNodePtr>();
  111. input_enter_nodes_.push_back(enter_cnode);
  112. while_node_->add_input(enter_cnode->input(1));
  113. }
  114. }
  115. if (input_enter_nodes_.empty()) {
  116. MS_LOG(ERROR) << "not found input of while node.";
  117. return RET_ERROR;
  118. }
  119. return RET_OK;
  120. }
  121. STATUS FunctionalizeWhile::IdentifyWhileNodeExternalInput() {
  122. std::deque<AnfNodePtr> todo(128);
  123. std::vector<CNodePtr> merge_nodes{};
  124. todo.clear();
  125. for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) {
  126. todo.push_back(loop_cond_node_->input(i));
  127. }
  128. while (!todo.empty()) {
  129. AnfNodePtr node = todo.front();
  130. todo.pop_front();
  131. if (FunctionalizeControlOpPass::IsMerge(node)) {
  132. merge_nodes.push_back(node->cast<CNodePtr>());
  133. continue;
  134. }
  135. if (utils::isa<CNodePtr>(node)) {
  136. auto cnode = utils::cast<CNodePtr>(node);
  137. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  138. todo.push_back(cnode->input(i));
  139. }
  140. }
  141. }
  142. for (auto &node : merge_nodes) {
  143. external_input_enter_nodes_.push_back(node->input(1)->cast<CNodePtr>());
  144. }
  145. return RET_OK;
  146. }
  147. bool FunctionalizeWhile::WhileNodeExternalInputIsContain(const AnfNodePtr &node) {
  148. auto cnode = node->cast<CNodePtr>();
  149. return std::find(external_input_enter_nodes_.begin(), external_input_enter_nodes_.end(), cnode) !=
  150. external_input_enter_nodes_.end();
  151. }
  152. STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() {
  153. output_exit_nodes_.resize(external_input_enter_nodes_.size());
  154. for (auto &node : node_cluster_) {
  155. // exit ->switch->merge->enter
  156. if (FunctionalizeControlOpPass::IsExit(node)) {
  157. auto exit_node = node->cast<CNodePtr>();
  158. auto switch_node = BlongToWhichSwitch(exit_node);
  159. auto merge_node = BlongToWhichMerge(switch_node);
  160. auto enter_node = BlongToWhichExternalEnter(merge_node);
  161. int pos = PosInInputEnterNodes(enter_node);
  162. if (pos == -1) {
  163. MS_LOG(ERROR) << "not find in input enter nodes.";
  164. return RET_ERROR;
  165. }
  166. output_exit_nodes_.at(pos) = exit_node;
  167. }
  168. }
  169. if (output_exit_nodes_.size() == 1) {
  170. while_node_->set_abstract(output_exit_nodes_[0]->abstract());
  171. } else {
  172. AbstractBasePtrList abstract_list;
  173. abstract_list.resize(output_exit_nodes_.size());
  174. std::transform(output_exit_nodes_.begin(), output_exit_nodes_.end(), abstract_list.begin(),
  175. [](const CNodePtr &cnode) { return cnode->abstract(); });
  176. while_node_->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
  177. }
  178. return RET_OK;
  179. }
  180. STATUS FunctionalizeWhile::UpdateExitNodeUser() {
  181. auto manager = fg_->manager();
  182. if (output_exit_nodes_.size() == 1) {
  183. if (!manager->Replace(output_exit_nodes_[0], while_node_)) {
  184. MS_LOG(ERROR) << "replace node failed.";
  185. return RET_ERROR;
  186. }
  187. } else {
  188. for (auto &node : output_exit_nodes_) {
  189. auto node_users = manager->node_users()[node];
  190. for (auto &node_user : node_users) {
  191. // new getitem
  192. AbstractBasePtrList abstractList;
  193. std::vector<int64_t> shape_vector;
  194. abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
  195. auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
  196. if (tuple_get_item_prim_ptr == nullptr) {
  197. MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
  198. return RET_NULL_PTR;
  199. }
  200. auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
  201. const auto &exit_node = node;
  202. auto switch_node = BlongToWhichSwitch(exit_node);
  203. auto merge_node = BlongToWhichMerge(switch_node);
  204. auto enter_node = BlongToWhichExternalEnter(merge_node);
  205. int output_idx = PosInInputEnterNodes(enter_node);
  206. auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
  207. std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
  208. CNodePtr get_item_node = fg_->NewCNode(inputs);
  209. std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
  210. auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
  211. if (abstract == nullptr) {
  212. MS_LOG(ERROR) << "create AbstractTensor failed";
  213. return RET_NULL_PTR;
  214. }
  215. get_item_node->set_abstract(abstract);
  216. get_item_node->set_fullname_with_scope(output_item_name);
  217. // set
  218. if (fg_->nodes().contains(node_user.first)) {
  219. if (!manager->Replace(output_exit_nodes_[0], while_node_)) {
  220. MS_LOG(ERROR) << "replace node failed.";
  221. return RET_ERROR;
  222. }
  223. }
  224. }
  225. }
  226. }
  227. return RET_OK;
  228. }
  229. STATUS FunctionalizeWhile::BuildWhileNode() {
  230. int ret = NewWhileNode();
  231. if (ret != RET_OK) {
  232. MS_LOG(ERROR) << "new while node failed, ret:" << ret;
  233. return ret;
  234. }
  235. ret = IdentifyWhileNodeInput();
  236. if (ret != RET_OK) {
  237. MS_LOG(ERROR) << "identify while node input failed, ret:" << ret;
  238. return ret;
  239. }
  240. ret = IdentifyWhileNodeExternalInput();
  241. if (ret != RET_OK) {
  242. MS_LOG(ERROR) << "identify while node external input failed, ret:" << ret;
  243. return ret;
  244. }
  245. ret = IdentifyWhileNodeOutput();
  246. if (ret != RET_OK) {
  247. MS_LOG(ERROR) << "identify while node output failed, ret:" << ret;
  248. return ret;
  249. }
  250. // update exit node user from exit to while
  251. ret = UpdateExitNodeUser();
  252. if (ret != RET_OK) {
  253. MS_LOG(ERROR) << "update while node users, ret:" << ret;
  254. return ret;
  255. }
  256. return ret;
  257. }
  258. // nodes between loop_cond op and merge op be added into cond_func_graph
  259. STATUS FunctionalizeWhile::CondSubgraphAddNodes() {
  260. std::deque<AnfNodePtr> todo(512);
  261. todo.clear();
  262. for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) {
  263. todo.push_back(loop_cond_node_->input(i));
  264. }
  265. while (!todo.empty()) {
  266. AnfNodePtr node = todo.back();
  267. todo.pop_back();
  268. if (FunctionalizeControlOpPass::IsMerge(node)) {
  269. continue;
  270. }
  271. if (utils::isa<ParameterPtr>(node)) {
  272. cond_sub_func_graph_->add_parameter(node->cast<ParameterPtr>());
  273. } else {
  274. cond_sub_func_graph_->AddNode(node);
  275. }
  276. node->set_func_graph(cond_sub_func_graph_);
  277. if (utils::isa<CNodePtr>(node)) {
  278. auto cnode = utils::cast<CNodePtr>(node);
  279. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  280. todo.push_back(cnode->input(i));
  281. }
  282. }
  283. }
  284. return RET_OK;
  285. }
  286. STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() {
  287. std::vector<AnfNodePtr> nodes_need_drop{};
  288. for (auto &cnode : cond_sub_func_graph_->GetOrderedCnodes()) {
  289. for (auto &input_node : cnode->inputs()) {
  290. if (FunctionalizeControlOpPass::IsMerge(input_node)) {
  291. auto merge_node = input_node->cast<CNodePtr>();
  292. auto enter_node = BlongToWhichEnter(merge_node);
  293. int pos = PosInInputEnterNodes(enter_node);
  294. nodes_need_drop.push_back(cnode);
  295. // set parameter
  296. auto parameter = cond_sub_func_graph_->add_parameter();
  297. parameter->set_abstract(cnode->abstract());
  298. // hardcode for subgraph input name
  299. parameter->set_name(cond_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter");
  300. // replace merge
  301. auto manager = fg_->manager();
  302. auto node_users = manager->node_users()[cnode];
  303. for (auto &node_user : node_users) {
  304. if (cond_sub_func_graph_->nodes().contains(node_user.first)) {
  305. manager->SetEdge(node_user.first, node_user.second, parameter);
  306. }
  307. }
  308. }
  309. }
  310. }
  311. // drop node from cond_func_graph
  312. for (const auto &node : nodes_need_drop) {
  313. cond_sub_func_graph_->DropNode(node);
  314. }
  315. return RET_OK;
  316. }
  317. STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
  318. auto return_prim_ptr = std::make_shared<ops::Return>();
  319. if (return_prim_ptr == nullptr) {
  320. MS_LOG(ERROR) << "GetReturnPrim return nullptr";
  321. return RET_NULL_PTR;
  322. }
  323. auto value_node = NewValueNode(return_prim_ptr);
  324. if (value_node == nullptr) {
  325. MS_LOG(ERROR) << "new value_node failed.";
  326. return RET_NULL_PTR;
  327. }
  328. // cond subgraph output is LoopCond's input
  329. std::vector<AnfNodePtr> op_inputs{value_node, loop_cond_node_->input(1)};
  330. auto return_cnode = cond_sub_func_graph_->NewCNode(op_inputs);
  331. return_cnode->set_fullname_with_scope(cond_subgraph_name_ + "-return");
  332. cond_sub_func_graph_->set_return(return_cnode);
  333. // hardcode subgraph outputs name
  334. cond_sub_func_graph_->output()->cast<CNodePtr>()->set_fullname_with_scope(cond_subgraph_name_ + "_output_0_cnode");
  335. return RET_OK;
  336. }
  337. STATUS FunctionalizeWhile::BuildCondGraph() {
  338. cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond";
  339. cond_sub_func_graph_ =
  340. FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF);
  341. if (cond_sub_func_graph_ == nullptr) {
  342. MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr";
  343. return RET_NULL_PTR;
  344. }
  345. cond_sub_func_graph_->set_manager(fg_->manager());
  346. int ret = CondSubgraphAddNodes();
  347. if (ret != RET_OK) {
  348. MS_LOG(ERROR) << "add cond_subgraph node failed, ret:" << ret;
  349. return ret;
  350. }
  351. ret = IdentifyCondSubgraphOutput();
  352. if (ret != RET_OK) {
  353. MS_LOG(ERROR) << "identify cond_subgraph output failed, ret:" << ret;
  354. return ret;
  355. }
  356. ret = IdentifyCondSubgraphInput();
  357. if (ret != RET_OK) {
  358. MS_LOG(ERROR) << "identify cond_subgraph input failed, ret:" << ret;
  359. return ret;
  360. }
  361. return ret;
  362. }
  363. // nodes between next_iteration op and switch op will be added into body_func_graph
  364. STATUS FunctionalizeWhile::BodySubgraphAddNodes() {
  365. std::deque<AnfNodePtr> todo(512);
  366. todo.clear();
  367. for (auto &node : node_cluster_) {
  368. if (FunctionalizeControlOpPass::IsNextIteration(node)) {
  369. auto next_iteration_cnode = node->cast<CNodePtr>();
  370. for (size_t i = 1; i < next_iteration_cnode->inputs().size(); i++) {
  371. todo.push_back(next_iteration_cnode->input(i));
  372. }
  373. body_subgraph_output_map_[node] = next_iteration_cnode->input(1);
  374. }
  375. }
  376. while (!todo.empty()) {
  377. AnfNodePtr node = todo.back();
  378. todo.pop_back();
  379. if (FunctionalizeControlOpPass::IsSwitch(node)) {
  380. continue;
  381. }
  382. if (utils::isa<ParameterPtr>(node)) {
  383. body_sub_func_graph_->add_parameter(node->cast<ParameterPtr>());
  384. } else {
  385. body_sub_func_graph_->AddNode(node);
  386. }
  387. node->set_func_graph(body_sub_func_graph_);
  388. if (utils::isa<CNodePtr>(node)) {
  389. auto cnode = utils::cast<CNodePtr>(node);
  390. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  391. todo.push_back(cnode->input(i));
  392. }
  393. }
  394. }
  395. return RET_OK;
  396. }
  397. STATUS FunctionalizeWhile::IdentifyBodySubgraphInput() {
  398. std::vector<AnfNodePtr> nodes_need_drop{};
  399. for (auto &cnode : body_sub_func_graph_->GetOrderedCnodes()) {
  400. for (auto &input_node : cnode->inputs()) {
  401. if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
  402. auto switch_node = input_node->cast<CNodePtr>();
  403. auto merge_node = BlongToWhichMerge(switch_node);
  404. auto enter_node = BlongToWhichEnter(merge_node);
  405. int pos = PosInInputEnterNodes(enter_node);
  406. if (pos == POS_INVALID) {
  407. continue;
  408. }
  409. nodes_need_drop.push_back(cnode);
  410. // set parameter
  411. auto parameter = body_sub_func_graph_->add_parameter();
  412. parameter->set_abstract(cnode->abstract());
  413. // hardcode for subgraph input name
  414. parameter->set_name(body_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter");
  415. // replace switch
  416. auto manager = fg_->manager();
  417. auto node_users = manager->node_users()[cnode];
  418. for (auto &node_user : node_users) {
  419. if (body_sub_func_graph_->nodes().contains(node_user.first)) {
  420. manager->SetEdge(node_user.first, node_user.second, parameter);
  421. }
  422. }
  423. }
  424. }
  425. }
  426. // drop node from cond_func_graph
  427. for (const auto &node : nodes_need_drop) {
  428. body_sub_func_graph_->DropNode(node);
  429. }
  430. return RET_OK;
  431. }
  432. STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
  433. std::vector<AnfNodePtr> tmp_output{};
  434. tmp_output.resize(input_enter_nodes_.size());
  435. for (auto &node_pair : body_subgraph_output_map_) {
  436. auto next_iteration_cnode = utils::cast<CNodePtr>(node_pair.first);
  437. auto switch_node = BlongToWhichSwitch(next_iteration_cnode);
  438. auto merge_node = BlongToWhichMerge(switch_node);
  439. auto enter_node = BlongToWhichEnter(merge_node);
  440. int pos = PosInInputEnterNodes(enter_node);
  441. if (pos == POS_INVALID) {
  442. continue;
  443. }
  444. tmp_output[pos] = node_pair.second;
  445. // hard code. set cnode output name
  446. node_pair.second->cast<CNodePtr>()->set_fullname_with_scope(body_subgraph_name_ + "_output_" + std::to_string(pos) +
  447. "_cnode");
  448. }
  449. auto return_prim_ptr = std::make_shared<ops::Return>();
  450. if (return_prim_ptr == nullptr) {
  451. MS_LOG(ERROR) << "GetReturnPrim return nullptr";
  452. return RET_NULL_PTR;
  453. }
  454. auto value_node = NewValueNode(return_prim_ptr);
  455. // cond subgraph output is LoopCond's input
  456. std::vector<AnfNodePtr> op_inputs{value_node};
  457. auto return_cnode = body_sub_func_graph_->NewCNode(op_inputs);
  458. return_cnode->set_fullname_with_scope(body_subgraph_name_ + "-return");
  459. if (tmp_output.size() == 1) {
  460. return_cnode->add_input(tmp_output[0]);
  461. } else {
  462. std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
  463. auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
  464. if (make_tuple_prim_ptr == nullptr) {
  465. MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
  466. return RET_NULL_PTR;
  467. }
  468. auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
  469. make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim);
  470. auto make_tuple_cnode = body_sub_func_graph_->NewCNode(make_tuple_inputs);
  471. make_tuple_cnode->set_fullname_with_scope(return_cnode->fullname_with_scope() + "tuple");
  472. return_cnode->add_input(make_tuple_cnode);
  473. }
  474. body_sub_func_graph_->set_return(return_cnode);
  475. return RET_OK;
  476. }
  477. STATUS FunctionalizeWhile::BuildBodyGraph() {
  478. body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body";
  479. body_sub_func_graph_ =
  480. FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF);
  481. if (body_sub_func_graph_ == nullptr) {
  482. MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr";
  483. return RET_NULL_PTR;
  484. }
  485. body_sub_func_graph_->set_manager(fg_->manager());
  486. int ret = BodySubgraphAddNodes();
  487. if (ret != RET_OK) {
  488. MS_LOG(ERROR) << "add body_subgraph node failed, ret:" << ret;
  489. return ret;
  490. }
  491. ret = IdentifyBodySubgraphOutput();
  492. if (ret != RET_OK) {
  493. MS_LOG(ERROR) << "identify body_subgraph output failed, ret:" << ret;
  494. return ret;
  495. }
  496. ret = IdentifyBodySubgraphInput();
  497. if (ret != RET_OK) {
  498. MS_LOG(ERROR) << "identify body_subgraph input failed, ret:" << ret;
  499. return ret;
  500. }
  501. return ret;
  502. }
  503. STATUS FunctionalizeWhile::InsertFuncGraphToWhileInput() {
  504. // set while input cond and body vnode
  505. auto cond_value_node = NewValueNode(cond_sub_func_graph_);
  506. auto body_value_node = NewValueNode(body_sub_func_graph_);
  507. auto inputs = while_node_->inputs();
  508. inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
  509. while_node_->set_inputs(inputs);
  510. return RET_OK;
  511. }
  512. STATUS FunctionalizeWhile::DropUselessNodesInMainGraph() {
  513. // fg_ drop cluster node
  514. for (auto &node : node_cluster_) {
  515. fg_->DropNode(node);
  516. }
  517. return RET_OK;
  518. }
  519. STATUS FunctionalizeWhile::Process() {
  520. int ret = BuildWhileNode();
  521. if (ret != RET_OK) {
  522. MS_LOG(ERROR) << "build while node failed, ret:" << ret;
  523. return ret;
  524. }
  525. ret = BuildCondGraph();
  526. if (ret != RET_OK) {
  527. MS_LOG(ERROR) << "build condition graph failed, ret:" << ret;
  528. return ret;
  529. }
  530. ret = BuildBodyGraph();
  531. if (ret != RET_OK) {
  532. MS_LOG(ERROR) << "build body graph failed, ret:" << ret;
  533. return ret;
  534. }
  535. ret = InsertFuncGraphToWhileInput();
  536. if (ret != RET_OK) {
  537. MS_LOG(ERROR) << "insert func_graph to while input failed, ret:" << ret;
  538. return ret;
  539. }
  540. ret = DropUselessNodesInMainGraph();
  541. if (ret != RET_OK) {
  542. MS_LOG(ERROR) << "main func_graph drop nodes failed, ret:" << ret;
  543. return ret;
  544. }
  545. return ret;
  546. }
  547. } // namespace mindspore::opt