You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clean.cc 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "frontend/optimizer/clean.h"
  19. #include <string>
  20. #include <vector>
  21. #include <algorithm>
  22. #include "debug/trace.h"
  23. #include "frontend/operator/composite/composite.h"
  24. #include "pipeline/jit/parse/resolve.h"
  25. namespace mindspore {
  26. /* namespace to support opt */
  27. namespace opt {
  28. using mindspore::abstract::AbstractAttribute;
  29. using mindspore::abstract::AbstractClass;
  30. using mindspore::abstract::AbstractDictionary;
  31. using mindspore::abstract::AbstractJTagged;
  32. using mindspore::abstract::AbstractList;
  33. using mindspore::abstract::AbstractRowTensor;
  34. using mindspore::abstract::AbstractScalar;
  35. using mindspore::abstract::AbstractSparseTensor;
  36. using mindspore::abstract::AbstractTuple;
  37. using mindspore::abstract::AbstractUndetermined;
  38. static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
  39. if (t == nullptr) {
  40. return nullptr;
  41. }
  42. if (t->isa<AbstractClass>()) {
  43. auto abs_class = dyn_cast<AbstractClass>(t);
  44. AbstractBasePtrList baselist;
  45. auto attributes = abs_class->attributes();
  46. (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
  47. [](const AbstractAttribute &item) { return item.second; });
  48. return std::make_shared<AbstractTuple>(baselist);
  49. }
  50. if (t->isa<AbstractDictionary>()) {
  51. auto abs_dict = dyn_cast<AbstractDictionary>(t);
  52. AbstractBasePtrList baselist;
  53. auto elements = abs_dict->elements();
  54. (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
  55. [](const AbstractAttribute &item) { return item.second; });
  56. return std::make_shared<AbstractTuple>(baselist);
  57. }
  58. return nullptr;
  59. }
  60. static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
  61. if (t == nullptr) {
  62. return nullptr;
  63. }
  64. if (t->isa<AbstractList>()) {
  65. auto abs_list = dyn_cast<AbstractList>(t);
  66. return std::make_shared<AbstractTuple>(abs_list->elements());
  67. }
  68. if (t->isa<AbstractSparseTensor>()) {
  69. auto abs_sparse = dyn_cast<AbstractSparseTensor>(t);
  70. std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
  71. return std::make_shared<AbstractTuple>(abstract_list);
  72. }
  73. if (t->isa<AbstractRowTensor>()) {
  74. auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t);
  75. std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(),
  76. abs_row_tensor->dense_shape()};
  77. return std::make_shared<AbstractTuple>(abstract_list);
  78. }
  79. return nullptr;
  80. }
  81. AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
  82. MS_EXCEPTION_IF_NULL(node);
  83. MS_EXCEPTION_IF_NULL(node->func_graph());
  84. const auto &inputs = node->inputs();
  85. // Inputs should be [getattr, data, attribute]
  86. MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
  87. AnfNodePtr data = inputs[1];
  88. AnfNodePtr cons = inputs[2];
  89. MS_EXCEPTION_IF_NULL(data);
  90. MS_EXCEPTION_IF_NULL(cons);
  91. auto dt = data->abstract();
  92. if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
  93. return nullptr;
  94. }
  95. if (!dt->isa<AbstractClass>()) {
  96. MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
  97. }
  98. auto cons_is_str = IsValueNode<StringImm>(cons);
  99. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  100. auto ct = dyn_cast<AbstractClass>(dt);
  101. const auto &cmap = ct->attributes();
  102. int count = 0;
  103. for (auto &item : cmap) {
  104. if (cons_is_str && item.first == cons_str) {
  105. break;
  106. }
  107. count++;
  108. }
  109. auto idx_c = NewValueNode(count);
  110. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  111. idx_c->set_abstract(aptr);
  112. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  113. }
  114. AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
  115. MS_EXCEPTION_IF_NULL(node);
  116. MS_EXCEPTION_IF_NULL(node->func_graph());
  117. // Inputs should be [dict_getitem, dict, item]
  118. const auto &inputs = node->inputs();
  119. MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
  120. AnfNodePtr data = inputs[1];
  121. AnfNodePtr cons = inputs[2];
  122. MS_EXCEPTION_IF_NULL(data);
  123. MS_EXCEPTION_IF_NULL(cons);
  124. auto dt = data->abstract();
  125. MS_EXCEPTION_IF_NULL(dt);
  126. if (!dt->isa<abstract::AbstractDictionary>()) {
  127. MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name();
  128. }
  129. auto cons_is_str = IsValueNode<StringImm>(cons);
  130. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  131. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  132. const auto &cmap = ct->elements();
  133. int count = 0;
  134. for (auto &item : cmap) {
  135. if (cons_is_str && item.first == cons_str) {
  136. break;
  137. }
  138. count++;
  139. }
  140. auto idx_c = NewValueNode(count);
  141. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  142. idx_c->set_abstract(aptr);
  143. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  144. }
  145. AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
  146. MS_EXCEPTION_IF_NULL(node);
  147. MS_EXCEPTION_IF_NULL(node->func_graph());
  148. // Inputs should be [dict_setitem, dict, item, value]
  149. const auto &inputs = node->inputs();
  150. MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs.");
  151. AnfNodePtr data = inputs[1];
  152. AnfNodePtr cons = inputs[2];
  153. AnfNodePtr item_value = inputs[3];
  154. MS_EXCEPTION_IF_NULL(data);
  155. MS_EXCEPTION_IF_NULL(cons);
  156. auto dt = data->abstract();
  157. MS_EXCEPTION_IF_NULL(dt);
  158. if (!dt->isa<abstract::AbstractDictionary>()) {
  159. MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
  160. }
  161. auto cons_is_str = IsValueNode<StringImm>(cons);
  162. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  163. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  164. const auto &cmap = ct->elements();
  165. int count = 0;
  166. for (auto &item : cmap) {
  167. if (cons_is_str && item.first == cons_str) {
  168. break;
  169. }
  170. count++;
  171. }
  172. if (IntToSize(count) >= cmap.size()) {
  173. // for dictionary set, if the key does not exist, we should create a new item
  174. auto tuple_add_op = std::make_shared<prim::TupleAdd>("tuple_add");
  175. auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value});
  176. return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item});
  177. }
  178. auto idx_c = NewValueNode(count);
  179. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  180. idx_c->set_abstract(aptr);
  181. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
  182. }
  183. AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
  184. MS_EXCEPTION_IF_NULL(node);
  185. MS_EXCEPTION_IF_NULL(node->func_graph());
  186. std::vector<AnfNodePtr> inputs;
  187. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  188. // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr;
  189. (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end());
  190. return node->func_graph()->NewCNode(inputs);
  191. }
  192. AnfNodePtr ErasePartialNode(const CNodePtr &node) {
  193. MS_EXCEPTION_IF_NULL(node);
  194. MS_EXCEPTION_IF_NULL(node->func_graph());
  195. const auto &inputs = node->inputs();
  196. // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
  197. MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
  198. std::vector<AnfNodePtr> args(inputs.begin() + 2, inputs.end());
  199. auto oper = inputs[1];
  200. if (IsPrimitive(oper, prim::kPrimMakeRecord)) {
  201. if (args.size() == 1) {
  202. return NewValueNode(prim::kPrimMakeTuple);
  203. }
  204. if (args.size() > 1) {
  205. std::vector<AnfNodePtr> new_inputs;
  206. new_inputs.emplace_back(NewValueNode(prim::kPrimPartial));
  207. new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  208. (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end());
  209. MS_EXCEPTION_IF_NULL(node->func_graph());
  210. return node->func_graph()->NewCNode(new_inputs);
  211. }
  212. }
  213. return nullptr;
  214. }
  215. AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
  216. MS_EXCEPTION_IF_NULL(node);
  217. MS_EXCEPTION_IF_NULL(node->func_graph());
  218. std::vector<AnfNodePtr> inputs;
  219. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  220. // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
  221. (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
  222. return node->func_graph()->NewCNode(inputs);
  223. }
  224. AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
  225. MS_EXCEPTION_IF_NULL(node);
  226. MS_EXCEPTION_IF_NULL(node->func_graph());
  227. const auto &inputs = node->inputs();
  228. // Inputs should be [list_getitem, list, item]
  229. if (inputs.size() < 3) {
  230. MS_LOG(EXCEPTION) << "Node's input number < 3.";
  231. }
  232. AnfNodePtr data = inputs[1];
  233. AnfNodePtr cons = inputs[2];
  234. MS_EXCEPTION_IF_NULL(data);
  235. MS_EXCEPTION_IF_NULL(cons);
  236. auto cons_node = cons->cast<ValueNodePtr>();
  237. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
  238. }
  239. AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
  240. MS_EXCEPTION_IF_NULL(node);
  241. MS_EXCEPTION_IF_NULL(node->func_graph());
  242. const auto &inputs = node->inputs();
  243. // Inputs should be [list_setitem, list, index, item]
  244. if (inputs.size() < 4) {
  245. MS_LOG(EXCEPTION) << "Node's input number < 4.";
  246. }
  247. AnfNodePtr data = inputs[1];
  248. AnfNodePtr cons = inputs[2];
  249. AnfNodePtr value = inputs[3];
  250. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
  251. }
  252. AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
  253. MS_EXCEPTION_IF_NULL(node);
  254. const auto &inputs = node->inputs();
  255. MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
  256. return inputs[2];
  257. }
  258. AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
  259. MS_EXCEPTION_IF_NULL(node);
  260. const auto &inputs = node->inputs();
  261. // Inputs should be [make_keyword_arg, key, value]
  262. MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
  263. return inputs[2];
  264. }
  265. AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
  266. MS_EXCEPTION_IF_NULL(node);
  267. const auto &inputs = node->inputs();
  268. // Inputs should be [extract_keyword_arg, arg, key]
  269. MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
  270. return inputs[2];
  271. }
  272. ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) {
  273. const int DEPTH_MAX = 5;
  274. if (depth > DEPTH_MAX) {
  275. MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
  276. }
  277. std::vector<ValuePtr> elements;
  278. for (const auto &it : value_list->value()) {
  279. ValuePtr value = nullptr;
  280. if (it->isa<ValueList>()) {
  281. value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
  282. } else {
  283. value = it;
  284. }
  285. elements.push_back(value);
  286. }
  287. return std::make_shared<ValueTuple>(elements);
  288. }
  289. AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
  290. MS_EXCEPTION_IF_NULL(node);
  291. ValuePtr value = node->value();
  292. auto value_list = value->cast<ValueListPtr>();
  293. MS_EXCEPTION_IF_NULL(value_list);
  294. int depth = 0;
  295. return std::make_shared<ValueNode>(ConvertValueListToValueTuple(value_list, depth));
  296. }
  297. // Convert class to Tuple
  298. // Convert getattr to getitem
  299. // Convert make_record to make_tuple
  300. bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  301. MS_EXCEPTION_IF_NULL(manager);
  302. manager->AddFuncGraph(root);
  303. bool changed = false;
  304. // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  305. AnfNodeSet all_node = manager->all_nodes();
  306. for (auto &node : all_node) {
  307. MS_EXCEPTION_IF_NULL(node);
  308. auto cnode = node->cast<CNodePtr>();
  309. AnfNodePtr new_node = nullptr;
  310. if (IsValueNode<parse::ClassObject>(node)) {
  311. new_node = NewValueNode(prim::kPrimMakeTuple);
  312. } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) {
  313. new_node = ConvertGetAttrToTupleGetItem(cnode);
  314. } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) {
  315. new_node = ConvertMakeRecordToMakeTuple(cnode);
  316. } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
  317. new_node = ErasePartialNode(cnode);
  318. } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
  319. new_node = ConvertDictGetItemToTupleGetItem(cnode);
  320. } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
  321. new_node = ConvertDictSetItemToTupleSetItem(cnode);
  322. } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
  323. new_node = EraseMakeDictNode(cnode);
  324. } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
  325. new_node = EraseMakeKeywordArgNode(cnode);
  326. } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
  327. new_node = EraseExtractKeywordArg(cnode);
  328. }
  329. if (new_node != nullptr) {
  330. new_node->set_abstract(node->abstract());
  331. MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
  332. (void)manager->Replace(node, new_node);
  333. changed = true;
  334. }
  335. }
  336. for (auto &node : manager->all_nodes()) {
  337. auto ret = Reabs(node->abstract());
  338. if (ret) {
  339. MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
  340. << ret->ToString();
  341. node->set_abstract(ret);
  342. if (ret->cast<abstract::AbstractTuplePtr>()->size() > 0) {
  343. changed = true;
  344. }
  345. }
  346. }
  347. return changed;
  348. }
  349. AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) {
  350. MS_EXCEPTION_IF_NULL(node);
  351. MS_EXCEPTION_IF_NULL(node->func_graph());
  352. std::vector<AnfNodePtr> inputs;
  353. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  354. // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
  355. (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
  356. return node->func_graph()->NewCNode(inputs);
  357. }
  358. AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int &index) {
  359. MS_EXCEPTION_IF_NULL(node);
  360. MS_EXCEPTION_IF_NULL(node->func_graph());
  361. const auto &inputs = node->inputs();
  362. // Inputs should be [spase_getattr, sparse]
  363. if (inputs.size() < 2) {
  364. MS_LOG(EXCEPTION) << "Node's input number < 2.";
  365. }
  366. AnfNodePtr sparse = inputs[1];
  367. MS_EXCEPTION_IF_NULL(sparse);
  368. auto cons_node = NewValueNode(index);
  369. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(index));
  370. cons_node->set_abstract(aptr);
  371. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node});
  372. }
  373. bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  374. MS_EXCEPTION_IF_NULL(manager);
  375. manager->AddFuncGraph(root);
  376. bool changed = false;
  377. // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  378. auto all_node = manager->all_nodes();
  379. for (auto &node : all_node) {
  380. MS_EXCEPTION_IF_NULL(node);
  381. auto cnode = node->cast<CNodePtr>();
  382. AnfNodePtr new_node = nullptr;
  383. if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
  384. new_node = ConvertMakeListToMakeTuple(cnode);
  385. } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
  386. new_node = ConvertListGetItemToTupleGetItem(cnode);
  387. } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) {
  388. new_node = ConvertListSetItemToTupleSetItem(cnode);
  389. } else if (IsValueNode<ValueList>(node)) {
  390. new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
  391. } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
  392. IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) {
  393. new_node = ConvertMakeSparseToMakeTuple(cnode);
  394. } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
  395. IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) {
  396. new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0);
  397. } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
  398. IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) {
  399. new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1);
  400. } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
  401. IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) {
  402. new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2);
  403. }
  404. if (new_node != nullptr) {
  405. new_node->set_abstract(node->abstract());
  406. MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
  407. (void)manager->Replace(node, new_node);
  408. changed = true;
  409. }
  410. }
  411. for (auto &node : manager->all_nodes()) {
  412. auto ret = AdaptAbs(node->abstract());
  413. if (ret) {
  414. MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
  415. << ret->ToString();
  416. node->set_abstract(ret);
  417. changed = true;
  418. }
  419. }
  420. return changed;
  421. }
  422. // expand tuples in graph parameters
  423. static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
  424. const std::vector<AnfNodePtr> &params) {
  425. MS_EXCEPTION_IF_NULL(mng);
  426. MS_EXCEPTION_IF_NULL(func_graph);
  427. std::vector<AnfNodePtr> new_params;
  428. for (const auto &param : params) {
  429. MS_EXCEPTION_IF_NULL(param);
  430. auto param_abs = param->abstract();
  431. MS_EXCEPTION_IF_NULL(param_abs);
  432. if (param_abs->isa<AbstractJTagged>()) {
  433. MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
  434. }
  435. if (!param_abs->isa<AbstractTuple>()) {
  436. new_params.emplace_back(param);
  437. continue;
  438. }
  439. std::vector<AnfNodePtr> new_param;
  440. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
  441. auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
  442. for (auto &elem : abs_tuple->elements()) {
  443. auto np = std::make_shared<Parameter>(func_graph);
  444. np->set_abstract(elem);
  445. new_param.emplace_back(np);
  446. }
  447. (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
  448. auto new_tuple = func_graph->NewCNode(inputs);
  449. (void)mng->Replace(param, new_tuple);
  450. auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
  451. (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
  452. }
  453. return new_params;
  454. }
  455. // expand tuples in graph applies
  456. static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
  457. MS_EXCEPTION_IF_NULL(graph);
  458. std::vector<AnfNodePtr> new_inputs;
  459. for (const auto &input : inputs) {
  460. MS_EXCEPTION_IF_NULL(input);
  461. auto input_abs = input->abstract();
  462. MS_EXCEPTION_IF_NULL(input_abs);
  463. if (input_abs->isa<AbstractJTagged>()) {
  464. auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
  465. if (abstract_tag->element()->isa<AbstractTuple>()) {
  466. MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
  467. }
  468. }
  469. if (!input_abs->isa<AbstractTuple>()) {
  470. new_inputs.emplace_back(input);
  471. continue;
  472. }
  473. int idx = 0;
  474. std::vector<AnfNodePtr> new_input;
  475. auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
  476. for (auto &elem : abs_tuple->elements()) {
  477. auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
  478. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
  479. c_node->input(2)->set_abstract(aptr);
  480. c_node->set_abstract(elem);
  481. new_input.emplace_back(c_node);
  482. idx++;
  483. }
  484. auto expand_tuple = ExpandTuplesC(graph, new_input);
  485. (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
  486. }
  487. return new_inputs;
  488. }
  489. // remove most uses of tuples from the graph parameters & apply inputs
  490. // tuples that are returned will be kept
  491. // tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
  492. // CNode("tuple_getitem", (a,b,c), 0)
  493. // CNode("tuple_getitem", (a,b,c), 1)
  494. // CNode("tuple_getitem", (a,b,c), 2)
  495. // tuples in Graph's parameters: AbstractTuple (a, b, c) -->
  496. // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
  497. // cppcheck-suppress unusedFunction
  498. void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  499. MS_EXCEPTION_IF_NULL(manager);
  500. manager->AddFuncGraph(root);
  501. // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  502. AnfNodeSet all_node = manager->all_nodes();
  503. for (auto &node : all_node) {
  504. auto cnode = node->cast<CNodePtr>();
  505. if (cnode == nullptr) {
  506. continue;
  507. }
  508. const auto &inputs = cnode->inputs();
  509. // Bypass the first input in inputs as it's fn.
  510. if (!IsValueNode<Primitive>(inputs[0])) {
  511. std::vector<AnfNodePtr> expand_inputs;
  512. (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
  513. auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
  514. if (new_inputs != expand_inputs) {
  515. std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
  516. (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
  517. MS_EXCEPTION_IF_NULL(node->func_graph());
  518. auto new_node = node->func_graph()->NewCNode(cnode_inputs);
  519. new_node->set_abstract(node->abstract());
  520. (void)manager->Replace(node, new_node);
  521. }
  522. // Bypass the first 2 inputs in inputs as it's [partial, fn].
  523. } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
  524. std::vector<AnfNodePtr> expand_inputs;
  525. (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
  526. auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
  527. if (new_inputs != expand_inputs) {
  528. std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
  529. (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
  530. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  531. auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
  532. new_node->set_abstract(cnode->abstract());
  533. (void)manager->Replace(node, new_node);
  534. }
  535. }
  536. }
  537. FuncGraphSet all_graph = manager->func_graphs();
  538. for (auto &func_graph : all_graph) {
  539. MS_EXCEPTION_IF_NULL(func_graph);
  540. auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
  541. manager->SetParameters(func_graph, expand_p);
  542. }
  543. }
  544. } // namespace opt
  545. } // namespace mindspore