You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

env_item_eliminate.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
  18. #include <vector>
  19. #include <utility>
  20. #include <algorithm>
  21. #include <unordered_map>
  22. #include <memory>
  23. #include "optimizer/irpass.h"
  24. #include "optimizer/optimizer.h"
  25. #include "ir/visitor.h"
  26. #include "ir/func_graph.h"
  27. #include "ir/func_graph_cloner.h"
  28. #include "operator/ops.h"
  29. #include "utils/symbolic.h"
  30. namespace mindspore {
  31. namespace opt {
  32. namespace irpass {
  33. namespace internal {
  34. class EnvGetitemTransform {
  35. public:
  36. EnvGetitemTransform() : cache_() {}
  37. ~EnvGetitemTransform() = default;
  38. FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
  39. if (cache_.find(fg) == cache_.end()) {
  40. cache_[fg] = {};
  41. }
  42. auto &cache = cache_[fg];
  43. auto hash_key = std::make_pair(key, default_node);
  44. if (cache.find(hash_key) == cache.end()) {
  45. std::ostringstream ss("env", std::ostringstream::app);
  46. if (key->node() != nullptr) {
  47. ss << key->node()->ToString();
  48. }
  49. auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
  50. auto env = new_fg->output();
  51. while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
  52. // {prim::kPrimEnvSetItem, env, symbolickey, value}
  53. auto &inputs = env->cast<CNodePtr>()->inputs();
  54. if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
  55. MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance.";
  56. }
  57. env = inputs[1];
  58. auto value = inputs[3];
  59. auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
  60. if (*key2 == *key) {
  61. new_fg->set_output(value);
  62. cache[hash_key] = new_fg;
  63. cache_[fg] = cache;
  64. return new_fg;
  65. }
  66. }
  67. new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
  68. cache[hash_key] = new_fg;
  69. }
  70. return cache[hash_key];
  71. }
  72. private:
  73. std::unordered_map<FuncGraphPtr,
  74. std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
  75. cache_;
  76. };
  77. } // namespace internal
  78. // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
  79. class NewEnvGetItem : public AnfVisitor {
  80. public:
  81. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  82. Reset();
  83. auto gety = [this](const AnfNodePtr &node) -> bool {
  84. this->y_ = node;
  85. return true;
  86. };
  87. AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
  88. if (env_ != nullptr && env_->Len() == 0) {
  89. return y_;
  90. }
  91. return nullptr;
  92. }
  93. void Visit(const ValueNodePtr &vnode) override {
  94. if (env_ == nullptr) {
  95. env_ = GetValueNode<EnvInstancePtr>(vnode);
  96. }
  97. }
  98. void Reset() {
  99. y_ = nullptr;
  100. env_ = nullptr;
  101. }
  102. private:
  103. AnfNodePtr y_{nullptr};
  104. EnvInstancePtr env_{nullptr};
  105. };
  106. // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
  107. // {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
  108. class AddEnvGetItem : public AnfVisitor {
  109. public:
  110. AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
  111. ~AddEnvGetItem() override = default;
  112. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  113. is_match_ = false;
  114. auto IsAddCNode = [](const AnfNodePtr &node) -> bool {
  115. return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast<CNodePtr>()->size() == 3;
  116. };
  117. AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node);
  118. if (!is_match_ || node->func_graph() == nullptr) {
  119. return nullptr;
  120. }
  121. // {prim::kPrimEnvGetItem, {...}, C, Z}
  122. auto cnode = node->cast<CNodePtr>();
  123. auto inp1 = cnode->input(1)->cast<CNodePtr>();
  124. auto c = cnode->input(2);
  125. auto z = cnode->input(3);
  126. // {prim::kPrimEnvAdd, X, Y}
  127. auto x = inp1->input(1);
  128. auto y = inp1->input(2);
  129. auto fg = node->func_graph();
  130. auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z});
  131. auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z});
  132. return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz});
  133. }
  134. void Visit(const AnfNodePtr &) override { is_match_ = true; }
  135. private:
  136. bool is_match_{false};
  137. ValuePtr PrimHyperAdd_;
  138. };
  139. // {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
  140. class EnvGetSetItem : public AnfVisitor {
  141. public:
  142. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  143. is_match_ = false;
  144. auto IsSetCNode = [](const AnfNodePtr &node) -> bool {
  145. if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) {
  146. return false;
  147. }
  148. // {prim::kPrimEnvSetItem, X, C1, Y}
  149. auto &inputs = node->cast<CNodePtr>()->inputs();
  150. if (inputs.size() != 4) {
  151. return false;
  152. }
  153. return IsValueNode<SymbolicKeyInstance>(inputs[2]);
  154. };
  155. AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
  156. if (!is_match_ || node->func_graph() == nullptr) {
  157. return nullptr;
  158. }
  159. // {prim::kPrimEnvGetItem, {...}, C2, Z}
  160. auto cnode = node->cast<CNodePtr>();
  161. auto inp1 = cnode->input(1)->cast<CNodePtr>();
  162. auto key2 = cnode->input(2);
  163. auto c2 = GetValueNode<SymbolicKeyInstancePtr>(key2);
  164. auto default_v = cnode->input(3);
  165. // {prim::kPrimEnvSetItem, X, C1, Y}
  166. auto env = inp1->input(1);
  167. auto c1 = GetValueNode<SymbolicKeyInstancePtr>(inp1->input(2));
  168. auto last_set = inp1->input(3);
  169. if (*c1 == *c2) {
  170. return last_set;
  171. }
  172. while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
  173. // {prim::kPrimEnvSetItem, env, symbolickey, value}
  174. auto &inputs = env->cast<CNodePtr>()->inputs();
  175. if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
  176. MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance.";
  177. }
  178. env = inputs[1];
  179. last_set = inputs[3];
  180. auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
  181. if (*symbolic_c1 == *c2) {
  182. return last_set;
  183. }
  184. }
  185. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v});
  186. }
  187. void Visit(const AnfNodePtr &) override { is_match_ = true; }
  188. private:
  189. bool is_match_{false};
  190. };
  191. // {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
  192. class IncorporateEnvGetitem : public AnfVisitor {
  193. public:
  194. IncorporateEnvGetitem() : env_get_item_transform_() {}
  195. ~IncorporateEnvGetitem() override = default;
  196. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  197. is_match_ = false;
  198. auto IsGCNode = [](const AnfNodePtr &node) -> bool {
  199. auto cnode = node->cast<CNodePtr>();
  200. if (cnode == nullptr || cnode->size() < 1) {
  201. return false;
  202. }
  203. return IsValueNode<FuncGraph>(cnode->input(0));
  204. };
  205. AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
  206. if (!is_match_) {
  207. return nullptr;
  208. }
  209. // {prim::kPrimEnvGetItem, {...}, C, Y}
  210. auto cnode = node->cast<CNodePtr>();
  211. auto inp1 = cnode->input(1)->cast<CNodePtr>();
  212. auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
  213. auto default_v = cnode->input(3);
  214. // {G, Xs}
  215. auto inputs = inp1->inputs();
  216. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  217. auto new_fg = env_get_item_transform_(fg, key, default_v);
  218. std::vector<AnfNodePtr> args;
  219. args.push_back(NewValueNode(new_fg));
  220. (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
  221. return node->func_graph()->NewCNode(args);
  222. }
  223. void Visit(const AnfNodePtr &) override { is_match_ = true; }
  224. private:
  225. bool is_match_{false};
  226. internal::EnvGetitemTransform env_get_item_transform_;
  227. };
  228. // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
  229. class IncorporateEnvGetitemSwitch : public AnfVisitor {
  230. public:
  231. IncorporateEnvGetitemSwitch() : env_get_item_transform_() {}
  232. ~IncorporateEnvGetitemSwitch() override = default;
  233. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  234. is_match_ = false;
  235. auto IsSwNode = [](const AnfNodePtr &node) -> bool {
  236. auto cnode = node->cast<CNodePtr>();
  237. if (cnode == nullptr || cnode->size() < 1) {
  238. return false;
  239. }
  240. return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch);
  241. };
  242. AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
  243. if (!is_match_ || node->func_graph() == nullptr) {
  244. return nullptr;
  245. }
  246. // {prim::kPrimEnvGetItem, {...}, C, Y}
  247. auto cnode = node->cast<CNodePtr>();
  248. auto inp1 = cnode->input(1)->cast<CNodePtr>();
  249. auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
  250. auto default_v = cnode->input(3);
  251. // {{prim::kPrimSwitch, X, G1, G2}, Xs}
  252. auto inputs = inp1->inputs();
  253. is_match_ = false;
  254. AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(inputs[0]);
  255. if (!is_match_) {
  256. return nullptr;
  257. }
  258. // {prim::kPrimSwitch, X, G1, G2}
  259. auto sw = inputs[0]->cast<CNodePtr>();
  260. auto x = sw->input(1);
  261. auto g1 = GetValueNode<FuncGraphPtr>(sw->input(2));
  262. auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
  263. auto new_g1 = env_get_item_transform_(g1, key, default_v);
  264. auto new_g2 = env_get_item_transform_(g2, key, default_v);
  265. auto fg = node->func_graph();
  266. auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});
  267. std::vector<AnfNodePtr> args{new_sw};
  268. (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
  269. return fg->NewCNode(args);
  270. }
  271. void Visit(const AnfNodePtr &) override { is_match_ = true; }
  272. private:
  273. bool is_match_{false};
  274. internal::EnvGetitemTransform env_get_item_transform_;
  275. };
  276. } // namespace irpass
  277. } // namespace opt
  278. } // namespace mindspore
  279. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_