You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_monad_eliminate.cc 15 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/auto_monad_eliminate.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <memory>
  20. #include <string>
  21. #include <optional>
  22. #include "utils/hash_map.h"
  23. #include "utils/ordered_map.h"
  24. #include "base/core_ops.h"
  25. #include "abstract/abstract_value.h"
  26. namespace mindspore {
  27. namespace opt {
  28. namespace {
  29. using ParamUserMap = mindspore::HashMap<std::string, std::vector<size_t>>;
  30. using LoadGraphMap = OrderedMap<std::string, std::vector<size_t>>;
  31. std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
  32. auto abs = node->abstract();
  33. if (abs == nullptr) {
  34. // Abstract for some Depends node are not proper set, we follow its input.
  35. if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
  36. return GetRefKey(node->cast<CNodePtr>()->input(1));
  37. }
  38. // Abstract should be set except UpdateState nodes.
  39. if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
  40. MS_LOG(WARNING) << "Abstract not set for " << node->DebugString();
  41. }
  42. return std::nullopt;
  43. }
  44. auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
  45. if (abs_ref == nullptr) {
  46. return std::nullopt;
  47. }
  48. auto ref_key = abs_ref->ref_key_value();
  49. if (ref_key == nullptr) {
  50. return std::nullopt;
  51. }
  52. return ref_key->name();
  53. }
  54. bool HasMemoryEffect(const CNodePtr &cnode) {
  55. const auto &inputs = cnode->inputs();
  56. if (HasAbstractUMonad(inputs.back())) {
  57. // The last input is UMonad.
  58. return true;
  59. }
  60. constexpr size_t kRequiredArgs = 2;
  61. if (inputs.size() > kRequiredArgs) {
  62. // The last two inputs are UMonad and IOMonad.
  63. return HasAbstractIOMonad(inputs.back()) && HasAbstractUMonad(inputs.rbegin()[1]);
  64. }
  65. return false;
  66. }
  67. LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
  68. std::vector<AnfNodePtr> *need_replace_loads, ParamUserMap *param_users,
  69. std::vector<size_t> *special_op_indexes) {
  70. LoadGraphMap load_groups;
  71. for (size_t i = 0; i < toposet.size(); i++) {
  72. auto cnode = dyn_cast<CNode>(toposet[i]);
  73. // Exclude free variable node.
  74. if (cnode == nullptr || cnode->func_graph() != fg) {
  75. continue;
  76. }
  77. // Handle Load node.
  78. if (cnode->IsApply(prim::kPrimLoad)) {
  79. auto ref_key = GetRefKey(cnode->input(1));
  80. if (!ref_key.has_value()) {
  81. MS_LOG(WARNING) << "Load without ref key: " << cnode->DebugString();
  82. continue;
  83. }
  84. // Group load nodes by their input ref key.
  85. auto &group = load_groups[ref_key.value()];
  86. (void)group.emplace_back(i);
  87. if (group.size() == 1) {
  88. // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u),
  89. // Means there are not nodes which modify param before the load.
  90. const bool param_not_used = (param_users->find(ref_key.value()) == param_users->end());
  91. const bool can_replace = (param_not_used && special_op_indexes->empty());
  92. if (can_replace) {
  93. (void)need_replace_loads->emplace_back(cnode);
  94. }
  95. }
  96. continue;
  97. }
  98. // Record special cnode.
  99. bool is_special_op = IsValueNode<FuncGraph>(cnode->input(0)) || cnode->IsApply(prim::kPrimCall) ||
  100. cnode->IsApply(prim::kPrimPartial) || cnode->IsApply(prim::kPrimSwitch) ||
  101. cnode->IsApply(prim::kPrimSwitchLayer);
  102. if (is_special_op) {
  103. (void)special_op_indexes->emplace_back(i);
  104. continue;
  105. }
  106. // Record param user in toposort nodes.
  107. // We only check memory side effect cnodes or Depend nodes.
  108. if (HasMemoryEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
  109. for (size_t n = 1; n < cnode->size(); ++n) {
  110. const auto &input = cnode->input(n);
  111. auto ref_key = GetRefKey(input);
  112. if (ref_key.has_value()) {
  113. (void)(*param_users)[ref_key.value()].emplace_back(i);
  114. }
  115. }
  116. }
  117. }
  118. return load_groups;
  119. }
  120. bool HasIndexBetween(const std::vector<size_t> &indexes, size_t first, size_t second) {
  121. return std::any_of(indexes.begin(), indexes.end(),
  122. [&first, &second](size_t index) { return index > first && index < second; });
  123. }
  124. std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
  125. const std::vector<size_t> &param_user_indexes,
  126. const std::vector<size_t> &special_op_indexes) {
  127. if (group.size() <= 1) {
  128. return {};
  129. }
  130. size_t cur_load_index = 1;
  131. size_t pre_load_index = 0;
  132. std::vector<size_t> cur_group = {group[pre_load_index]};
  133. std::vector<std::vector<size_t>> split_groups;
  134. while (cur_load_index < group.size()) {
  135. const auto cur_load = group[cur_load_index];
  136. const auto prev_load = group[pre_load_index];
  137. // Exist node which is the user of load_param between prev_load and cur_load,
  138. // Do not divide into the same group.
  139. if (HasIndexBetween(param_user_indexes, prev_load, cur_load) ||
  140. HasIndexBetween(special_op_indexes, prev_load, cur_load)) {
  141. (void)split_groups.emplace_back(std::move(cur_group));
  142. }
  143. cur_group.push_back(cur_load);
  144. pre_load_index++;
  145. cur_load_index++;
  146. }
  147. // push back the last splited group.
  148. split_groups.push_back(cur_group);
  149. return split_groups;
  150. }
  151. // Pattern1======================================
  152. // a = Load(para1, u1)
  153. // ...
  154. // b = Load(para1, u2)
  155. // u3 = UpdateState(u2, b)
  156. // ==>
  157. // delete the UpdateState
  158. void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) {
  159. const auto &update_state_cnode = load_user->cast<CNodePtr>();
  160. constexpr size_t monad_index = 1;
  161. const auto &monad = update_state_cnode->input(monad_index);
  162. (void)manager->Replace(load_user, monad);
  163. }
  164. // Pattern2======================================
  165. // a = Load(para1, u1)
  166. // ...
  167. // b = Load(para1, u2)
  168. // t = make_tuple(x, b)
  169. // u3 = UpdateState(u2, t)
  170. //==>
  171. // a = Load(para1, u1)
  172. // ...
  173. // b = Load(para1, u2)
  174. // u3 = UpdateState(u2, x)
  175. void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
  176. // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
  177. AnfNodePtr other_input = load;
  178. for (size_t i = 1; i < make_tuple->size(); i++) {
  179. if (make_tuple->input(i) != load) {
  180. other_input = make_tuple->input(i);
  181. break;
  182. }
  183. }
  184. MS_EXCEPTION_IF_NULL(other_input);
  185. manager->Replace(make_tuple, other_input);
  186. }
  187. // Pattern3======================================
  188. // a = Load(para1, u1)
  189. // ...
  190. // b = Load(para1, u2)
  191. // t = make_tuple(x, y, b, z)
  192. // u3 = UpdateState(u2, t)
  193. //==>
  194. // a = Load(para1, u1)
  195. // ...
  196. // b = Load(para1, u2)
  197. // t = make_tuple(x, y, z)
  198. // u3 = UpdateState(u2, t)
  199. void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple,
  200. const AnfNodePtr &load) {
  201. auto &make_tuple_inputs = make_tuple->inputs();
  202. std::vector<AnfNodePtr> new_make_tuple_inputs;
  203. (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs),
  204. [load](const AnfNodePtr &input) { return load != input; });
  205. const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs);
  206. // Set abstract for the MakeTuple node.
  207. abstract::AbstractBasePtrList element_abstracts;
  208. (void)std::transform(new_make_tuple_inputs.begin() + 1, new_make_tuple_inputs.end(),
  209. std::back_inserter(element_abstracts),
  210. [](const AnfNodePtr &input) { return input->abstract(); });
  211. new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
  212. manager->Replace(make_tuple, new_make_tuple);
  213. }
  214. bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
  215. bool change = false;
  216. auto load_users = manager->node_users()[load];
  217. for (const auto &load_user : load_users) {
  218. // Pattern1
  219. if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
  220. DeleteLoadUserUpdateState(manager, load_user.first);
  221. change = true;
  222. continue;
  223. }
  224. if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
  225. const auto &make_tuple = load_user.first->cast<CNodePtr>();
  226. auto &maketuple_users = manager->node_users()[make_tuple];
  227. auto maketuple_as_input_of_update =
  228. maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState);
  229. if (!maketuple_as_input_of_update) {
  230. continue;
  231. }
  232. // Pattern2
  233. if (make_tuple->size() == 3) {
  234. DeleteLoadUserMakeTuple(manager, make_tuple, load);
  235. change = true;
  236. continue;
  237. }
  238. // Pattern3
  239. if (make_tuple->size() > 3) {
  240. ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
  241. change = true;
  242. }
  243. }
  244. }
  245. return change;
  246. }
  247. bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
  248. const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) {
  249. if (group.size() <= 1) {
  250. return false;
  251. }
  252. bool change = false;
  253. const auto &main = toposet[group[0]];
  254. for (size_t i = 1; i < group.size(); i++) {
  255. change = ReplaceLoadUser(manager, fg, toposet[group[i]]);
  256. manager->Replace(toposet[group[i]], main);
  257. }
  258. return change;
  259. }
  260. AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
  261. auto &params = fg->parameters();
  262. auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend();
  263. auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr &para) { return HasAbstractUMonad(para); });
  264. if (iter != end) {
  265. return *iter;
  266. }
  267. auto monad = NewValueNode(kUMonad);
  268. monad->set_abstract(kUMonad->ToAbstract());
  269. return monad;
  270. }
  271. // Replace UpdateStates with U for first load.
  272. // Covert:
  273. // u1 = UpdateState(u, c)
  274. // p1 = Load(para1, u1) // first load for para1
  275. // To:
  276. // u1 = UpdateState(u, c)
  277. // p1 = Load(para1, u') // u' is first monad in graph or new monad
  278. bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
  279. if (need_replace_loads.size() == 0) {
  280. return false;
  281. }
  282. bool change = false;
  283. constexpr size_t second_input_index = 2;
  284. auto monad = GetFirstMonad(fg);
  285. for (const auto &load_node : need_replace_loads) {
  286. if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
  287. continue;
  288. }
  289. auto update_state = load_node->cast<CNodePtr>()->input(second_input_index);
  290. auto mgr = fg->manager();
  291. MS_EXCEPTION_IF_NULL(mgr);
  292. // If the u1 only used by Load and one other updatestate, no need to replace u1 by u'.
  293. auto &node_users = mgr->node_users()[update_state];
  294. constexpr size_t kUserSize = 2;
  295. if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState) || node_users.size() == kUserSize) {
  296. continue;
  297. }
  298. mgr->SetEdge(load_node, second_input_index, monad);
  299. change = true;
  300. }
  301. return change;
  302. }
  303. } // namespace
  304. // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
  305. // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
  306. bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
  307. auto changed = false;
  308. for (const FuncGraphPtr &fg : manager->func_graphs()) {
  309. std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
  310. // Record the set of the first load of param which no nodes modify param before the load in toposort.
  311. std::vector<AnfNodePtr> need_replace_loads;
  312. // Record the param and the toposort id of the unload user of param, they may modify the value of param.
  313. ParamUserMap param_users;
  314. // Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
  315. std::vector<size_t> special_op_indexes;
  316. auto load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads, &param_users, &special_op_indexes);
  317. // Split group if there is no-load node between two load nodes.
  318. std::vector<std::vector<size_t>> need_merge_loads;
  319. for (auto &load_group : load_groups) {
  320. auto &ref_key = load_group.first;
  321. auto &group = load_group.second;
  322. const auto &param_user_indexes = param_users[ref_key];
  323. auto groups = SplitGroup(group, param_user_indexes, special_op_indexes);
  324. need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end());
  325. }
  326. for (auto &group : need_merge_loads) {
  327. bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group);
  328. if (replaced) {
  329. changed = true;
  330. }
  331. }
  332. bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
  333. if (update_state_replaced) {
  334. changed = true;
  335. }
  336. }
  337. return changed;
  338. }
  339. // Eliminate auto monad node:
  340. // From:
  341. // u1 = UpdateState(...);
  342. // xxx = User(u1); // Other users except below Depend.
  343. // output = Depend(output, u1);
  344. // return output;
  345. // To:
  346. // u1 = UpdateState(...);
  347. // xxx = User(u1);
  348. // return output;
  349. bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const {
  350. auto changed = false;
  351. for (const FuncGraphPtr &fg : manager->func_graphs()) {
  352. auto output = fg->output();
  353. if (output == nullptr) {
  354. continue;
  355. }
  356. if (!IsPrimitiveCNode(output, prim::kPrimDepend)) {
  357. continue;
  358. }
  359. constexpr size_t attach_index = 2;
  360. auto attach = output->cast<CNodePtr>()->input(attach_index);
  361. if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) {
  362. continue;
  363. }
  364. auto &node_users = manager->node_users();
  365. auto iter = node_users.find(attach);
  366. if (iter == node_users.end()) {
  367. MS_LOG(EXCEPTION) << "No user of node: " << attach->DebugString();
  368. }
  369. auto &users = iter->second;
  370. if (users.size() <= 1) {
  371. continue;
  372. }
  373. constexpr size_t input_index = 1;
  374. auto input = output->cast<CNodePtr>()->input(input_index);
  375. MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString();
  376. fg->set_output(input);
  377. changed = true;
  378. }
  379. MS_LOG(DEBUG) << "Changed: " << changed;
  380. return changed;
  381. }
  382. } // namespace opt
  383. } // namespace mindspore