You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functionalize_while.cc 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *conv_activation_fusion.h
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <memory>
  18. #include <deque>
  19. #include "tools/optimizer/graph/functionalize_while.h"
  20. #include "mindspore/lite/include/errorcode.h"
  21. #include "src/ops/primitive_c.h"
  22. #include "src/ops/while.h"
  23. namespace {
  24. mindspore::ValueNodePtr GetWhileAnfPrim() {
  25. auto while_primitiveT = new (std::nothrow) mindspore::schema::PrimitiveT;
  26. if (while_primitiveT == nullptr) {
  27. MS_LOG(ERROR) << "new while_primitiveT failed";
  28. return nullptr;
  29. }
  30. while_primitiveT->value.type = mindspore::schema::PrimitiveType_While;
  31. auto whileT = new (std::nothrow) mindspore::schema::WhileT;
  32. whileT->condSubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex();
  33. whileT->bodySubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex();
  34. while_primitiveT->value.value = whileT;
  35. if (while_primitiveT->value.value == nullptr) {
  36. MS_LOG(ERROR) << "new WhileT failed";
  37. delete (while_primitiveT);
  38. return nullptr;
  39. }
  40. auto while_prim = std::make_shared<mindspore::lite::While>(while_primitiveT);
  41. mindspore::ValueNodePtr partial_anf_prim = NewValueNode(while_prim);
  42. return partial_anf_prim;
  43. }
  44. } // namespace
  45. namespace mindspore::opt {
  46. using mindspore::lite::RET_NULL_PTR;
  47. CNodePtr FunctionalizeWhile::BlongToWhichSwitch(const CNodePtr &node) {
  48. return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsSwitch);
  49. }
  50. CNodePtr FunctionalizeWhile::BlongToWhichMerge(const CNodePtr &node) {
  51. return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsMerge);
  52. }
  53. CNodePtr FunctionalizeWhile::BlongToWhichEnter(const CNodePtr &node) {
  54. return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsEnter);
  55. }
  56. CNodePtr FunctionalizeWhile::BlongToWhichExternalEnter(const CNodePtr &node) {
  57. if (node == nullptr) {
  58. return nullptr;
  59. }
  60. if (FunctionalizeControlOpPass::IsEnter(node)) {
  61. return node;
  62. }
  63. CNodePtr aim_node = nullptr;
  64. std::deque<AnfNodePtr> todo(256);
  65. todo.clear();
  66. for (auto &input_node : node->inputs()) {
  67. if (FunctionalizeControlOpPass::IsEnter(input_node) && WhileNodeExternalInputIsContain(input_node)) {
  68. aim_node = utils::cast<CNodePtr>(input_node);
  69. todo.clear();
  70. break;
  71. }
  72. todo.push_back(input_node);
  73. }
  74. while (!todo.empty()) {
  75. AnfNodePtr todo_node = todo.front();
  76. todo.pop_front();
  77. if (FunctionalizeControlOpPass::IsEnter(todo_node) && WhileNodeExternalInputIsContain(todo_node)) {
  78. aim_node = utils::cast<CNodePtr>(todo_node);
  79. todo.clear();
  80. break;
  81. }
  82. if (utils::isa<CNodePtr>(todo_node)) {
  83. auto cnode = utils::cast<CNodePtr>(todo_node);
  84. for (size_t i = 0; i < cnode->inputs().size(); i++) {
  85. todo.push_back(cnode->input(i));
  86. }
  87. }
  88. }
  89. if (aim_node == nullptr) {
  90. MS_LOG(WARNING) << "not found belonging enter node.";
  91. return nullptr;
  92. }
  93. return aim_node;
  94. }
  95. int FunctionalizeWhile::PosInInputEnterNodes(const CNodePtr &node) {
  96. auto index = std::find(input_enter_nodes_.begin(), input_enter_nodes_.end(), node);
  97. if (index == input_enter_nodes_.end()) {
  98. return POS_INVALID;
  99. }
  100. return index - input_enter_nodes_.begin();
  101. }
  102. STATUS FunctionalizeWhile::NewWhileNode() {
  103. ValueNodePtr while_anf_primitive = GetWhileAnfPrim();
  104. if (while_anf_primitive == nullptr) {
  105. MS_LOG(ERROR) << "Get while anf primitive failed.";
  106. return RET_NULL_PTR;
  107. }
  108. static int count = 0;
  109. std::vector<AnfNodePtr> while_op_inputs = {while_anf_primitive};
  110. while_node_ = fg_->NewCNode(while_op_inputs);
  111. while_node_->set_fullname_with_scope(loop_cond_node_->fullname_with_scope() + "-while-" + std::to_string(count++));
  112. return RET_OK;
  113. }
  114. STATUS FunctionalizeWhile::IdentifyWhileNodeInput() {
  115. for (auto &node : node_cluster_) {
  116. if (FunctionalizeControlOpPass::IsEnter(node)) {
  117. auto enter_cnode = node->cast<CNodePtr>();
  118. input_enter_nodes_.push_back(enter_cnode);
  119. while_node_->add_input(enter_cnode->input(1));
  120. }
  121. }
  122. if (input_enter_nodes_.empty()) {
  123. MS_LOG(ERROR) << "not found input of while node.";
  124. return RET_ERROR;
  125. }
  126. return RET_OK;
  127. }
  128. STATUS FunctionalizeWhile::IdentifyWhileNodeExternalInput() {
  129. std::deque<AnfNodePtr> todo(128);
  130. std::vector<CNodePtr> merge_nodes{};
  131. todo.clear();
  132. for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) {
  133. todo.push_back(loop_cond_node_->input(i));
  134. }
  135. while (!todo.empty()) {
  136. AnfNodePtr node = todo.front();
  137. todo.pop_front();
  138. if (FunctionalizeControlOpPass::IsMerge(node)) {
  139. merge_nodes.push_back(node->cast<CNodePtr>());
  140. continue;
  141. }
  142. if (utils::isa<CNodePtr>(node)) {
  143. auto cnode = utils::cast<CNodePtr>(node);
  144. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  145. todo.push_back(cnode->input(i));
  146. }
  147. }
  148. }
  149. for (auto &node : merge_nodes) {
  150. external_input_enter_nodes_.push_back(node->input(1)->cast<CNodePtr>());
  151. }
  152. return RET_OK;
  153. }
  154. bool FunctionalizeWhile::WhileNodeExternalInputIsContain(const AnfNodePtr &node) {
  155. auto cnode = node->cast<CNodePtr>();
  156. return std::find(external_input_enter_nodes_.begin(), external_input_enter_nodes_.end(), cnode) !=
  157. external_input_enter_nodes_.end();
  158. }
  159. STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() {
  160. output_exit_nodes_.resize(external_input_enter_nodes_.size());
  161. for (auto &node : node_cluster_) {
  162. // exit ->switch->merge->enter
  163. if (FunctionalizeControlOpPass::IsExit(node)) {
  164. auto exit_node = node->cast<CNodePtr>();
  165. auto switch_node = BlongToWhichSwitch(exit_node);
  166. auto merge_node = BlongToWhichMerge(switch_node);
  167. auto enter_node = BlongToWhichExternalEnter(merge_node);
  168. int pos = PosInInputEnterNodes(enter_node);
  169. if (pos == -1) {
  170. MS_LOG(ERROR) << "not find in input enter nodes.";
  171. return RET_ERROR;
  172. }
  173. output_exit_nodes_.at(pos) = exit_node;
  174. }
  175. }
  176. if (output_exit_nodes_.size() == 1) {
  177. while_node_->set_abstract(output_exit_nodes_[0]->abstract());
  178. } else {
  179. AbstractBasePtrList abstract_list;
  180. abstract_list.resize(output_exit_nodes_.size());
  181. std::transform(output_exit_nodes_.begin(), output_exit_nodes_.end(), abstract_list.begin(),
  182. [](const CNodePtr &cnode) { return cnode->abstract(); });
  183. while_node_->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
  184. }
  185. return RET_OK;
  186. }
  187. STATUS FunctionalizeWhile::UpdateExitNodeUser() {
  188. if (output_exit_nodes_.size() == 1) {
  189. auto manager = fg_->manager();
  190. auto node_users = manager->node_users()[output_exit_nodes_[0]];
  191. for (auto &node_user : node_users) {
  192. if (fg_->nodes().contains(node_user.first)) {
  193. manager->SetEdge(node_user.first, node_user.second, while_node_);
  194. }
  195. }
  196. } else {
  197. for (auto &node : output_exit_nodes_) {
  198. auto manager = fg_->manager();
  199. auto node_users = manager->node_users()[node];
  200. for (auto &node_user : node_users) {
  201. // new getitem
  202. AbstractBasePtrList abstractList;
  203. std::vector<int64_t> shape_vector;
  204. abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
  205. auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim();
  206. if (tuple_get_item_prim_ptr == nullptr) {
  207. MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
  208. return RET_NULL_PTR;
  209. }
  210. auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
  211. const auto &exit_node = node;
  212. auto switch_node = BlongToWhichSwitch(exit_node);
  213. auto merge_node = BlongToWhichMerge(switch_node);
  214. auto enter_node = BlongToWhichExternalEnter(merge_node);
  215. int output_idx = PosInInputEnterNodes(enter_node);
  216. auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
  217. std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
  218. CNodePtr get_item_node = fg_->NewCNode(inputs);
  219. std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
  220. auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
  221. if (abstract == nullptr) {
  222. MS_LOG(ERROR) << "create AbstractTensor failed";
  223. return RET_NULL_PTR;
  224. }
  225. get_item_node->set_abstract(abstract);
  226. get_item_node->set_fullname_with_scope(output_item_name);
  227. // set
  228. if (fg_->nodes().contains(node_user.first)) {
  229. manager->SetEdge(node_user.first, node_user.second, get_item_node);
  230. }
  231. }
  232. }
  233. }
  234. return RET_OK;
  235. }
  236. STATUS FunctionalizeWhile::BuildWhileNode() {
  237. int ret = NewWhileNode();
  238. if (ret != RET_OK) {
  239. MS_LOG(ERROR) << "new while node failed, ret:" << ret;
  240. return ret;
  241. }
  242. ret = IdentifyWhileNodeInput();
  243. if (ret != RET_OK) {
  244. MS_LOG(ERROR) << "identify while node input failed, ret:" << ret;
  245. return ret;
  246. }
  247. ret = IdentifyWhileNodeExternalInput();
  248. if (ret != RET_OK) {
  249. MS_LOG(ERROR) << "identify while node external input failed, ret:" << ret;
  250. return ret;
  251. }
  252. ret = IdentifyWhileNodeOutput();
  253. if (ret != RET_OK) {
  254. MS_LOG(ERROR) << "identify while node output failed, ret:" << ret;
  255. return ret;
  256. }
  257. // update exit node user from exit to while
  258. ret = UpdateExitNodeUser();
  259. if (ret != RET_OK) {
  260. MS_LOG(ERROR) << "update while node users, ret:" << ret;
  261. return ret;
  262. }
  263. return ret;
  264. }
  265. // nodes between loop_cond op and merge op be added into cond_func_graph
  266. STATUS FunctionalizeWhile::CondSubgraphAddNodes() {
  267. std::deque<AnfNodePtr> todo(512);
  268. todo.clear();
  269. for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) {
  270. todo.push_back(loop_cond_node_->input(i));
  271. }
  272. while (!todo.empty()) {
  273. AnfNodePtr node = todo.back();
  274. todo.pop_back();
  275. if (FunctionalizeControlOpPass::IsMerge(node)) {
  276. continue;
  277. }
  278. if (utils::isa<ParameterPtr>(node)) {
  279. cond_sub_func_graph_->add_parameter(node->cast<ParameterPtr>());
  280. } else {
  281. cond_sub_func_graph_->AddNode(node);
  282. }
  283. node->set_func_graph(cond_sub_func_graph_);
  284. if (utils::isa<CNodePtr>(node)) {
  285. auto cnode = utils::cast<CNodePtr>(node);
  286. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  287. todo.push_back(cnode->input(i));
  288. }
  289. }
  290. }
  291. return RET_OK;
  292. }
  293. STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() {
  294. std::vector<AnfNodePtr> nodes_need_drop{};
  295. for (auto &cnode : cond_sub_func_graph_->GetOrderedCnodes()) {
  296. for (auto &input_node : cnode->inputs()) {
  297. if (FunctionalizeControlOpPass::IsMerge(input_node)) {
  298. auto merge_node = input_node->cast<CNodePtr>();
  299. auto enter_node = BlongToWhichEnter(merge_node);
  300. int pos = PosInInputEnterNodes(enter_node);
  301. nodes_need_drop.push_back(cnode);
  302. // set parameter
  303. auto parameter = cond_sub_func_graph_->add_parameter();
  304. parameter->set_abstract(cnode->abstract());
  305. // hardcode for subgraph input name
  306. parameter->set_name(cond_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter");
  307. // replace merge
  308. auto manager = fg_->manager();
  309. auto node_users = manager->node_users()[cnode];
  310. for (auto &node_user : node_users) {
  311. if (cond_sub_func_graph_->nodes().contains(node_user.first)) {
  312. manager->SetEdge(node_user.first, node_user.second, parameter);
  313. }
  314. }
  315. }
  316. }
  317. }
  318. // drop node from cond_func_graph
  319. for (const auto &node : nodes_need_drop) {
  320. cond_sub_func_graph_->DropNode(node);
  321. }
  322. return RET_OK;
  323. }
  324. STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
  325. auto return_prim_ptr = lite::GetReturnPrim();
  326. if (return_prim_ptr == nullptr) {
  327. MS_LOG(ERROR) << "GetReturnPrim return nullptr";
  328. return RET_NULL_PTR;
  329. }
  330. auto value_node = NewValueNode(return_prim_ptr);
  331. if (value_node == nullptr) {
  332. MS_LOG(ERROR) << "new value_node failed.";
  333. return RET_NULL_PTR;
  334. }
  335. // cond subgraph output is LoopCond's input
  336. std::vector<AnfNodePtr> op_inputs{value_node, loop_cond_node_->input(1)};
  337. auto return_cnode = cond_sub_func_graph_->NewCNode(op_inputs);
  338. return_cnode->set_fullname_with_scope(cond_subgraph_name_ + "-return");
  339. cond_sub_func_graph_->set_return(return_cnode);
  340. // hardcode subgraph outputs name
  341. cond_sub_func_graph_->output()->cast<CNodePtr>()->set_fullname_with_scope(cond_subgraph_name_ + "_output_0_cnode");
  342. return RET_OK;
  343. }
  344. STATUS FunctionalizeWhile::BuildCondGraph() {
  345. cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond";
  346. cond_sub_func_graph_ =
  347. FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF);
  348. if (cond_sub_func_graph_ == nullptr) {
  349. MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr";
  350. return RET_NULL_PTR;
  351. }
  352. cond_sub_func_graph_->set_manager(fg_->manager());
  353. int ret = CondSubgraphAddNodes();
  354. if (ret != RET_OK) {
  355. MS_LOG(ERROR) << "add cond_subgraph node failed, ret:" << ret;
  356. return ret;
  357. }
  358. ret = IdentifyCondSubgraphOutput();
  359. if (ret != RET_OK) {
  360. MS_LOG(ERROR) << "identify cond_subgraph output failed, ret:" << ret;
  361. return ret;
  362. }
  363. ret = IdentifyCondSubgraphInput();
  364. if (ret != RET_OK) {
  365. MS_LOG(ERROR) << "identify cond_subgraph input failed, ret:" << ret;
  366. return ret;
  367. }
  368. return ret;
  369. }
  370. // nodes between next_iteration op and switch op will be added into body_func_graph
  371. STATUS FunctionalizeWhile::BodySubgraphAddNodes() {
  372. std::deque<AnfNodePtr> todo(512);
  373. todo.clear();
  374. for (auto &node : node_cluster_) {
  375. if (FunctionalizeControlOpPass::IsNextIteration(node)) {
  376. auto next_iteration_cnode = node->cast<CNodePtr>();
  377. for (size_t i = 1; i < next_iteration_cnode->inputs().size(); i++) {
  378. todo.push_back(next_iteration_cnode->input(i));
  379. }
  380. body_subgraph_output_map_[node] = next_iteration_cnode->input(1);
  381. }
  382. }
  383. while (!todo.empty()) {
  384. AnfNodePtr node = todo.back();
  385. todo.pop_back();
  386. if (FunctionalizeControlOpPass::IsSwitch(node)) {
  387. continue;
  388. }
  389. if (utils::isa<ParameterPtr>(node)) {
  390. body_sub_func_graph_->add_parameter(node->cast<ParameterPtr>());
  391. } else {
  392. body_sub_func_graph_->AddNode(node);
  393. }
  394. node->set_func_graph(body_sub_func_graph_);
  395. if (utils::isa<CNodePtr>(node)) {
  396. auto cnode = utils::cast<CNodePtr>(node);
  397. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  398. todo.push_back(cnode->input(i));
  399. }
  400. }
  401. }
  402. return RET_OK;
  403. }
  404. STATUS FunctionalizeWhile::IdentifyBodySubgraphInput() {
  405. std::vector<AnfNodePtr> nodes_need_drop{};
  406. for (auto &cnode : body_sub_func_graph_->GetOrderedCnodes()) {
  407. for (auto &input_node : cnode->inputs()) {
  408. if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
  409. auto switch_node = input_node->cast<CNodePtr>();
  410. auto merge_node = BlongToWhichMerge(switch_node);
  411. auto enter_node = BlongToWhichEnter(merge_node);
  412. int pos = PosInInputEnterNodes(enter_node);
  413. if (pos == POS_INVALID) {
  414. continue;
  415. }
  416. nodes_need_drop.push_back(cnode);
  417. // set parameter
  418. auto parameter = body_sub_func_graph_->add_parameter();
  419. parameter->set_abstract(cnode->abstract());
  420. // hardcode for subgraph input name
  421. parameter->set_name(body_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter");
  422. // replace switch
  423. auto manager = fg_->manager();
  424. auto node_users = manager->node_users()[cnode];
  425. for (auto &node_user : node_users) {
  426. if (body_sub_func_graph_->nodes().contains(node_user.first)) {
  427. manager->SetEdge(node_user.first, node_user.second, parameter);
  428. }
  429. }
  430. }
  431. }
  432. }
  433. // drop node from cond_func_graph
  434. for (const auto &node : nodes_need_drop) {
  435. body_sub_func_graph_->DropNode(node);
  436. }
  437. return RET_OK;
  438. }
  439. STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
  440. std::vector<AnfNodePtr> tmp_output{};
  441. tmp_output.resize(input_enter_nodes_.size());
  442. for (auto &node_pair : body_subgraph_output_map_) {
  443. auto next_iteration_cnode = utils::cast<CNodePtr>(node_pair.first);
  444. auto switch_node = BlongToWhichSwitch(next_iteration_cnode);
  445. auto merge_node = BlongToWhichMerge(switch_node);
  446. auto enter_node = BlongToWhichEnter(merge_node);
  447. int pos = PosInInputEnterNodes(enter_node);
  448. if (pos == POS_INVALID) {
  449. continue;
  450. }
  451. tmp_output[pos] = node_pair.second;
  452. // hard code. set cnode output name
  453. node_pair.second->cast<CNodePtr>()->set_fullname_with_scope(body_subgraph_name_ + "_output_" + std::to_string(pos) +
  454. "_cnode");
  455. }
  456. auto return_prim_ptr = lite::GetReturnPrim();
  457. if (return_prim_ptr == nullptr) {
  458. MS_LOG(ERROR) << "GetReturnPrim return nullptr";
  459. return RET_NULL_PTR;
  460. }
  461. auto value_node = NewValueNode(return_prim_ptr);
  462. // cond subgraph output is LoopCond's input
  463. std::vector<AnfNodePtr> op_inputs{value_node};
  464. auto return_cnode = body_sub_func_graph_->NewCNode(op_inputs);
  465. return_cnode->set_fullname_with_scope(body_subgraph_name_ + "-return");
  466. if (tmp_output.size() == 1) {
  467. return_cnode->add_input(tmp_output[0]);
  468. } else {
  469. std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
  470. auto make_tuple_prim_ptr = lite::GetMakeTuplePrim();
  471. if (make_tuple_prim_ptr == nullptr) {
  472. MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
  473. return RET_NULL_PTR;
  474. }
  475. auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
  476. make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim);
  477. auto make_tuple_cnode = body_sub_func_graph_->NewCNode(make_tuple_inputs);
  478. make_tuple_cnode->set_fullname_with_scope(return_cnode->fullname_with_scope() + "tuple");
  479. return_cnode->add_input(make_tuple_cnode);
  480. }
  481. body_sub_func_graph_->set_return(return_cnode);
  482. return RET_OK;
  483. }
  484. STATUS FunctionalizeWhile::BuildBodyGraph() {
  485. body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body";
  486. body_sub_func_graph_ =
  487. FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF);
  488. if (body_sub_func_graph_ == nullptr) {
  489. MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr";
  490. return RET_NULL_PTR;
  491. }
  492. body_sub_func_graph_->set_manager(fg_->manager());
  493. int ret = BodySubgraphAddNodes();
  494. if (ret != RET_OK) {
  495. MS_LOG(ERROR) << "add body_subgraph node failed, ret:" << ret;
  496. return ret;
  497. }
  498. ret = IdentifyBodySubgraphOutput();
  499. if (ret != RET_OK) {
  500. MS_LOG(ERROR) << "identify body_subgraph output failed, ret:" << ret;
  501. return ret;
  502. }
  503. ret = IdentifyBodySubgraphInput();
  504. if (ret != RET_OK) {
  505. MS_LOG(ERROR) << "identify body_subgraph input failed, ret:" << ret;
  506. return ret;
  507. }
  508. return ret;
  509. }
  510. STATUS FunctionalizeWhile::InsertFuncGraphToWhileInput() {
  511. // set while input cond and body vnode
  512. auto cond_value_node = NewValueNode(cond_sub_func_graph_);
  513. auto body_value_node = NewValueNode(body_sub_func_graph_);
  514. auto inputs = while_node_->inputs();
  515. inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
  516. while_node_->set_inputs(inputs);
  517. return RET_OK;
  518. }
  519. STATUS FunctionalizeWhile::DropUselessNodesInMainGraph() {
  520. // fg_ drop cluster node
  521. for (auto &node : node_cluster_) {
  522. fg_->DropNode(node);
  523. }
  524. return RET_OK;
  525. }
  526. STATUS FunctionalizeWhile::Process() {
  527. int ret = BuildWhileNode();
  528. if (ret != RET_OK) {
  529. MS_LOG(ERROR) << "build while node failed, ret:" << ret;
  530. return ret;
  531. }
  532. ret = BuildCondGraph();
  533. if (ret != RET_OK) {
  534. MS_LOG(ERROR) << "build condition graph failed, ret:" << ret;
  535. return ret;
  536. }
  537. ret = BuildBodyGraph();
  538. if (ret != RET_OK) {
  539. MS_LOG(ERROR) << "build body graph failed, ret:" << ret;
  540. return ret;
  541. }
  542. ret = InsertFuncGraphToWhileInput();
  543. if (ret != RET_OK) {
  544. MS_LOG(ERROR) << "insert func_graph to while input failed, ret:" << ret;
  545. return ret;
  546. }
  547. ret = DropUselessNodesInMainGraph();
  548. if (ret != RET_OK) {
  549. MS_LOG(ERROR) << "main func_graph drop nodes failed, ret:" << ret;
  550. return ret;
  551. }
  552. return ret;
  553. }
  554. } // namespace mindspore::opt