You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clean.cc 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "frontend/optimizer/clean.h"
  19. #include <string>
  20. #include <vector>
  21. #include <algorithm>
  22. #include "debug/trace.h"
  23. #include "frontend/operator/composite/composite.h"
  24. #include "pipeline/jit/parse/resolve.h"
  25. namespace mindspore {
  26. /* namespace to support opt */
  27. namespace opt {
  28. using mindspore::abstract::AbstractAttribute;
  29. using mindspore::abstract::AbstractClass;
  30. using mindspore::abstract::AbstractDictionary;
  31. using mindspore::abstract::AbstractJTagged;
  32. using mindspore::abstract::AbstractList;
  33. using mindspore::abstract::AbstractRowTensor;
  34. using mindspore::abstract::AbstractScalar;
  35. using mindspore::abstract::AbstractSparseTensor;
  36. using mindspore::abstract::AbstractTuple;
  37. using mindspore::abstract::AbstractUndetermined;
  38. inline void CheckInputsSize(size_t actual_size, size_t expect_size, const std::string &op_name) {
  39. if (actual_size != expect_size) {
  40. MS_LOG(EXCEPTION) << op_name << " should have " << expect_size << " inputs, but got " << actual_size;
  41. }
  42. }
  43. static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
  44. if (t == nullptr) {
  45. return nullptr;
  46. }
  47. if (t->isa<AbstractClass>()) {
  48. auto abs_class = dyn_cast<AbstractClass>(t);
  49. AbstractBasePtrList baselist;
  50. auto attributes = abs_class->attributes();
  51. (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
  52. [](const AbstractAttribute &item) { return item.second; });
  53. return std::make_shared<AbstractTuple>(baselist);
  54. }
  55. if (t->isa<AbstractDictionary>()) {
  56. auto abs_dict = dyn_cast<AbstractDictionary>(t);
  57. AbstractBasePtrList baselist;
  58. auto elements = abs_dict->elements();
  59. (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
  60. [](const AbstractAttribute &item) { return item.second; });
  61. return std::make_shared<AbstractTuple>(baselist);
  62. }
  63. return nullptr;
  64. }
  65. static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
  66. if (t == nullptr) {
  67. return nullptr;
  68. }
  69. if (t->isa<AbstractList>()) {
  70. auto abs_list = dyn_cast<AbstractList>(t);
  71. return std::make_shared<AbstractTuple>(abs_list->elements());
  72. }
  73. if (t->isa<AbstractSparseTensor>()) {
  74. auto abs_sparse = dyn_cast<AbstractSparseTensor>(t);
  75. std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
  76. return std::make_shared<AbstractTuple>(abstract_list);
  77. }
  78. if (t->isa<AbstractRowTensor>()) {
  79. auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t);
  80. std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(),
  81. abs_row_tensor->dense_shape()};
  82. return std::make_shared<AbstractTuple>(abstract_list);
  83. }
  84. return nullptr;
  85. }
  86. AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
  87. MS_EXCEPTION_IF_NULL(node);
  88. MS_EXCEPTION_IF_NULL(node->func_graph());
  89. const auto &inputs = node->inputs();
  90. // Inputs should be [getattr, data, attribute]
  91. const size_t expect_inputs_size = 3;
  92. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  93. constexpr size_t data_index = 1;
  94. constexpr size_t attribute_index = 2;
  95. AnfNodePtr data = inputs[data_index];
  96. AnfNodePtr cons = inputs[attribute_index];
  97. MS_EXCEPTION_IF_NULL(data);
  98. MS_EXCEPTION_IF_NULL(cons);
  99. auto dt = data->abstract();
  100. if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
  101. return nullptr;
  102. }
  103. if (!dt->isa<AbstractClass>()) {
  104. MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
  105. }
  106. auto cons_is_str = IsValueNode<StringImm>(cons);
  107. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  108. auto ct = dyn_cast<AbstractClass>(dt);
  109. const auto &cmap = ct->attributes();
  110. int64_t count = 0;
  111. for (auto &item : cmap) {
  112. if (cons_is_str && item.first == cons_str) {
  113. break;
  114. }
  115. count++;
  116. }
  117. auto idx_c = NewValueNode(count);
  118. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
  119. idx_c->set_abstract(aptr);
  120. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  121. }
  122. AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
  123. MS_EXCEPTION_IF_NULL(node);
  124. MS_EXCEPTION_IF_NULL(node->func_graph());
  125. // Inputs should be [dict_getitem, dict, item]
  126. const auto &inputs = node->inputs();
  127. const size_t expect_inputs_size = 3;
  128. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  129. constexpr size_t data_index = 1;
  130. constexpr size_t cons_index = 2;
  131. AnfNodePtr data = inputs[data_index];
  132. AnfNodePtr cons = inputs[cons_index];
  133. MS_EXCEPTION_IF_NULL(data);
  134. MS_EXCEPTION_IF_NULL(cons);
  135. auto dt = data->abstract();
  136. MS_EXCEPTION_IF_NULL(dt);
  137. if (!dt->isa<abstract::AbstractDictionary>()) {
  138. MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name();
  139. }
  140. auto cons_is_str = IsValueNode<StringImm>(cons);
  141. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  142. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  143. const auto &cmap = ct->elements();
  144. int64_t count = 0;
  145. for (auto &item : cmap) {
  146. if (cons_is_str && item.first == cons_str) {
  147. break;
  148. }
  149. count++;
  150. }
  151. auto idx_c = NewValueNode(count);
  152. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
  153. idx_c->set_abstract(aptr);
  154. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  155. }
  156. AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
  157. MS_EXCEPTION_IF_NULL(node);
  158. MS_EXCEPTION_IF_NULL(node->func_graph());
  159. // Inputs should be [dict_setitem, dict, item, value]
  160. const auto &inputs = node->inputs();
  161. const size_t expect_inputs_size = 4;
  162. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  163. const size_t data_index = 1;
  164. const size_t cons_index = 2;
  165. const size_t item_value_index = 3;
  166. AnfNodePtr data = inputs[data_index];
  167. AnfNodePtr cons = inputs[cons_index];
  168. AnfNodePtr item_value = inputs[item_value_index];
  169. MS_EXCEPTION_IF_NULL(data);
  170. MS_EXCEPTION_IF_NULL(cons);
  171. auto dt = data->abstract();
  172. MS_EXCEPTION_IF_NULL(dt);
  173. if (!dt->isa<abstract::AbstractDictionary>()) {
  174. MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
  175. }
  176. auto cons_is_str = IsValueNode<StringImm>(cons);
  177. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  178. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  179. const auto &cmap = ct->elements();
  180. int64_t count = 0;
  181. for (auto &item : cmap) {
  182. if (cons_is_str && item.first == cons_str) {
  183. break;
  184. }
  185. count++;
  186. }
  187. if (LongToSize(count) >= cmap.size()) {
  188. // for dictionary set, if the key does not exist, we should create a new item
  189. auto tuple_add_op = std::make_shared<prim::TupleAdd>("tuple_add");
  190. auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value});
  191. return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item});
  192. }
  193. auto idx_c = NewValueNode(count);
  194. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
  195. idx_c->set_abstract(aptr);
  196. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
  197. }
  198. AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
  199. MS_EXCEPTION_IF_NULL(node);
  200. MS_EXCEPTION_IF_NULL(node->func_graph());
  201. std::vector<AnfNodePtr> inputs;
  202. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  203. // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr;
  204. (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end());
  205. return node->func_graph()->NewCNode(inputs);
  206. }
  207. AnfNodePtr ErasePartialNode(const CNodePtr &node) {
  208. MS_EXCEPTION_IF_NULL(node);
  209. MS_EXCEPTION_IF_NULL(node->func_graph());
  210. const auto &inputs = node->inputs();
  211. // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
  212. const size_t min_inputs_size = 2;
  213. if (inputs.size() < min_inputs_size) {
  214. MS_LOG(EXCEPTION) << "Partial should have at least 2 inputs, but got " << inputs.size();
  215. }
  216. std::vector<AnfNodePtr> args(inputs.begin() + 2, inputs.end());
  217. auto oper = inputs[1];
  218. if (IsPrimitive(oper, prim::kPrimMakeRecord)) {
  219. if (args.size() == 1) {
  220. return NewValueNode(prim::kPrimMakeTuple);
  221. }
  222. if (args.size() > 1) {
  223. std::vector<AnfNodePtr> new_inputs;
  224. new_inputs.emplace_back(NewValueNode(prim::kPrimPartial));
  225. new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  226. (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end());
  227. MS_EXCEPTION_IF_NULL(node->func_graph());
  228. return node->func_graph()->NewCNode(new_inputs);
  229. }
  230. }
  231. return nullptr;
  232. }
  233. AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
  234. MS_EXCEPTION_IF_NULL(node);
  235. MS_EXCEPTION_IF_NULL(node->func_graph());
  236. std::vector<AnfNodePtr> inputs;
  237. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  238. // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
  239. (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
  240. return node->func_graph()->NewCNode(inputs);
  241. }
  242. AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
  243. MS_EXCEPTION_IF_NULL(node);
  244. MS_EXCEPTION_IF_NULL(node->func_graph());
  245. const auto &inputs = node->inputs();
  246. // Inputs should be [list_getitem, list, item]
  247. constexpr size_t expect_input_size = 3;
  248. CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node));
  249. constexpr size_t real_input_index = 1;
  250. constexpr size_t index_input_index = 2;
  251. AnfNodePtr data = inputs[real_input_index];
  252. AnfNodePtr cons = inputs[index_input_index];
  253. MS_EXCEPTION_IF_NULL(data);
  254. MS_EXCEPTION_IF_NULL(cons);
  255. auto cons_node = cons->cast<ValueNodePtr>();
  256. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
  257. }
  258. AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
  259. MS_EXCEPTION_IF_NULL(node);
  260. MS_EXCEPTION_IF_NULL(node->func_graph());
  261. const auto &inputs = node->inputs();
  262. // Inputs should be [list_setitem, list, index, item]
  263. const size_t expect_inputs_size = 4;
  264. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  265. const size_t data_index = 1;
  266. const size_t cons_index = 2;
  267. const size_t value_index = 3;
  268. AnfNodePtr data = inputs[data_index];
  269. AnfNodePtr cons = inputs[cons_index];
  270. AnfNodePtr value = inputs[value_index];
  271. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
  272. }
  273. AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
  274. MS_EXCEPTION_IF_NULL(node);
  275. const auto &inputs = node->inputs();
  276. const size_t expect_inputs_size = 3;
  277. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  278. return inputs[2];
  279. }
  280. AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
  281. MS_EXCEPTION_IF_NULL(node);
  282. const auto &inputs = node->inputs();
  283. const size_t expect_inputs_size = 2;
  284. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  285. return inputs[1];
  286. }
  287. AnfNodePtr EraseDictItems(const CNodePtr &node) {
  288. MS_EXCEPTION_IF_NULL(node);
  289. const auto &inputs = node->inputs();
  290. const size_t expect_inputs_size = 2;
  291. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  292. const auto &tmp = inputs[0]->cast<ValueNodePtr>();
  293. MS_EXCEPTION_IF_NULL(tmp);
  294. MS_EXCEPTION_IF_NULL(tmp->value()->cast<ValueTuplePtr>());
  295. ValuePtrList keys = tmp->value()->cast<ValueTuplePtr>()->value();
  296. std::vector<AnfNodePtr> outer_node{NewValueNode(prim::kPrimMakeList)};
  297. for (size_t i = 0; i < keys.size(); ++i) {
  298. std::vector<AnfNodePtr> inner_node;
  299. inner_node.push_back(NewValueNode(prim::kPrimMakeTuple));
  300. inner_node.push_back(NewValueNode(keys[i]));
  301. inner_node.push_back(NewCNode(
  302. std::vector<AnfNodePtr>{NewValueNode(prim::kPrimTupleGetItem), inputs[1], NewValueNode(i)}, node->func_graph()));
  303. outer_node.push_back(NewCNode(inner_node, node->func_graph()));
  304. }
  305. return NewCNode(outer_node, node->func_graph());
  306. }
  307. AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
  308. MS_EXCEPTION_IF_NULL(node);
  309. const auto &inputs = node->inputs();
  310. // Inputs should be [make_keyword_arg, key, value]
  311. constexpr size_t expect_input_size = 3;
  312. constexpr size_t value_inputs_index = 2;
  313. CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node));
  314. return inputs[value_inputs_index];
  315. }
  316. AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
  317. MS_EXCEPTION_IF_NULL(node);
  318. const auto &inputs = node->inputs();
  319. // Inputs should be [extract_keyword_arg, arg, key]
  320. const size_t expect_inputs_size = 3;
  321. CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
  322. constexpr size_t key_index = 2;
  323. return inputs[key_index];
  324. }
  325. ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
  326. const int64_t DEPTH_MAX = 5;
  327. if (depth > DEPTH_MAX) {
  328. MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels.";
  329. }
  330. std::vector<ValuePtr> elements;
  331. for (const auto &it : value_list->value()) {
  332. ValuePtr value = nullptr;
  333. if (it->isa<ValueList>()) {
  334. value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
  335. } else {
  336. value = it;
  337. }
  338. elements.push_back(value);
  339. }
  340. return std::make_shared<ValueTuple>(elements);
  341. }
  342. AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
  343. MS_EXCEPTION_IF_NULL(node);
  344. ValuePtr value = node->value();
  345. auto value_list = value->cast<ValueListPtr>();
  346. MS_EXCEPTION_IF_NULL(value_list);
  347. int64_t depth = 0;
  348. return std::make_shared<ValueNode>(ConvertValueListToValueTuple(value_list, depth));
  349. }
  350. // Convert class to Tuple
  351. // Convert getattr to getitem
  352. // Convert make_record to make_tuple
  353. bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  354. MS_EXCEPTION_IF_NULL(manager);
  355. manager->AddFuncGraph(root);
  356. bool changed = false;
  357. // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  358. AnfNodeSet all_node = manager->all_nodes();
  359. for (auto &node : all_node) {
  360. MS_EXCEPTION_IF_NULL(node);
  361. auto cnode = node->cast<CNodePtr>();
  362. AnfNodePtr new_node = nullptr;
  363. if (IsValueNode<parse::ClassObject>(node)) {
  364. new_node = NewValueNode(prim::kPrimMakeTuple);
  365. } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) {
  366. new_node = ConvertGetAttrToTupleGetItem(cnode);
  367. } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) {
  368. new_node = ConvertMakeRecordToMakeTuple(cnode);
  369. } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
  370. new_node = ErasePartialNode(cnode);
  371. } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
  372. new_node = ConvertDictGetItemToTupleGetItem(cnode);
  373. } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
  374. new_node = ConvertDictSetItemToTupleSetItem(cnode);
  375. } else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) {
  376. new_node = EraseDictGetValues(cnode);
  377. } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
  378. new_node = EraseMakeDictNode(cnode);
  379. } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
  380. new_node = EraseMakeKeywordArgNode(cnode);
  381. } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
  382. new_node = EraseExtractKeywordArg(cnode);
  383. } else if (IsPrimitiveCNode(node, prim::kPrimDictItems)) {
  384. new_node = EraseDictItems(cnode);
  385. }
  386. if (new_node != nullptr) {
  387. new_node->set_abstract(node->abstract());
  388. MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
  389. (void)manager->Replace(node, new_node);
  390. changed = true;
  391. }
  392. }
  393. for (auto &node : manager->all_nodes()) {
  394. auto ret = Reabs(node->abstract());
  395. if (ret) {
  396. MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
  397. << ret->ToString();
  398. node->set_abstract(ret);
  399. if (ret->cast<abstract::AbstractTuplePtr>()->size() > 0) {
  400. changed = true;
  401. }
  402. }
  403. }
  404. return changed;
  405. }
  406. AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) {
  407. MS_EXCEPTION_IF_NULL(node);
  408. MS_EXCEPTION_IF_NULL(node->func_graph());
  409. std::vector<AnfNodePtr> inputs;
  410. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  411. // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
  412. (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
  413. return node->func_graph()->NewCNode(inputs);
  414. }
  415. AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int64_t &index) {
  416. MS_EXCEPTION_IF_NULL(node);
  417. MS_EXCEPTION_IF_NULL(node->func_graph());
  418. const auto &inputs = node->inputs();
  419. // Inputs should be [sparse_getattr, sparse]
  420. constexpr size_t expect_input_index = 2;
  421. CheckInputsSize(inputs.size(), expect_input_index, GetCNodeFuncName(node));
  422. constexpr size_t sparse_index = 1;
  423. AnfNodePtr sparse = inputs[sparse_index];
  424. MS_EXCEPTION_IF_NULL(sparse);
  425. auto cons_node = NewValueNode(index);
  426. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(index));
  427. cons_node->set_abstract(aptr);
  428. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node});
  429. }
  430. bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  431. MS_EXCEPTION_IF_NULL(manager);
  432. manager->AddFuncGraph(root);
  433. bool changed = false;
  434. // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  435. auto all_node = manager->all_nodes();
  436. for (auto &node : all_node) {
  437. MS_EXCEPTION_IF_NULL(node);
  438. auto cnode = node->cast<CNodePtr>();
  439. AnfNodePtr new_node = nullptr;
  440. if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
  441. new_node = ConvertMakeListToMakeTuple(cnode);
  442. } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
  443. new_node = ConvertListGetItemToTupleGetItem(cnode);
  444. } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) {
  445. new_node = ConvertListSetItemToTupleSetItem(cnode);
  446. } else if (IsValueNode<ValueList>(node)) {
  447. new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
  448. } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
  449. IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) {
  450. new_node = ConvertMakeSparseToMakeTuple(cnode);
  451. } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
  452. IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) {
  453. constexpr int64_t indices_index = 0;
  454. new_node = ConvertSparseGetAttrToTupleGetItem(cnode, indices_index);
  455. } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
  456. IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) {
  457. constexpr int64_t value_index = 1;
  458. new_node = ConvertSparseGetAttrToTupleGetItem(cnode, value_index);
  459. } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
  460. IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) {
  461. constexpr int64_t shape_index = 2;
  462. new_node = ConvertSparseGetAttrToTupleGetItem(cnode, shape_index);
  463. }
  464. if (new_node != nullptr) {
  465. new_node->set_abstract(node->abstract());
  466. MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
  467. (void)manager->Replace(node, new_node);
  468. changed = true;
  469. }
  470. }
  471. for (auto &node : manager->all_nodes()) {
  472. auto ret = AdaptAbs(node->abstract());
  473. if (ret) {
  474. MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
  475. << ret->ToString();
  476. node->set_abstract(ret);
  477. changed = true;
  478. }
  479. }
  480. return changed;
  481. }
  482. } // namespace opt
  483. } // namespace mindspore