You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cse.cc 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "frontend/optimizer/cse.h"
  19. #include <vector>
  20. #include <set>
  21. #include <unordered_map>
  22. #include <unordered_set>
  23. #include <algorithm>
  24. #include "abstract/abstract_function.h"
  25. #include "utils/flags.h"
  26. #include "utils/utils.h"
  27. #include "base/core_ops.h"
  28. namespace mindspore {
  29. /* namespace to support opt */
  30. namespace opt {
  31. using mindspore::abstract::AbstractBase;
  32. using mindspore::abstract::AbstractFunction;
  33. using mindspore::abstract::AbstractFunctionPtr;
  34. bool WithRecomputedScope(const AnfNodePtr &node) {
  35. MS_EXCEPTION_IF_NULL(node);
  36. if (!node->isa<CNode>()) {
  37. return false;
  38. }
  39. auto full_name_with_scope = node->fullname_with_scope();
  40. return full_name_with_scope.find(kAttrRecompute) == 0;
  41. }
  42. bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) {
  43. return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) ||
  44. (WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute));
  45. }
  46. BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
  47. MS_EXCEPTION_IF_NULL(node);
  48. auto node_abs = node->abstract();
  49. // In testcase: TestOptOpt.CSE, node->abstract() is null.
  50. if (node_abs == nullptr) {
  51. return kAnyValue;
  52. }
  53. if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
  54. // Ignore the tracking_id and prim pointer hash.
  55. auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
  56. return prim_abs->prim();
  57. } else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
  58. // Ignore the tracking_id.
  59. auto new_fg_abs = node_abs->cast<abstract::AbstractFunctionPtr>()->Copy();
  60. new_fg_abs->set_tracking_id(nullptr);
  61. return new_fg_abs;
  62. }
  63. return node_abs;
  64. }
  65. bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
  66. bool changed = false;
  67. for (FuncGraphPtr fg : manager->func_graphs()) {
  68. MS_EXCEPTION_IF_NULL(fg);
  69. std::vector<std::size_t> order_group;
  70. std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
  71. std::unordered_map<AnfNodePtr, std::size_t> hashes;
  72. std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
  73. for (auto node : toposet) {
  74. MS_EXCEPTION_IF_NULL(node);
  75. if (hashes.find(node) != hashes.end()) {
  76. continue;
  77. }
  78. std::size_t h = 0;
  79. if (node->isa<ValueNode>()) {
  80. ValueNodePtr value_node = node->cast<ValueNodePtr>();
  81. auto value = value_node->value();
  82. MS_EXCEPTION_IF_NULL(value);
  83. h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
  84. } else if (node->isa<CNode>()) {
  85. auto cnode = node->cast<CNodePtr>();
  86. auto &inputs = cnode->inputs();
  87. size_t init = 0;
  88. h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
  89. return hash_combine(hash, hashes[node_in]);
  90. });
  91. } else if (node->isa<Parameter>()) {
  92. h = node->hash();
  93. } else {
  94. MS_LOG(ERROR) << "Unknown node type";
  95. }
  96. hashes[node] = h;
  97. if (groups.find(h) == groups.end()) {
  98. std::vector<AnfNodePtr> innervec({node});
  99. groups[h] = innervec;
  100. order_group.emplace_back(h);
  101. } else {
  102. groups[h].push_back(node);
  103. }
  104. }
  105. changed = DoReplace(manager, order_group, &groups) || changed;
  106. }
  107. return changed;
  108. }
  109. std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
  110. std::vector<AnfNodePtr> *need_replace_loads) {
  111. std::unordered_map<AnfNodePtr, size_t> load_groups_record;
  112. std::vector<std::vector<size_t>> load_groups;
  113. std::unordered_set<AnfNodePtr> unload_users_record;
  114. for (size_t i = 0; i < toposet.size(); i++) {
  115. auto &node = toposet[i];
  116. auto cnode = node->cast<CNodePtr>();
  117. if (cnode == nullptr) {
  118. continue;
  119. }
  120. if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
  121. for (const auto &input : cnode->inputs()) {
  122. if (input->isa<Parameter>() ||
  123. (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) {
  124. unload_users_record.insert(input);
  125. }
  126. }
  127. continue;
  128. }
  129. // Exclude free variable node.
  130. if (cnode->func_graph() != fg) {
  131. continue;
  132. }
  133. auto load_param = cnode->input(1);
  134. // first time get same input1 of load.
  135. if (load_groups_record.find(load_param) == load_groups_record.end()) {
  136. load_groups_record[load_param] = load_groups.size();
  137. load_groups.push_back({i});
  138. if (unload_users_record.find(load_param) == unload_users_record.end()) {
  139. need_replace_loads->emplace_back(cnode);
  140. }
  141. } else {
  142. // not first time get same input1 of load
  143. load_groups[load_groups_record[load_param]].push_back(i);
  144. }
  145. }
  146. return load_groups;
  147. }
  148. std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) {
  149. if (group.size() <= 1) {
  150. return {};
  151. }
  152. auto load_param = toposet[group.back()]->cast<CNodePtr>()->input(1);
  153. size_t cur_load_index = 1;
  154. size_t pre_load_index = 0;
  155. std::vector<size_t> cur_group = {group[pre_load_index]};
  156. std::vector<std::vector<size_t>> split_groups;
  157. while (cur_load_index < group.size()) {
  158. const auto &cur_load = group[cur_load_index];
  159. const auto &prev_load = group[pre_load_index];
  160. const auto param_used_by_other =
  161. std::any_of(toposet.begin() + prev_load, toposet.begin() + cur_load, [&load_param](const AnfNodePtr &node) {
  162. if (!node->isa<CNode>()) {
  163. return false;
  164. }
  165. if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
  166. return false;
  167. }
  168. auto cnode = node->cast<CNodePtr>();
  169. auto &inputs = cnode->inputs();
  170. return std::any_of(inputs.begin(), inputs.end(),
  171. [&load_param](const AnfNodePtr &input) { return load_param == input; });
  172. });
  173. if (param_used_by_other) {
  174. split_groups.push_back(cur_group);
  175. cur_group.clear();
  176. }
  177. cur_group.push_back(cur_load);
  178. pre_load_index++;
  179. cur_load_index++;
  180. }
  181. // push back the last splited group.
  182. split_groups.push_back(cur_group);
  183. return split_groups;
  184. }
  185. // Pattern1======================================
  186. // a = Load(para1, u1)
  187. // ...
  188. // b = Load(para1, u2)
  189. // u3 = UpdateState(u2, b)
  190. //==>
  191. // delete the UpdateState
  192. void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user,
  193. const AnfNodePtr &load) {
  194. const auto &load_cnode = load->cast<CNodePtr>();
  195. const auto &u = load_cnode->input(2);
  196. manager->Replace(load_user, u);
  197. }
  198. // Pattern2======================================
  199. // a = Load(para1, u1)
  200. // ...
  201. // b = Load(para1, u2)
  202. // t = make_tuple(x, b)
  203. // u3 = UpdateState(u2, t)
  204. //==>
  205. // a = Load(para1, u1)
  206. // ...
  207. // b = Load(para1, u2)
  208. // u3 = UpdateState(u2, x)
  209. void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
  210. // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
  211. AnfNodePtr other_input = load;
  212. for (size_t i = 1; i < make_tuple->size(); i++) {
  213. if (make_tuple->input(i) != load) {
  214. other_input = make_tuple->input(i);
  215. break;
  216. }
  217. }
  218. MS_EXCEPTION_IF_NULL(other_input);
  219. manager->Replace(make_tuple, other_input);
  220. }
  221. // Pattern3======================================
  222. // a = Load(para1, u1)
  223. // ...
  224. // b = Load(para1, u2)
  225. // t = make_tuple(x, y, b, z)
  226. // u3 = UpdateState(u2, t)
  227. //==>
  228. // a = Load(para1, u1)
  229. // ...
  230. // b = Load(para1, u2)
  231. // t = make_tuple(x, y, z)
  232. // u3 = UpdateState(u2, t)
  233. void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple,
  234. const AnfNodePtr &load) {
  235. auto &make_tuple_inputs = make_tuple->inputs();
  236. std::vector<AnfNodePtr> new_make_tuple_inputs;
  237. (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs),
  238. [load](const AnfNodePtr &input) { return load != input; });
  239. const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs);
  240. new_make_tuple->set_abstract(make_tuple->abstract());
  241. manager->Replace(make_tuple, new_make_tuple);
  242. }
  243. void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
  244. auto load_users = manager->node_users()[load];
  245. for (const auto &load_user : load_users) {
  246. // Pattern1
  247. if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
  248. DeleteLoadUserUpdateState(manager, load_user.first, load);
  249. continue;
  250. }
  251. if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
  252. const auto &make_tuple = load_user.first->cast<CNodePtr>();
  253. auto &maketuple_users = manager->node_users()[make_tuple];
  254. auto maketuple_as_input_of_update =
  255. maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState);
  256. if (!maketuple_as_input_of_update) {
  257. continue;
  258. }
  259. // Pattern2
  260. if (make_tuple->size() == 3) {
  261. DeleteLoadUserMakeTuple(manager, make_tuple, load);
  262. continue;
  263. }
  264. // Pattern3
  265. if (make_tuple->size() > 3) {
  266. ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
  267. }
  268. }
  269. }
  270. }
  271. bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
  272. const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) {
  273. if (group.size() <= 1) {
  274. return false;
  275. }
  276. const auto &main = toposet[group[0]];
  277. for (size_t i = 1; i < group.size(); i++) {
  278. ReplaceLoadUser(manager, fg, toposet[group[i]]);
  279. manager->Replace(toposet[group[i]], main);
  280. }
  281. return true;
  282. }
  283. AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
  284. auto &params = fg->parameters();
  285. auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend();
  286. auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr &para) { return HasAbstractUMonad(para); });
  287. if (iter != end) {
  288. return *iter;
  289. }
  290. auto monad = NewValueNode(kUMonad);
  291. monad->set_abstract(kUMonad->ToAbstract());
  292. return monad;
  293. }
  294. // Replace UpdateStates with U for first load.
  295. // Covert:
  296. // u1 = UpdateState(u, c)
  297. // p1 = Load(para1, u1) // first load for para1
  298. // To:
  299. // u1 = UpdateState(u, c)
  300. // p1 = Load(para1, u') // u' is first monad in graph or new monad
  301. bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
  302. if (need_replace_loads.size() == 0) {
  303. return false;
  304. }
  305. constexpr size_t second_input_index = 2;
  306. auto monad = GetFirstMonad(fg);
  307. for (const auto &load_node : need_replace_loads) {
  308. if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
  309. continue;
  310. }
  311. auto update_state = load_node->cast<CNodePtr>()->input(second_input_index);
  312. if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) {
  313. continue;
  314. }
  315. auto mgr = fg->manager();
  316. mgr->SetEdge(load_node, second_input_index, monad);
  317. }
  318. return true;
  319. }
  320. // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
  321. // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
  322. bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
  323. auto changed = false;
  324. for (const FuncGraphPtr &fg : manager->func_graphs()) {
  325. std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
  326. std::vector<AnfNodePtr> need_replace_loads;
  327. std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads);
  328. const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
  329. if (update_state_replaced) {
  330. changed = true;
  331. }
  332. // split group if there is no-load node between two load nodes.
  333. std::vector<std::vector<size_t>> need_merge_loads;
  334. for (auto &group : load_groups) {
  335. auto groups = SplitGroup(toposet, group);
  336. need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end());
  337. }
  338. for (auto &group : need_merge_loads) {
  339. const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group);
  340. if (!changed && replaced) {
  341. changed = true;
  342. }
  343. }
  344. }
  345. return changed;
  346. }
  347. // The op like print, summary, or the op do not has true output, and always as a depend node input.
  348. static bool HasSideEffect(const AnfNodePtr &node) {
  349. auto prim = GetCNodePrimitive(node);
  350. if (prim == nullptr) {
  351. return false;
  352. }
  353. auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
  354. if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
  355. return GetValue<bool>(side_effect_v);
  356. }
  357. return false;
  358. }
  359. // If true do not merge the node.
  360. bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
  361. bool has_random_effect = false;
  362. auto prim_main = GetCNodePrimitive(main);
  363. auto prim_node = GetCNodePrimitive(node);
  364. // if has random effect, when generate by different op (not same object), do not merge.
  365. if (prim_main != nullptr) {
  366. if (prim_main == prim_node) {
  367. return false;
  368. }
  369. auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
  370. if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
  371. has_random_effect = GetValue<bool>(effect_val);
  372. }
  373. }
  374. return has_random_effect;
  375. }
  376. bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
  377. MS_EXCEPTION_IF_NULL(main);
  378. MS_EXCEPTION_IF_NULL(node);
  379. if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
  380. auto main_value = GetValueNode(main);
  381. auto node_value = GetValueNode(node);
  382. return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
  383. } else if (main->isa<CNode>() && node->isa<CNode>()) {
  384. auto c_main = main->cast<CNodePtr>();
  385. auto c_node = node->cast<CNodePtr>();
  386. // Not do cse for the node set recompute before the recompute pass.
  387. if (IsSetRecomputed(c_main, c_node)) {
  388. return false;
  389. }
  390. // When appsame is true, check if has side effect, do not merge.
  391. if (check_side_effect && HasSideEffect(main)) {
  392. return false;
  393. }
  394. const auto &inp1 = c_main->inputs();
  395. const auto &inp2 = c_node->inputs();
  396. if (inp1.size() != inp2.size()) {
  397. return false;
  398. }
  399. for (size_t j = 0; j < inp1.size(); j++) {
  400. auto inp1_j = inp1[j];
  401. auto inp2_j = inp2[j];
  402. MS_EXCEPTION_IF_NULL(inp1_j);
  403. MS_EXCEPTION_IF_NULL(inp2_j);
  404. if (!(*inp1_j == *inp2_j)) {
  405. // Handle the case of two different Tensor, but with the same value
  406. if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
  407. auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
  408. auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
  409. if (tensor1->ValueEqual(*tensor2)) {
  410. continue;
  411. }
  412. } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
  413. // When the same side effect node as another two nodes' inputs, we still merge the node.
  414. // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
  415. // node.
  416. if (CheckReplace(inp1_j, inp2_j, false)) {
  417. continue;
  418. }
  419. }
  420. return false;
  421. }
  422. }
  423. // When appsame is true, check if has random effect do not merge
  424. if (CheckRandomEffect(c_main, c_node)) {
  425. return false;
  426. }
  427. return true;
  428. }
  429. // a parameter node.
  430. return false;
  431. }
  432. bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
  433. std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const {
  434. bool changes = false;
  435. std::set<size_t> clear_set;
  436. for (auto &h : order_group) {
  437. std::vector<AnfNodePtr> &group = (*groups)[h];
  438. // If there are more than 2 node in that group, they may be same common expression can be eliminated.
  439. if (group.size() > 1) {
  440. for (size_t k = 0; k < group.size() - 1; k++) {
  441. AnfNodePtr main = group[k];
  442. MS_EXCEPTION_IF_NULL(main);
  443. // When all node in group has been replaced
  444. // or a valuenode node, skip compare in group
  445. if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa<ValueNode>())) {
  446. break;
  447. }
  448. // skip node has been replaced
  449. if (clear_set.find(k) != clear_set.end()) {
  450. continue;
  451. }
  452. // Compare with rest elements in this group.
  453. for (size_t i = k + 1; i < group.size(); i++) {
  454. auto node = group[i];
  455. MS_EXCEPTION_IF_NULL(node);
  456. if (clear_set.find(i) != clear_set.end()) {
  457. continue;
  458. }
  459. if (main->func_graph() != node->func_graph()) {
  460. continue;
  461. }
  462. if (CheckReplace(node, main)) {
  463. changes = true;
  464. (void)manager->Replace(node, main);
  465. (void)clear_set.insert(i);
  466. }
  467. }
  468. }
  469. clear_set.clear();
  470. }
  471. }
  472. return changes;
  473. }
  474. bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const {
  475. MS_EXCEPTION_IF_NULL(manager);
  476. manager->AddFuncGraph(root);
  477. auto change1 = ReplaceAutoMonadNode(manager);
  478. auto change2 = BuildOrderGroupAndDoReplace(manager);
  479. return change1 || change2;
  480. }
  481. } // namespace opt
  482. } // namespace mindspore