You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

merge_addn.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
  18. #include <vector>
  19. #include <algorithm>
  20. #include <memory>
  21. #include "frontend/optimizer/irpass.h"
  22. #include "frontend/optimizer/optimizer.h"
  23. #include "frontend/optimizer/anf_visitor.h"
  24. #include "frontend/operator/ops.h"
  25. namespace mindspore {
  26. namespace opt {
  27. namespace irpass {
  28. // {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} ->
  29. // {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}}
  30. // {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} ->
  31. // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
  32. class MergeAddN : public AnfVisitor {
  33. public:
  34. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
  35. Reset();
  36. mng_ = optimizer->resource()->manager();
  37. is_outer_ = true;
  38. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
  39. // do not hold this manager
  40. mng_ = nullptr;
  41. if (!is_match_ || node->func_graph() == nullptr) {
  42. return nullptr;
  43. }
  44. auto cnode = node->cast<CNodePtr>();
  45. auto addn = NewValueNode(GetValueNode(cnode->input(0)));
  46. // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
  47. (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
  48. auto fg = node->func_graph();
  49. auto make_node = fg->NewCNode(args_);
  50. return fg->NewCNode({addn, make_node});
  51. }
  52. void Visit(const CNodePtr &cnode) override {
  53. if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
  54. return;
  55. }
  56. auto &inputs = cnode->inputs();
  57. if (is_outer_) {
  58. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_));
  59. is_outer_ = false;
  60. is_inner_ = true;
  61. // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
  62. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]);
  63. if (is_match_) {
  64. if (!is_unique(inputs[1])) {
  65. is_match_ = false;
  66. return;
  67. }
  68. (void)Ys_.erase(Ys_.begin());
  69. (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
  70. (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
  71. return;
  72. }
  73. // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
  74. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back());
  75. if (is_match_) {
  76. if (!is_unique(inputs.back())) {
  77. is_match_ = false;
  78. return;
  79. }
  80. Ys_.pop_back();
  81. (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
  82. (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
  83. return;
  84. }
  85. return;
  86. }
  87. if (is_inner_) {
  88. is_match_ = true;
  89. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
  90. }
  91. }
  92. bool is_unique(const AnfNodePtr &node) {
  93. auto &node_users = mng_->node_users();
  94. if (node_users.find(node) == node_users.end()) {
  95. return false;
  96. }
  97. size_t n_use = node_users[node].size();
  98. return n_use == 1;
  99. }
  100. void Reset() {
  101. Xs_.clear();
  102. Ys_.clear();
  103. args_.clear();
  104. is_inner_ = false;
  105. is_outer_ = false;
  106. is_match_ = false;
  107. }
  108. private:
  109. FuncGraphManagerPtr mng_{nullptr};
  110. std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
  111. bool is_inner_{false}, is_outer_{false}, is_match_{false};
  112. };
  113. // {PrimAddN, {kPrimMakeTuple, Xs}}
  114. class AddNZeroFilter : public AnfVisitor {
  115. public:
  116. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  117. Reset();
  118. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
  119. if (filtered_Xs_.empty() || node->func_graph() == nullptr) {
  120. return nullptr;
  121. }
  122. // if only two node in filtered_nodes, {make_tuple, x}. return x.
  123. if (filtered_Xs_.size() == 2) {
  124. return filtered_Xs_[1];
  125. }
  126. // if only one node in filtered_nodes, all node is zerolike, return one of the input.
  127. if (filtered_Xs_.size() == 1 && Xs_.size() > 0) {
  128. return Xs_[0];
  129. }
  130. if (!has_zero_like_) {
  131. return nullptr;
  132. }
  133. auto cnode = node->cast<CNodePtr>();
  134. auto addn = NewValueNode(GetValueNode(cnode->input(0)));
  135. auto fg = node->func_graph();
  136. auto make_tuple = fg->NewCNode(filtered_Xs_);
  137. return fg->NewCNode({addn, make_tuple});
  138. }
  139. void Visit(const CNodePtr &cnode) override {
  140. if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
  141. return;
  142. }
  143. auto &inputs = cnode->inputs();
  144. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
  145. // {kPrimMakeTuple, X1, X2, ...}
  146. filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
  147. for (auto &x : Xs_) {
  148. if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
  149. filtered_Xs_.push_back(x);
  150. } else {
  151. has_zero_like_ = true;
  152. }
  153. }
  154. }
  155. void Reset() {
  156. Xs_.clear();
  157. filtered_Xs_.clear();
  158. has_zero_like_ = false;
  159. }
  160. private:
  161. std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
  162. bool has_zero_like_{false};
  163. };
  164. // {PrimAddN, {kPrimMakeTuple, Xs}}
  165. // Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
  166. // case0: AddN(inputs)(inputs size < 2) -> error
  167. // case1: AddN(inputs)(all inputs is ValueNode) -> error
  168. // case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
  169. // case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
  170. // -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
  171. class AddNEliminater : public AnfVisitor {
  172. public:
  173. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  174. if (!node->isa<CNode>() || node->func_graph() == nullptr) {
  175. return nullptr;
  176. }
  177. auto &inputs = node->cast<CNodePtr>()->inputs();
  178. auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
  179. MS_EXCEPTION_IF_NULL(fg);
  180. auto mng = fg->manager();
  181. MS_EXCEPTION_IF_NULL(mng);
  182. if (fg->recursive()) {
  183. return nullptr;
  184. }
  185. auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
  186. mng->AddFuncGraph(new_fg);
  187. need_update_ = false;
  188. bool changed;
  189. do {
  190. changed = Process(new_fg);
  191. } while (changed);
  192. if (!need_update_) {
  193. return nullptr;
  194. } else {
  195. auto new_sx = inputs;
  196. new_sx[0] = NewValueNode(new_fg);
  197. return node->func_graph()->NewCNode(new_sx);
  198. }
  199. }
  200. bool Process(const FuncGraphPtr &func_graph) {
  201. auto mng = func_graph->manager();
  202. MS_EXCEPTION_IF_NULL(mng);
  203. auto nodes = TopoSort(func_graph->output());
  204. bool changed = false;
  205. for (size_t i = 0; i < nodes.size(); ++i) {
  206. auto node = nodes[i];
  207. if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
  208. continue;
  209. }
  210. auto cnode = node->cast<CNodePtr>();
  211. MS_EXCEPTION_IF_NULL(cnode);
  212. auto &tuple_input = cnode->input(1);
  213. MS_EXCEPTION_IF_NULL(tuple_input);
  214. auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
  215. MS_EXCEPTION_IF_NULL(tuple_input_cnode);
  216. auto &tuple_inputs = tuple_input_cnode->inputs();
  217. if (tuple_inputs.size() < 3) {
  218. // case0: inputs size < 2, error
  219. MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
  220. }
  221. int valuenode_num =
  222. std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) {
  223. if (IsValueNode<tensor::Tensor>(node)) {
  224. return accumulator + 1;
  225. } else {
  226. return accumulator;
  227. }
  228. });
  229. if (IntToSize(valuenode_num) == tuple_inputs.size()) {
  230. // case1: all inputs is ValueNode, error
  231. MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
  232. }
  233. if (tuple_inputs.size() == 3) {
  234. // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
  235. MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
  236. ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
  237. std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
  238. tuple_inputs[2]};
  239. mng->Replace(node, func_graph->NewCNode(new_xs));
  240. changed = true;
  241. continue;
  242. }
  243. auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
  244. [](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
  245. if (first_valuenode == tuple_inputs.end()) {
  246. // no ValueNode input found.
  247. continue;
  248. } else {
  249. // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
  250. std::vector<AnfNodePtr> make_tuple_new_xs{
  251. NewValueNode(prim::kPrimMakeTuple),
  252. };
  253. std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
  254. [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
  255. if (node != *first_valuenode) {
  256. make_tuple_new_xs.push_back(node);
  257. }
  258. });
  259. ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
  260. auto new_addn = func_graph->NewCNode(
  261. {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
  262. ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
  263. auto new_add =
  264. func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
  265. (void)mng->Replace(node, new_add);
  266. changed = true;
  267. continue;
  268. }
  269. }
  270. need_update_ = need_update_ || changed;
  271. return changed;
  272. }
  273. private:
  274. bool need_update_{false};
  275. };
  276. } // namespace irpass
  277. } // namespace opt
  278. } // namespace mindspore
  279. #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_