You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clean.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "optimizer/clean.h"
  19. #include <map>
  20. #include <string>
  21. #include <vector>
  22. #include <algorithm>
  23. #include <utility>
  24. #include "./common.h"
  25. #include "debug/trace.h"
  26. #include "operator/composite/composite.h"
  27. namespace mindspore {
  28. /* namespace to support opt */
  29. namespace opt {
  30. using mindspore::abstract::AbstractAttribute;
  31. using mindspore::abstract::AbstractClass;
  32. using mindspore::abstract::AbstractDictionary;
  33. using mindspore::abstract::AbstractJTagged;
  34. using mindspore::abstract::AbstractList;
  35. using mindspore::abstract::AbstractScalar;
  36. using mindspore::abstract::AbstractTuple;
  37. static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
  38. if (t == nullptr) {
  39. return nullptr;
  40. }
  41. AbstractBasePtr res = t;
  42. if (t->isa<AbstractClass>()) {
  43. auto abs_class = dyn_cast<AbstractClass>(t);
  44. AbstractBasePtrList baselist;
  45. auto attributes = abs_class->attributes();
  46. (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
  47. [](const AbstractAttribute &item) { return item.second; });
  48. res = std::make_shared<AbstractTuple>(baselist);
  49. } else if (t->isa<AbstractDictionary>()) {
  50. auto abs_dict = dyn_cast<AbstractDictionary>(t);
  51. AbstractBasePtrList baselist;
  52. auto elements = abs_dict->elements();
  53. (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
  54. [](const AbstractAttribute &item) { return item.second; });
  55. res = std::make_shared<AbstractTuple>(baselist);
  56. } else if (t->isa<AbstractList>()) {
  57. auto abs_dict = dyn_cast<AbstractList>(t);
  58. res = std::make_shared<AbstractTuple>(abs_dict->elements());
  59. }
  60. return res;
  61. }
  62. AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
  63. MS_EXCEPTION_IF_NULL(node);
  64. MS_EXCEPTION_IF_NULL(node->func_graph());
  65. const auto &inputs = node->inputs();
  66. // Inputs should be [getattr, data, attribute]
  67. MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
  68. AnfNodePtr data = inputs[1];
  69. AnfNodePtr cons = inputs[2];
  70. MS_EXCEPTION_IF_NULL(data);
  71. MS_EXCEPTION_IF_NULL(cons);
  72. auto dt = data->abstract();
  73. if (dt == nullptr) {
  74. return nullptr;
  75. }
  76. if (!dt->isa<AbstractClass>()) {
  77. MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
  78. }
  79. auto cons_is_str = IsValueNode<StringImm>(cons);
  80. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  81. auto ct = dyn_cast<AbstractClass>(dt);
  82. const auto &cmap = ct->attributes();
  83. int count = 0;
  84. for (auto &item : cmap) {
  85. if (cons_is_str && item.first == cons_str) {
  86. break;
  87. }
  88. count++;
  89. }
  90. auto idx_c = NewValueNode(count);
  91. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  92. idx_c->set_abstract(aptr);
  93. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  94. }
  95. AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
  96. MS_EXCEPTION_IF_NULL(node);
  97. MS_EXCEPTION_IF_NULL(node->func_graph());
  98. // Inputs should be [dict_getitem, dict, item]
  99. const auto &inputs = node->inputs();
  100. MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
  101. AnfNodePtr data = inputs[1];
  102. AnfNodePtr cons = inputs[2];
  103. MS_EXCEPTION_IF_NULL(data);
  104. MS_EXCEPTION_IF_NULL(cons);
  105. auto dt = data->abstract();
  106. MS_EXCEPTION_IF_NULL(dt);
  107. if (!dt->isa<abstract::AbstractDictionary>()) {
  108. MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name();
  109. }
  110. auto cons_is_str = IsValueNode<StringImm>(cons);
  111. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  112. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  113. const auto &cmap = ct->elements();
  114. int count = 0;
  115. for (auto &item : cmap) {
  116. if (cons_is_str && item.first == cons_str) {
  117. break;
  118. }
  119. count++;
  120. }
  121. auto idx_c = NewValueNode(count);
  122. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  123. idx_c->set_abstract(aptr);
  124. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  125. }
  126. AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
  127. MS_EXCEPTION_IF_NULL(node);
  128. MS_EXCEPTION_IF_NULL(node->func_graph());
  129. // Inputs should be [dict_setitem, dict, item, value]
  130. const auto &inputs = node->inputs();
  131. MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs.");
  132. AnfNodePtr data = inputs[1];
  133. AnfNodePtr cons = inputs[2];
  134. AnfNodePtr item_value = inputs[3];
  135. MS_EXCEPTION_IF_NULL(data);
  136. MS_EXCEPTION_IF_NULL(cons);
  137. auto dt = data->abstract();
  138. MS_EXCEPTION_IF_NULL(dt);
  139. if (!dt->isa<abstract::AbstractDictionary>()) {
  140. MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
  141. }
  142. auto cons_is_str = IsValueNode<StringImm>(cons);
  143. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  144. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  145. const auto &cmap = ct->elements();
  146. int count = 0;
  147. for (auto &item : cmap) {
  148. if (cons_is_str && item.first == cons_str) {
  149. break;
  150. }
  151. count++;
  152. }
  153. if (IntToSize(count) >= cmap.size()) {
  154. // for dictionary set, if the key does not exist, we should create a new item
  155. auto tuple_add_op = std::make_shared<prim::TupleAdd>("tuple_add");
  156. auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value});
  157. return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item});
  158. }
  159. auto idx_c = NewValueNode(count);
  160. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  161. idx_c->set_abstract(aptr);
  162. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
  163. }
  164. AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. MS_EXCEPTION_IF_NULL(node->func_graph());
  167. std::vector<AnfNodePtr> inputs;
  168. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  169. // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr;
  170. (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end());
  171. return node->func_graph()->NewCNode(inputs);
  172. }
  173. AnfNodePtr ErasePartialNode(const CNodePtr &node) {
  174. MS_EXCEPTION_IF_NULL(node);
  175. MS_EXCEPTION_IF_NULL(node->func_graph());
  176. const auto &inputs = node->inputs();
  177. // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
  178. MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
  179. std::vector<AnfNodePtr> args(inputs.begin() + 2, inputs.end());
  180. auto oper = inputs[1];
  181. if (IsPrimitive(oper, prim::kPrimMakeRecord)) {
  182. if (args.size() == 1) {
  183. return NewValueNode(prim::kPrimMakeTuple);
  184. }
  185. if (args.size() > 1) {
  186. std::vector<AnfNodePtr> new_inputs;
  187. new_inputs.emplace_back(NewValueNode(prim::kPrimPartial));
  188. new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  189. (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end());
  190. MS_EXCEPTION_IF_NULL(node->func_graph());
  191. return node->func_graph()->NewCNode(new_inputs);
  192. }
  193. }
  194. return nullptr;
  195. }
  196. AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
  197. MS_EXCEPTION_IF_NULL(node);
  198. MS_EXCEPTION_IF_NULL(node->func_graph());
  199. std::vector<AnfNodePtr> inputs;
  200. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  201. // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
  202. (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
  203. return node->func_graph()->NewCNode(inputs);
  204. }
  205. AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
  206. MS_EXCEPTION_IF_NULL(node);
  207. MS_EXCEPTION_IF_NULL(node->func_graph());
  208. const auto &inputs = node->inputs();
  209. // Inputs should be [list_getitem, list, item]
  210. if (inputs.size() < 3) {
  211. MS_LOG(EXCEPTION) << "Node's input number < 3.";
  212. }
  213. AnfNodePtr data = inputs[1];
  214. AnfNodePtr cons = inputs[2];
  215. MS_EXCEPTION_IF_NULL(data);
  216. MS_EXCEPTION_IF_NULL(cons);
  217. auto cons_node = cons->cast<ValueNodePtr>();
  218. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
  219. }
  220. AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
  221. MS_EXCEPTION_IF_NULL(node);
  222. MS_EXCEPTION_IF_NULL(node->func_graph());
  223. const auto &inputs = node->inputs();
  224. // Inputs should be [list_setitem, list, index, item]
  225. if (inputs.size() < 4) {
  226. MS_LOG(EXCEPTION) << "Node's input number < 4.";
  227. }
  228. AnfNodePtr data = inputs[1];
  229. AnfNodePtr cons = inputs[2];
  230. AnfNodePtr value = inputs[3];
  231. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
  232. }
  233. AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
  234. MS_EXCEPTION_IF_NULL(node);
  235. const auto &inputs = node->inputs();
  236. MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
  237. return inputs[2];
  238. }
  239. AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
  240. MS_EXCEPTION_IF_NULL(node);
  241. const auto &inputs = node->inputs();
  242. // Inputs should be [make_keyword_arg, key, value]
  243. MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
  244. return inputs[2];
  245. }
  246. AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
  247. MS_EXCEPTION_IF_NULL(node);
  248. const auto &inputs = node->inputs();
  249. // Inputs should be [extract_keyword_arg, arg, key]
  250. MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
  251. return inputs[2];
  252. }
  253. ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) {
  254. const int DEPTH_MAX = 5;
  255. if (depth > DEPTH_MAX) {
  256. MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
  257. }
  258. std::vector<ValuePtr> elements;
  259. for (const auto &it : value_list->value()) {
  260. ValuePtr value = nullptr;
  261. if (it->isa<ValueList>()) {
  262. value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
  263. } else {
  264. value = it;
  265. }
  266. elements.push_back(value);
  267. }
  268. return std::make_shared<ValueTuple>(elements);
  269. }
  270. AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
  271. MS_EXCEPTION_IF_NULL(node);
  272. ValuePtr value = node->value();
  273. auto value_list = value->cast<ValueListPtr>();
  274. MS_EXCEPTION_IF_NULL(value_list);
  275. int depth = 0;
  276. return std::make_shared<ValueNode>(ConvertValueListToValueTuple(value_list, depth));
  277. }
  278. // Convert class to Tuple
  279. // Convert getattr to getitem
  280. // Convert make_record to make_tuple
  281. bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  282. MS_EXCEPTION_IF_NULL(manager);
  283. manager->AddFuncGraph(root);
  284. bool changed = false;
  285. // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  286. AnfNodeSet all_node = manager->all_nodes();
  287. for (auto &node : all_node) {
  288. MS_EXCEPTION_IF_NULL(node);
  289. auto cnode = node->cast<CNodePtr>();
  290. AnfNodePtr new_node = nullptr;
  291. if (IsValueNode<parse::ClassObject>(node)) {
  292. new_node = NewValueNode(prim::kPrimMakeTuple);
  293. } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) {
  294. new_node = ConvertGetAttrToTupleGetItem(cnode);
  295. } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) {
  296. new_node = ConvertMakeRecordToMakeTuple(cnode);
  297. } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
  298. new_node = ErasePartialNode(cnode);
  299. } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
  300. new_node = ConvertDictGetItemToTupleGetItem(cnode);
  301. } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
  302. new_node = ConvertDictSetItemToTupleSetItem(cnode);
  303. } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
  304. new_node = EraseMakeDictNode(cnode);
  305. } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
  306. new_node = EraseMakeKeywordArgNode(cnode);
  307. } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
  308. new_node = EraseExtractKeywordArg(cnode);
  309. } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
  310. new_node = ConvertMakeListToMakeTuple(cnode);
  311. } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
  312. new_node = ConvertListGetItemToTupleGetItem(cnode);
  313. } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) {
  314. new_node = ConvertListSetItemToTupleSetItem(cnode);
  315. } else if (IsValueNode<ValueList>(node)) {
  316. new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
  317. }
  318. if (new_node != nullptr) {
  319. new_node->set_abstract(node->abstract());
  320. MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
  321. (void)manager->Replace(node, new_node);
  322. changed = true;
  323. }
  324. }
  325. for (auto &node : manager->all_nodes()) {
  326. auto ret = Reabs(node->abstract());
  327. node->set_abstract(ret);
  328. }
  329. return changed;
  330. }
  331. // expand tuples in graph parameters
  332. static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
  333. const std::vector<AnfNodePtr> &params) {
  334. MS_EXCEPTION_IF_NULL(mng);
  335. MS_EXCEPTION_IF_NULL(func_graph);
  336. std::vector<AnfNodePtr> new_params;
  337. for (const auto &param : params) {
  338. MS_EXCEPTION_IF_NULL(param);
  339. auto param_abs = param->abstract();
  340. MS_EXCEPTION_IF_NULL(param_abs);
  341. if (param_abs->isa<AbstractJTagged>()) {
  342. MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
  343. }
  344. if (!param_abs->isa<AbstractTuple>()) {
  345. new_params.emplace_back(param);
  346. continue;
  347. }
  348. std::vector<AnfNodePtr> new_param;
  349. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
  350. auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
  351. for (auto &elem : abs_tuple->elements()) {
  352. auto np = std::make_shared<Parameter>(func_graph);
  353. np->set_abstract(elem);
  354. new_param.emplace_back(np);
  355. }
  356. (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
  357. auto new_tuple = func_graph->NewCNode(inputs);
  358. (void)mng->Replace(param, new_tuple);
  359. auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
  360. (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
  361. }
  362. return new_params;
  363. }
  364. // expand tuples in graph applies
  365. static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
  366. MS_EXCEPTION_IF_NULL(graph);
  367. std::vector<AnfNodePtr> new_inputs;
  368. for (const auto &input : inputs) {
  369. MS_EXCEPTION_IF_NULL(input);
  370. auto input_abs = input->abstract();
  371. MS_EXCEPTION_IF_NULL(input_abs);
  372. if (input_abs->isa<AbstractJTagged>()) {
  373. auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
  374. if (abstract_tag->element()->isa<AbstractTuple>()) {
  375. MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
  376. }
  377. }
  378. if (!input_abs->isa<AbstractTuple>()) {
  379. new_inputs.emplace_back(input);
  380. continue;
  381. }
  382. int idx = 0;
  383. std::vector<AnfNodePtr> new_input;
  384. auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
  385. for (auto &elem : abs_tuple->elements()) {
  386. auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
  387. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
  388. c_node->input(2)->set_abstract(aptr);
  389. c_node->set_abstract(elem);
  390. new_input.emplace_back(c_node);
  391. idx++;
  392. }
  393. auto expand_tuple = ExpandTuplesC(graph, new_input);
  394. (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
  395. }
  396. return new_inputs;
  397. }
  398. // remove most uses of tuples from the graph parameters & apply inputs
  399. // tuples that are returned will be kept
  400. // tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
  401. // CNode("tuple_getitem", (a,b,c), 0)
  402. // CNode("tuple_getitem", (a,b,c), 1)
  403. // CNode("tuple_getitem", (a,b,c), 2)
  404. // tuples in Graph's parameters: AbstractTuple (a, b, c) -->
  405. // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
  406. // cppcheck-suppress unusedFunction
  407. void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  408. MS_EXCEPTION_IF_NULL(manager);
  409. manager->AddFuncGraph(root);
  410. // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  411. AnfNodeSet all_node = manager->all_nodes();
  412. for (auto &node : all_node) {
  413. auto cnode = node->cast<CNodePtr>();
  414. if (cnode == nullptr) {
  415. continue;
  416. }
  417. const auto &inputs = cnode->inputs();
  418. // Bypass the first input in inputs as it's fn.
  419. if (!IsValueNode<Primitive>(inputs[0])) {
  420. std::vector<AnfNodePtr> expand_inputs;
  421. (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
  422. auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
  423. if (new_inputs != expand_inputs) {
  424. std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
  425. (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
  426. MS_EXCEPTION_IF_NULL(node->func_graph());
  427. auto new_node = node->func_graph()->NewCNode(cnode_inputs);
  428. new_node->set_abstract(node->abstract());
  429. (void)manager->Replace(node, new_node);
  430. }
  431. // Bypass the first 2 inputs in inputs as it's [partial, fn].
  432. } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
  433. std::vector<AnfNodePtr> expand_inputs;
  434. (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
  435. auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
  436. if (new_inputs != expand_inputs) {
  437. std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
  438. (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
  439. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  440. auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
  441. new_node->set_abstract(cnode->abstract());
  442. (void)manager->Replace(node, new_node);
  443. }
  444. }
  445. }
  446. FuncGraphSet all_graph = manager->func_graphs();
  447. for (auto &func_graph : all_graph) {
  448. MS_EXCEPTION_IF_NULL(func_graph);
  449. auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
  450. manager->SetParameters(func_graph, expand_p);
  451. }
  452. }
  453. } // namespace opt
  454. } // namespace mindspore