You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opt_test.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. /**
  2. * Copyright 2020-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "ir/anf.h"
  21. #include "ir/visitor.h"
  22. #include "ir/func_graph_cloner.h"
  23. #include "frontend/optimizer/optimizer.h"
  24. #include "frontend/optimizer/opt.h"
  25. #include "frontend/optimizer/anf_visitor.h"
  26. #include "frontend/optimizer/irpass.h"
  27. #include "frontend/optimizer/irpass/arithmetic_simplify.h"
  28. #include "pipeline/jit/action.h"
  29. #include "include/common/debug/draw.h"
  30. #include "frontend/operator/ops.h"
  31. #include "include/common/utils/cse.h"
  32. #include "include/common/utils/convert_utils.h"
  33. namespace mindspore {
  34. namespace opt {
  35. class TestOptOpt : public UT::Common {
  36. public:
  37. TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {}
  38. class IdempotentEliminater : public AnfVisitor {
  39. public:
  40. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  41. x_ = nullptr;
  42. AnfVisitor::Match(P, {irpass::IsCNode})(node);
  43. if (x_ == nullptr || node->func_graph() == nullptr) {
  44. return nullptr;
  45. }
  46. return node->func_graph()->NewCNode({NewValueNode(P), x_});
  47. };
  48. void Visit(const CNodePtr &cnode) override {
  49. if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) {
  50. x_ = cnode->input(1);
  51. }
  52. }
  53. private:
  54. AnfNodePtr x_{nullptr};
  55. };
  56. class QctToP : public AnfVisitor {
  57. public:
  58. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  59. v_ = nullptr;
  60. AnfVisitor::Match(Q, {irpass::IsVNode})(node);
  61. if (v_ == nullptr || node->func_graph() == nullptr) {
  62. return nullptr;
  63. }
  64. return node->func_graph()->NewCNode({NewValueNode(P), v_});
  65. };
  66. void Visit(const ValueNodePtr &vnode) override { v_ = vnode; }
  67. private:
  68. AnfNodePtr v_{nullptr};
  69. };
  70. void SetUp() {
  71. elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
  72. elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
  73. idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
  74. Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
  75. }
  76. bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
  77. equiv_node.clear();
  78. equiv_graph.clear();
  79. FuncGraphPtr gbefore_clone = BasicClone(gbefore);
  80. OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
  81. transform(gbefore_clone, optimizer);
  82. return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
  83. }
  84. bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) {
  85. SubstitutionList eq(opts);
  86. return CheckTransform(before, after, eq);
  87. }
  88. public:
  89. UT::PyFuncGraphFetcher getPyFun;
  90. FuncGraphPairMapEquiv equiv_graph;
  91. NodeMapEquiv equiv_node;
  92. irpass::OptimizeIRPassLib irpass_lib;
  93. static const PrimitivePtr P;
  94. static const PrimitivePtr Q;
  95. static const PrimitivePtr R;
  96. SubstitutionPtr elim_Z;
  97. SubstitutionPtr elim_R;
  98. SubstitutionPtr idempotent_P;
  99. SubstitutionPtr Qct_to_P;
  100. SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
  101. };
  102. const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
  103. const PrimitivePtr TestOptOpt::Q = std::make_shared<Primitive>("Q");
  104. const PrimitivePtr TestOptOpt::R = std::make_shared<Primitive>("R");
  105. TEST_F(TestOptOpt, TestCheckOptIsClone) {
  106. FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
  107. ASSERT_TRUE(nullptr != before);
  108. ASSERT_TRUE(CheckOpt(before, before));
  109. ASSERT_FALSE(CheckOpt(before, before, std::vector<SubstitutionPtr>({elim_Z})));
  110. }
  111. TEST_F(TestOptOpt, Elim) {
  112. FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
  113. FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
  114. ASSERT_TRUE(nullptr != before);
  115. ASSERT_TRUE(nullptr != after);
  116. ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
  117. }
  118. TEST_F(TestOptOpt, ElimTwo) {
  119. FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2");
  120. FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
  121. ASSERT_TRUE(nullptr != before);
  122. ASSERT_TRUE(nullptr != after);
  123. ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
  124. }
  125. TEST_F(TestOptOpt, ElimR) {
  126. FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1");
  127. FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after");
  128. ASSERT_TRUE(nullptr != before);
  129. ASSERT_TRUE(nullptr != after);
  130. ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_R})));
  131. }
  132. TEST_F(TestOptOpt, idempotent) {
  133. FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2");
  134. FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1");
  135. FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after");
  136. ASSERT_TRUE(nullptr != before_2);
  137. ASSERT_TRUE(nullptr != before_1);
  138. ASSERT_TRUE(nullptr != after);
  139. ASSERT_TRUE(CheckOpt(before_1, after, std::vector<SubstitutionPtr>({idempotent_P})));
  140. ASSERT_TRUE(CheckOpt(before_2, after, std::vector<SubstitutionPtr>({idempotent_P})));
  141. }
  142. TEST_F(TestOptOpt, ConstantVariable) {
  143. FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1");
  144. FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after");
  145. ASSERT_TRUE(nullptr != before);
  146. ASSERT_TRUE(nullptr != after);
  147. ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({Qct_to_P})));
  148. }
  149. TEST_F(TestOptOpt, CSE) {
  150. // test a simple cse testcase test_f1
  151. FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1");
  152. ASSERT_TRUE(nullptr != test_graph1);
  153. // add func_graph the GraphManager
  154. FuncGraphManagerPtr manager1 = Manage(test_graph1);
  155. ASSERT_EQ(manager1->all_nodes().size(), 9);
  156. auto cse = std::make_shared<CSE>();
  157. ASSERT_TRUE(cse != nullptr);
  158. bool is_changed = cse->Cse(test_graph1, manager1);
  159. ASSERT_TRUE(is_changed);
  160. ASSERT_EQ(manager1->all_nodes().size(), 8);
  161. // test a more complicated case test_f2
  162. FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2");
  163. ASSERT_TRUE(nullptr != test_graph2);
  164. FuncGraphManagerPtr manager2 = Manage(test_graph2);
  165. ASSERT_EQ(manager2->all_nodes().size(), 16);
  166. is_changed = cse->Cse(test_graph2, manager2);
  167. ASSERT_TRUE(is_changed);
  168. ASSERT_EQ(manager2->all_nodes().size(), 12);
  169. }
  170. size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
  171. // Check tuple params and tuple args.
  172. auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
  173. size_t tuple_arg_param_num = 0;
  174. auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t {
  175. auto abs = node->abstract();
  176. MS_EXCEPTION_IF_NULL(abs);
  177. return abs->isa<abstract::AbstractTuple>() ? prev_num + 1 : prev_num;
  178. };
  179. for (const auto &node : all_nodes) {
  180. // Count func graph call tuple args.
  181. if (node->isa<CNode>() && !IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
  182. auto call_node = node->cast<CNodePtr>();
  183. tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(),
  184. tuple_arg_param_num, tuple_accumulate_func);
  185. }
  186. // Count partial tuple args.
  187. if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
  188. auto partial = node->cast<CNodePtr>();
  189. constexpr auto kPartialFirstArgIdx = 2;
  190. tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(),
  191. tuple_arg_param_num, tuple_accumulate_func);
  192. }
  193. // Count tuple params.
  194. if (IsValueNode<FuncGraph>(node)) {
  195. auto fg = GetValueNode<FuncGraphPtr>(node);
  196. tuple_arg_param_num =
  197. std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func);
  198. }
  199. }
  200. return tuple_arg_param_num;
  201. }
  202. // Feature: Switch call tuple arg transform.
  203. // Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in
  204. // the same time.
  205. // Expectation: All tuple args are correctly transformed to tensor args.
  206. TEST_F(TestOptOpt, SwitchPartialTupleTrans) {
  207. FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg");
  208. ASSERT_TRUE(nullptr != test_graph);
  209. FuncGraphManagerPtr manager1 = Manage(test_graph);
  210. pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
  211. std::vector<AbstractBasePtr> args_spec;
  212. // Renormalize firstly.
  213. auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
  214. ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
  215. // Flatten tuple param and args.
  216. OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
  217. SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
  218. transform(renormalized_fg, optimizer);
  219. // Renormalize again.
  220. auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
  221. ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
  222. abstract::AnalysisResultCacheMgr::GetInstance().Clear();
  223. abstract::AnalysisContext::ClearContext();
  224. }
  225. // Feature: Switch layer call tuple arg transform.
  226. // Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's
  227. // tensor arg in the same time.
  228. // Expectation: All tuple args are correctly transformed to tensor args.
  229. TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) {
  230. FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg");
  231. ASSERT_TRUE(nullptr != test_graph);
  232. FuncGraphManagerPtr manager1 = Manage(test_graph);
  233. pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
  234. std::vector<AbstractBasePtr> args_spec;
  235. // Renormalize firstly.
  236. auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
  237. ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
  238. // Flatten tuple param and args.
  239. OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
  240. SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
  241. transform(renormalized_fg, optimizer);
  242. // Renormalize again.
  243. auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
  244. ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
  245. abstract::AnalysisResultCacheMgr::GetInstance().Clear();
  246. abstract::AnalysisContext::ClearContext();
  247. }
  248. // Feature: Single graph call tuple arg transform.
  249. // Description: Test single graph call's tuple arg transform.This case include tuple in tuple args.
  250. // Expectation: All tuple args are correctly transformed to tensor args.
  251. TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) {
  252. FuncGraphPtr test_graph =
  253. getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg");
  254. ASSERT_TRUE(nullptr != test_graph);
  255. FuncGraphManagerPtr manager1 = Manage(test_graph);
  256. pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
  257. std::vector<AbstractBasePtr> args_spec;
  258. // Renormalize firstly.
  259. auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
  260. ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
  261. // Flatten tuple param and args.
  262. OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
  263. SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
  264. transform(renormalized_fg, optimizer);
  265. // Renormalize again.
  266. auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
  267. ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
  268. abstract::AnalysisResultCacheMgr::GetInstance().Clear();
  269. abstract::AnalysisContext::ClearContext();
  270. }
  271. } // namespace opt
  272. } // namespace mindspore