You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clean.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "optimizer/clean.h"
  19. #include <map>
  20. #include <string>
  21. #include <vector>
  22. #include <algorithm>
  23. #include <utility>
  24. #include "./common.h"
  25. #include "debug/trace.h"
  26. namespace mindspore {
  27. /* namespace to support opt */
  28. namespace opt {
  29. using mindspore::abstract::AbstractAttribute;
  30. using mindspore::abstract::AbstractClass;
  31. using mindspore::abstract::AbstractDictionary;
  32. using mindspore::abstract::AbstractJTagged;
  33. using mindspore::abstract::AbstractList;
  34. using mindspore::abstract::AbstractScalar;
  35. using mindspore::abstract::AbstractTuple;
  36. static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
  37. if (t == nullptr) {
  38. return nullptr;
  39. }
  40. AbstractBasePtr res = t;
  41. if (t->isa<AbstractClass>()) {
  42. auto abs_class = dyn_cast<AbstractClass>(t);
  43. AbstractBasePtrList baselist;
  44. auto attributes = abs_class->attributes();
  45. (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
  46. [](const AbstractAttribute &item) { return item.second; });
  47. res = std::make_shared<AbstractTuple>(baselist);
  48. } else if (t->isa<AbstractDictionary>()) {
  49. auto abs_dict = dyn_cast<AbstractDictionary>(t);
  50. AbstractBasePtrList baselist;
  51. auto elements = abs_dict->elements();
  52. (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
  53. [](const AbstractAttribute &item) { return item.second; });
  54. res = std::make_shared<AbstractTuple>(baselist);
  55. } else if (t->isa<AbstractList>()) {
  56. auto abs_dict = dyn_cast<AbstractList>(t);
  57. res = std::make_shared<AbstractTuple>(abs_dict->elements());
  58. }
  59. return res;
  60. }
  61. AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
  62. MS_EXCEPTION_IF_NULL(node);
  63. MS_EXCEPTION_IF_NULL(node->func_graph());
  64. const auto &inputs = node->inputs();
  65. // Inputs should be [getattr, data, attribute]
  66. MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
  67. AnfNodePtr data = inputs[1];
  68. AnfNodePtr cons = inputs[2];
  69. MS_EXCEPTION_IF_NULL(data);
  70. MS_EXCEPTION_IF_NULL(cons);
  71. auto dt = data->abstract();
  72. MS_EXCEPTION_IF_NULL(dt);
  73. if (!dt->isa<AbstractClass>()) {
  74. MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
  75. }
  76. auto cons_is_str = IsValueNode<StringImm>(cons);
  77. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  78. auto ct = dyn_cast<AbstractClass>(dt);
  79. const auto &cmap = ct->attributes();
  80. int count = 0;
  81. for (auto &item : cmap) {
  82. if (cons_is_str && item.first == cons_str) {
  83. break;
  84. }
  85. count++;
  86. }
  87. auto idx_c = NewValueNode(count);
  88. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  89. idx_c->set_abstract(aptr);
  90. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  91. }
  92. AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
  93. MS_EXCEPTION_IF_NULL(node);
  94. MS_EXCEPTION_IF_NULL(node->func_graph());
  95. // Inputs should be [dict_getitem, dict, item]
  96. const auto &inputs = node->inputs();
  97. MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
  98. AnfNodePtr data = inputs[1];
  99. AnfNodePtr cons = inputs[2];
  100. MS_EXCEPTION_IF_NULL(data);
  101. MS_EXCEPTION_IF_NULL(cons);
  102. auto dt = data->abstract();
  103. MS_EXCEPTION_IF_NULL(dt);
  104. if (!dt->isa<abstract::AbstractDictionary>()) {
  105. MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name();
  106. }
  107. auto cons_is_str = IsValueNode<StringImm>(cons);
  108. auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
  109. auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
  110. const auto &cmap = ct->elements();
  111. int count = 0;
  112. for (auto &item : cmap) {
  113. if (cons_is_str && item.first == cons_str) {
  114. break;
  115. }
  116. count++;
  117. }
  118. auto idx_c = NewValueNode(count);
  119. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
  120. idx_c->set_abstract(aptr);
  121. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
  122. }
  123. AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
  124. MS_EXCEPTION_IF_NULL(node);
  125. MS_EXCEPTION_IF_NULL(node->func_graph());
  126. std::vector<AnfNodePtr> inputs;
  127. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  128. // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr;
  129. (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end());
  130. return node->func_graph()->NewCNode(inputs);
  131. }
  132. AnfNodePtr ErasePartialNode(const CNodePtr &node) {
  133. MS_EXCEPTION_IF_NULL(node);
  134. MS_EXCEPTION_IF_NULL(node->func_graph());
  135. const auto &inputs = node->inputs();
  136. // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
  137. MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
  138. std::vector<AnfNodePtr> args(inputs.begin() + 2, inputs.end());
  139. auto oper = inputs[1];
  140. if (IsPrimitive(oper, prim::kPrimMakeRecord)) {
  141. if (args.size() == 1) {
  142. return NewValueNode(prim::kPrimMakeTuple);
  143. }
  144. if (args.size() > 1) {
  145. std::vector<AnfNodePtr> new_inputs;
  146. new_inputs.emplace_back(NewValueNode(prim::kPrimPartial));
  147. new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  148. (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end());
  149. MS_EXCEPTION_IF_NULL(node->func_graph());
  150. return node->func_graph()->NewCNode(new_inputs);
  151. }
  152. }
  153. return nullptr;
  154. }
  155. AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
  156. MS_EXCEPTION_IF_NULL(node);
  157. MS_EXCEPTION_IF_NULL(node->func_graph());
  158. std::vector<AnfNodePtr> inputs;
  159. inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  160. // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
  161. (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
  162. return node->func_graph()->NewCNode(inputs);
  163. }
  164. AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. MS_EXCEPTION_IF_NULL(node->func_graph());
  167. const auto &inputs = node->inputs();
  168. // Inputs should be [list_getitem, list, item]
  169. if (inputs.size() < 3) {
  170. MS_LOG(EXCEPTION) << "Node's input number < 3.";
  171. }
  172. AnfNodePtr data = inputs[1];
  173. AnfNodePtr cons = inputs[2];
  174. MS_EXCEPTION_IF_NULL(data);
  175. MS_EXCEPTION_IF_NULL(cons);
  176. auto cons_node = cons->cast<ValueNodePtr>();
  177. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
  178. }
  179. AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
  180. MS_EXCEPTION_IF_NULL(node);
  181. MS_EXCEPTION_IF_NULL(node->func_graph());
  182. const auto &inputs = node->inputs();
  183. // Inputs should be [list_setitem, list, index, item]
  184. if (inputs.size() < 4) {
  185. MS_LOG(EXCEPTION) << "Node's input number < 4.";
  186. }
  187. AnfNodePtr data = inputs[1];
  188. AnfNodePtr cons = inputs[2];
  189. AnfNodePtr value = inputs[3];
  190. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
  191. }
  192. AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
  193. MS_EXCEPTION_IF_NULL(node);
  194. const auto &inputs = node->inputs();
  195. MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
  196. return inputs[2];
  197. }
  198. AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
  199. MS_EXCEPTION_IF_NULL(node);
  200. const auto &inputs = node->inputs();
  201. // Inputs should be [make_keyword_arg, key, value]
  202. MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
  203. return inputs[2];
  204. }
  205. AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
  206. MS_EXCEPTION_IF_NULL(node);
  207. const auto &inputs = node->inputs();
  208. // Inputs should be [extract_keyword_arg, arg, key]
  209. MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
  210. return inputs[2];
  211. }
  212. ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) {
  213. const int DEPTH_MAX = 5;
  214. if (depth > DEPTH_MAX) {
  215. MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
  216. }
  217. std::vector<ValuePtr> elements;
  218. for (const auto &it : value_list->value()) {
  219. ValuePtr value = nullptr;
  220. if (it->isa<ValueList>()) {
  221. value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
  222. } else {
  223. value = it;
  224. }
  225. elements.push_back(value);
  226. }
  227. return std::make_shared<ValueTuple>(elements);
  228. }
  229. AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
  230. MS_EXCEPTION_IF_NULL(node);
  231. ValuePtr value = node->value();
  232. auto value_list = value->cast<ValueListPtr>();
  233. MS_EXCEPTION_IF_NULL(value_list);
  234. int depth = 0;
  235. return std::make_shared<ValueNode>(ConvertValueListToValueTuple(value_list, depth));
  236. }
  237. // Convert class to Tuple
  238. // Convert getattr to getitem
  239. // Convert make_record to make_tuple
  240. void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  241. MS_EXCEPTION_IF_NULL(manager);
  242. manager->AddFuncGraph(root);
  243. // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  244. AnfNodeSet all_node = manager->all_nodes();
  245. for (auto &node : all_node) {
  246. MS_EXCEPTION_IF_NULL(node);
  247. auto cnode = node->cast<CNodePtr>();
  248. AnfNodePtr new_node = nullptr;
  249. if (IsValueNode<parse::ClassObject>(node)) {
  250. new_node = NewValueNode(prim::kPrimMakeTuple);
  251. } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) {
  252. new_node = ConvertGetAttrToTupleGetItem(cnode);
  253. } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) {
  254. new_node = ConvertMakeRecordToMakeTuple(cnode);
  255. } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
  256. new_node = ErasePartialNode(cnode);
  257. } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
  258. new_node = ConvertDictGetItemToTupleGetItem(cnode);
  259. } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
  260. new_node = EraseMakeDictNode(cnode);
  261. } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
  262. new_node = EraseMakeKeywordArgNode(cnode);
  263. } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
  264. new_node = EraseExtractKeywordArg(cnode);
  265. } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
  266. new_node = ConvertMakeListToMakeTuple(cnode);
  267. } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
  268. new_node = ConvertListGetItemToTupleGetItem(cnode);
  269. } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) {
  270. new_node = ConvertListSetItemToTupleSetItem(cnode);
  271. } else if (IsValueNode<ValueList>(node)) {
  272. new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
  273. }
  274. if (new_node != nullptr) {
  275. new_node->set_abstract(node->abstract());
  276. (void)manager->Replace(node, new_node);
  277. }
  278. }
  279. for (auto &node : manager->all_nodes()) {
  280. auto ret = Reabs(node->abstract());
  281. node->set_abstract(ret);
  282. }
  283. }
  284. // expand tuples in graph parameters
  285. static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
  286. const std::vector<AnfNodePtr> &params) {
  287. MS_EXCEPTION_IF_NULL(mng);
  288. MS_EXCEPTION_IF_NULL(func_graph);
  289. std::vector<AnfNodePtr> new_params;
  290. for (const auto &param : params) {
  291. MS_EXCEPTION_IF_NULL(param);
  292. auto param_abs = param->abstract();
  293. MS_EXCEPTION_IF_NULL(param_abs);
  294. if (param_abs->isa<AbstractJTagged>()) {
  295. MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
  296. }
  297. if (!param_abs->isa<AbstractTuple>()) {
  298. new_params.emplace_back(param);
  299. continue;
  300. }
  301. std::vector<AnfNodePtr> new_param;
  302. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
  303. auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
  304. for (auto &elem : abs_tuple->elements()) {
  305. auto np = std::make_shared<Parameter>(func_graph);
  306. np->set_abstract(elem);
  307. new_param.emplace_back(np);
  308. }
  309. (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
  310. auto new_tuple = func_graph->NewCNode(inputs);
  311. (void)mng->Replace(param, new_tuple);
  312. auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
  313. (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
  314. }
  315. return new_params;
  316. }
  317. // expand tuples in graph applies
  318. static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
  319. MS_EXCEPTION_IF_NULL(graph);
  320. std::vector<AnfNodePtr> new_inputs;
  321. for (const auto &input : inputs) {
  322. MS_EXCEPTION_IF_NULL(input);
  323. auto input_abs = input->abstract();
  324. MS_EXCEPTION_IF_NULL(input_abs);
  325. if (input_abs->isa<AbstractJTagged>()) {
  326. auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
  327. if (abstract_tag->element()->isa<AbstractTuple>()) {
  328. MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
  329. }
  330. }
  331. if (!input_abs->isa<AbstractTuple>()) {
  332. new_inputs.emplace_back(input);
  333. continue;
  334. }
  335. int idx = 0;
  336. std::vector<AnfNodePtr> new_input;
  337. auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
  338. for (auto &elem : abs_tuple->elements()) {
  339. auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
  340. AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
  341. c_node->input(2)->set_abstract(aptr);
  342. c_node->set_abstract(elem);
  343. new_input.emplace_back(c_node);
  344. idx++;
  345. }
  346. auto expand_tuple = ExpandTuplesC(graph, new_input);
  347. (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
  348. }
  349. return new_inputs;
  350. }
  351. // remove most uses of tuples from the graph parameters & apply inputs
  352. // tuples that are returned will be kept
  353. // tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
  354. // CNode("tuple_getitem", (a,b,c), 0)
  355. // CNode("tuple_getitem", (a,b,c), 1)
  356. // CNode("tuple_getitem", (a,b,c), 2)
  357. // tuples in Graph's parameters: AbstractTuple (a, b, c) -->
  358. // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
  359. // cppcheck-suppress unusedFunction
  360. void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
  361. MS_EXCEPTION_IF_NULL(manager);
  362. manager->AddFuncGraph(root);
  363. // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
  364. AnfNodeSet all_node = manager->all_nodes();
  365. for (auto &node : all_node) {
  366. auto cnode = node->cast<CNodePtr>();
  367. if (cnode == nullptr) {
  368. continue;
  369. }
  370. const auto &inputs = cnode->inputs();
  371. // Bypass the first input in inputs as it's fn.
  372. if (!IsValueNode<Primitive>(inputs[0])) {
  373. std::vector<AnfNodePtr> expand_inputs;
  374. (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
  375. auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
  376. if (new_inputs != expand_inputs) {
  377. std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
  378. (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
  379. MS_EXCEPTION_IF_NULL(node->func_graph());
  380. auto new_node = node->func_graph()->NewCNode(cnode_inputs);
  381. new_node->set_abstract(node->abstract());
  382. (void)manager->Replace(node, new_node);
  383. }
  384. // Bypass the first 2 inputs in inputs as it's [partial, fn].
  385. } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
  386. std::vector<AnfNodePtr> expand_inputs;
  387. (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
  388. auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
  389. if (new_inputs != expand_inputs) {
  390. std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
  391. (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
  392. MS_EXCEPTION_IF_NULL(cnode->func_graph());
  393. auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
  394. new_node->set_abstract(cnode->abstract());
  395. (void)manager->Replace(node, new_node);
  396. }
  397. }
  398. }
  399. FuncGraphSet all_graph = manager->func_graphs();
  400. for (auto &func_graph : all_graph) {
  401. MS_EXCEPTION_IF_NULL(func_graph);
  402. auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
  403. manager->SetParameters(func_graph, expand_p);
  404. }
  405. }
  406. } // namespace opt
  407. } // namespace mindspore