You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specialize_test.cc 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "ir/manager.h"
  21. #include "pipeline/jit/static_analysis/prim.h"
  22. #include "pipeline/jit/static_analysis/program_specialize.h"
  23. #include "pipeline/static_analysis/helper.h"
  24. #include "utils/log_adapter.h"
  25. #include "ir/graph_utils.h"
  26. #include "utils/misc.h"
  27. #include "debug/draw.h"
  28. namespace mindspore {
  29. namespace abstract {
  30. class TestSpecializeGraph : public UT::Common {
  31. public:
  32. void SetUp();
  33. void TearDown();
  34. // f(x) call g(x)
  35. FuncGraphPtr graph_f_;
  36. FuncGraphPtr graph_g_;
  37. // alpha(x) return beta(x) closure;
  38. FuncGraphPtr graph_alpha_;
  39. FuncGraphPtr graph_beta_;
  40. std::shared_ptr<AnalysisEngine> engine_;
  41. std::shared_ptr<ProgramSpecializer> special_;
  42. };
  43. void TestSpecializeGraph::SetUp() {
  44. UT::InitPythonPath();
  45. // init resource
  46. engine_ = SetupAnalysisEngine();
  47. special_ = std::make_shared<ProgramSpecializer>(engine_);
  48. /*
  49. * def g(y):
  50. * return y;
  51. */
  52. graph_g_ = std::make_shared<FuncGraph>();
  53. ParameterPtr y = graph_g_->add_parameter();
  54. auto prim_return = std::make_shared<Primitive>("return");
  55. std::vector<AnfNodePtr> inputs;
  56. inputs.push_back(NewValueNode(prim_return));
  57. inputs.push_back(y);
  58. CNodePtr cnode_g_ret = graph_g_->NewCNode(inputs);
  59. graph_g_->set_return(cnode_g_ret);
  60. /*
  61. * def f(x):
  62. * return g(x)
  63. */
  64. graph_f_ = std::make_shared<FuncGraph>();
  65. ParameterPtr x = graph_f_->add_parameter();
  66. inputs.clear();
  67. inputs.push_back(NewValueNode(graph_g_));
  68. inputs.push_back(x);
  69. CNodePtr cnode_f = graph_f_->NewCNode(inputs);
  70. inputs.clear();
  71. inputs.push_back(NewValueNode(prim_return));
  72. inputs.push_back(cnode_f);
  73. CNodePtr cnode_f_ret = graph_f_->NewCNode(inputs);
  74. graph_f_->set_return(cnode_f_ret);
  75. /* build a closure func_graph */
  76. /*
  77. *def alpha(x, y):
  78. * def beta(x1):
  79. * return x1 + y
  80. * return beta(x)
  81. */
  82. graph_alpha_ = std::make_shared<FuncGraph>();
  83. graph_beta_ = std::make_shared<FuncGraph>();
  84. x = graph_alpha_->add_parameter();
  85. y = graph_alpha_->add_parameter();
  86. // build func_graph beta
  87. ParameterPtr x1 = graph_beta_->add_parameter();
  88. inputs.clear();
  89. inputs.push_back(NewValueNode(std::make_shared<Primitive>("scalar_add")));
  90. inputs.push_back(x1);
  91. inputs.push_back(y);
  92. CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
  93. inputs.clear();
  94. inputs.push_back(NewValueNode(std::make_shared<Primitive>("return")));
  95. inputs.push_back(cnode_add);
  96. CNodePtr cnode_return = graph_beta_->NewCNode(inputs);
  97. graph_beta_->set_return(cnode_return);
  98. // build func_graph alpha
  99. inputs.clear();
  100. inputs.push_back(NewValueNode(graph_beta_));
  101. inputs.push_back(x);
  102. CNodePtr cnode_graph_beta_ = graph_alpha_->NewCNode(inputs);
  103. inputs.clear();
  104. inputs.push_back(NewValueNode(prim_return));
  105. inputs.push_back(cnode_graph_beta_);
  106. cnode_return = graph_alpha_->NewCNode(inputs);
  107. graph_alpha_->set_return(cnode_return);
  108. }
  109. void TestSpecializeGraph::TearDown() {}
  110. TEST_F(TestSpecializeGraph, test_specialize) {
  111. AbstractBasePtrList args_spec_list;
  112. MS_LOG(INFO) << "Begin TestSpecializeGraph call other graph.";
  113. MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
  114. AbstractBasePtr abstract_v1 = FromValue(1, false);
  115. args_spec_list.push_back(abstract_v1);
  116. AnalysisResult result = engine_->Run(graph_f_, args_spec_list);
  117. FuncGraphPtr new_graph = special_->Run(graph_f_, result.context);
  118. }
  119. TEST_F(TestSpecializeGraph, test_specialize1) {
  120. AbstractBasePtrList args_spec_list;
  121. AbstractBasePtr abstract_v1 = FromValue(1, true);
  122. AbstractBasePtr abstract_v2 = FromValue(2, true);
  123. args_spec_list.push_back(abstract_v1);
  124. args_spec_list.push_back(abstract_v2);
  125. AnalysisResult result = engine_->Run(graph_alpha_, args_spec_list);
  126. draw::Draw("befor_graph_alpha.dot", graph_alpha_);
  127. FuncGraphPtr new_graph = special_->Run(graph_alpha_, result.context);
  128. if (new_graph) {
  129. draw::Draw("after_graph_alpha.dot", new_graph);
  130. }
  131. }
  132. class TestSpecializeMetaFuncGraph : public UT::Common {
  133. public:
  134. void SetUp();
  135. void TearDown();
  136. FuncGraphPtr graph_;
  137. std::shared_ptr<AnalysisEngine> engine_;
  138. std::shared_ptr<ProgramSpecializer> special_;
  139. };
  140. class MetaScalarAdd : public MetaFuncGraph {
  141. public:
  142. explicit MetaScalarAdd(std::string name) : MetaFuncGraph(name) {}
  143. ~MetaScalarAdd() {}
  144. /*
  145. * Generate a Graph for the given abstract arguments.
  146. */
  147. FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override {
  148. FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
  149. ParameterPtr x = graph_g->add_parameter();
  150. ParameterPtr y = graph_g->add_parameter();
  151. auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
  152. std::vector<AnfNodePtr> inputs;
  153. inputs.push_back(NewValueNode(prim_scalar_add));
  154. inputs.push_back(x);
  155. inputs.push_back(y);
  156. CNodePtr cnode_add = graph_g->NewCNode(inputs);
  157. auto prim_return = std::make_shared<Primitive>("return");
  158. inputs.clear();
  159. inputs.push_back(NewValueNode(prim_return));
  160. inputs.push_back(cnode_add);
  161. CNodePtr cnode_return = graph_g->NewCNode(inputs);
  162. graph_g->set_return(cnode_return);
  163. return graph_g;
  164. }
  165. };
  166. void TestSpecializeMetaFuncGraph::SetUp() {
  167. UT::InitPythonPath();
  168. // init resource
  169. engine_ = SetupAnalysisEngine();
  170. special_ = std::make_shared<ProgramSpecializer>(engine_);
  171. /*
  172. * def f(x, y):
  173. * return mata_scalar_add(x, y)
  174. */
  175. graph_ = std::make_shared<FuncGraph>();
  176. ParameterPtr x = graph_->add_parameter();
  177. ParameterPtr y = graph_->add_parameter();
  178. std::shared_ptr<MetaFuncGraph> meta_scalar_add = std::make_shared<MetaScalarAdd>("meta_scalar_add");
  179. std::vector<AnfNodePtr> inputs;
  180. inputs.push_back(NewValueNode(meta_scalar_add));
  181. inputs.push_back(x);
  182. inputs.push_back(y);
  183. CNodePtr cnode_add = graph_->NewCNode(inputs);
  184. auto prim_return = std::make_shared<Primitive>("return");
  185. inputs.clear();
  186. inputs.push_back(NewValueNode(prim_return));
  187. inputs.push_back(cnode_add);
  188. CNodePtr cnode_return = graph_->NewCNode(inputs);
  189. graph_->set_return(cnode_return);
  190. }
  191. void TestSpecializeMetaFuncGraph::TearDown() {}
  192. TEST_F(TestSpecializeMetaFuncGraph, test_specialize) {
  193. AbstractBasePtrList args_spec_list;
  194. std::cout << graph_->get_return()->ToString() << std::endl;
  195. AbstractBasePtr abstract_v1 = FromValue(1, true);
  196. AbstractBasePtr abstract_v2 = FromValue(2, true);
  197. args_spec_list.push_back(abstract_v1);
  198. args_spec_list.push_back(abstract_v2);
  199. AnalysisResult result = engine_->Run(graph_, args_spec_list);
  200. draw::Draw("befor_graph.dot", graph_);
  201. FuncGraphPtr new_graph = special_->Run(graph_, result.context);
  202. if (new_graph) {
  203. draw::Draw("after_graph.dot", new_graph);
  204. }
  205. }
  206. } // namespace abstract
  207. } // namespace mindspore