You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specialize_test.cc 7.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "ir/manager.h"
  21. #include "pipeline/jit/static_analysis/prim.h"
  22. #include "pipeline/jit/static_analysis/program_specialize.h"
  23. #include "pipeline/static_analysis/helper.h"
  24. #include "utils/log_adapter.h"
  25. #include "ir/graph_utils.h"
  26. #include "utils/misc.h"
  27. #include "debug/draw.h"
  28. #include "base/core_ops.h"
  29. namespace mindspore {
  30. namespace abstract {
  31. class TestSpecializeGraph : public UT::Common {
  32. public:
  33. void SetUp();
  34. void TearDown();
  35. // f(x) call g(x)
  36. FuncGraphPtr graph_f_;
  37. FuncGraphPtr graph_g_;
  38. // alpha(x) return beta(x) closure;
  39. FuncGraphPtr graph_alpha_;
  40. FuncGraphPtr graph_beta_;
  41. std::shared_ptr<AnalysisEngine> engine_;
  42. std::shared_ptr<ProgramSpecializer> special_;
  43. };
  44. void TestSpecializeGraph::SetUp() {
  45. UT::InitPythonPath();
  46. // init resource
  47. engine_ = SetupAnalysisEngine();
  48. special_ = std::make_shared<ProgramSpecializer>(engine_);
  49. /*
  50. * def g(y):
  51. * return y;
  52. */
  53. graph_g_ = std::make_shared<FuncGraph>();
  54. ParameterPtr y = graph_g_->add_parameter();
  55. auto prim_return = std::make_shared<Primitive>("Return");
  56. std::vector<AnfNodePtr> inputs;
  57. inputs.push_back(NewValueNode(prim_return));
  58. inputs.push_back(y);
  59. CNodePtr cnode_g_ret = graph_g_->NewCNode(inputs);
  60. graph_g_->set_return(cnode_g_ret);
  61. /*
  62. * def f(x):
  63. * return g(x)
  64. */
  65. graph_f_ = std::make_shared<FuncGraph>();
  66. ParameterPtr x = graph_f_->add_parameter();
  67. inputs.clear();
  68. inputs.push_back(NewValueNode(graph_g_));
  69. inputs.push_back(x);
  70. CNodePtr cnode_f = graph_f_->NewCNode(inputs);
  71. inputs.clear();
  72. inputs.push_back(NewValueNode(prim_return));
  73. inputs.push_back(cnode_f);
  74. CNodePtr cnode_f_ret = graph_f_->NewCNode(inputs);
  75. graph_f_->set_return(cnode_f_ret);
  76. /* build a closure func_graph */
  77. /*
  78. *def alpha(x, y):
  79. * def beta(x1):
  80. * return x1 + y
  81. * return beta(x)
  82. */
  83. graph_alpha_ = std::make_shared<FuncGraph>();
  84. graph_beta_ = std::make_shared<FuncGraph>();
  85. x = graph_alpha_->add_parameter();
  86. y = graph_alpha_->add_parameter();
  87. // build func_graph beta
  88. ParameterPtr x1 = graph_beta_->add_parameter();
  89. inputs.clear();
  90. inputs.push_back(NewValueNode(std::make_shared<Primitive>(prim::kScalarAdd)));
  91. inputs.push_back(x1);
  92. inputs.push_back(y);
  93. CNodePtr cnode_add = graph_beta_->NewCNode(inputs);
  94. inputs.clear();
  95. inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
  96. inputs.push_back(cnode_add);
  97. CNodePtr cnode_return = graph_beta_->NewCNode(inputs);
  98. graph_beta_->set_return(cnode_return);
  99. // build func_graph alpha
  100. inputs.clear();
  101. inputs.push_back(NewValueNode(graph_beta_));
  102. inputs.push_back(x);
  103. CNodePtr cnode_graph_beta_ = graph_alpha_->NewCNode(inputs);
  104. inputs.clear();
  105. inputs.push_back(NewValueNode(prim_return));
  106. inputs.push_back(cnode_graph_beta_);
  107. cnode_return = graph_alpha_->NewCNode(inputs);
  108. graph_alpha_->set_return(cnode_return);
  109. }
  110. void TestSpecializeGraph::TearDown() {}
  111. TEST_F(TestSpecializeGraph, test_specialize) {
  112. AbstractBasePtrList args_spec_list;
  113. MS_LOG(INFO) << "Begin TestSpecializeGraph call other graph.";
  114. MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
  115. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
  116. args_spec_list.push_back(abstract_v1);
  117. AnalysisResult result = engine_->Run(graph_f_, args_spec_list);
  118. FuncGraphPtr new_graph = special_->Run(graph_f_, result.context);
  119. }
  120. TEST_F(TestSpecializeGraph, test_specialize1) {
  121. AbstractBasePtrList args_spec_list;
  122. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), true);
  123. AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), true);
  124. args_spec_list.push_back(abstract_v1);
  125. args_spec_list.push_back(abstract_v2);
  126. AnalysisResult result = engine_->Run(graph_alpha_, args_spec_list);
  127. draw::Draw("befor_graph_alpha.dot", graph_alpha_);
  128. FuncGraphPtr new_graph = special_->Run(graph_alpha_, result.context);
  129. if (new_graph) {
  130. draw::Draw("after_graph_alpha.dot", new_graph);
  131. }
  132. }
  133. class TestSpecializeMetaFuncGraph : public UT::Common {
  134. public:
  135. void SetUp();
  136. void TearDown();
  137. FuncGraphPtr graph_;
  138. std::shared_ptr<AnalysisEngine> engine_;
  139. std::shared_ptr<ProgramSpecializer> special_;
  140. };
  141. class MetaScalarAdd : public MetaFuncGraph {
  142. public:
  143. explicit MetaScalarAdd(std::string name) : MetaFuncGraph(name) {}
  144. ~MetaScalarAdd() {}
  145. /*
  146. * Generate a Graph for the given abstract arguments.
  147. */
  148. FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override {
  149. FuncGraphPtr graph_g = std::make_shared<FuncGraph>();
  150. ParameterPtr x = graph_g->add_parameter();
  151. ParameterPtr y = graph_g->add_parameter();
  152. auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
  153. std::vector<AnfNodePtr> inputs;
  154. inputs.push_back(NewValueNode(prim_scalar_add));
  155. inputs.push_back(x);
  156. inputs.push_back(y);
  157. CNodePtr cnode_add = graph_g->NewCNode(inputs);
  158. auto prim_return = std::make_shared<Primitive>("Return");
  159. inputs.clear();
  160. inputs.push_back(NewValueNode(prim_return));
  161. inputs.push_back(cnode_add);
  162. CNodePtr cnode_return = graph_g->NewCNode(inputs);
  163. graph_g->set_return(cnode_return);
  164. return graph_g;
  165. }
  166. };
  167. void TestSpecializeMetaFuncGraph::SetUp() {
  168. UT::InitPythonPath();
  169. // init resource
  170. engine_ = SetupAnalysisEngine();
  171. special_ = std::make_shared<ProgramSpecializer>(engine_);
  172. /*
  173. * def f(x, y):
  174. * return mata_scalar_add(x, y)
  175. */
  176. graph_ = std::make_shared<FuncGraph>();
  177. ParameterPtr x = graph_->add_parameter();
  178. ParameterPtr y = graph_->add_parameter();
  179. std::shared_ptr<MetaFuncGraph> meta_scalar_add = std::make_shared<MetaScalarAdd>("meta_scalar_add");
  180. std::vector<AnfNodePtr> inputs;
  181. inputs.push_back(NewValueNode(meta_scalar_add));
  182. inputs.push_back(x);
  183. inputs.push_back(y);
  184. CNodePtr cnode_add = graph_->NewCNode(inputs);
  185. auto prim_return = std::make_shared<Primitive>("Return");
  186. inputs.clear();
  187. inputs.push_back(NewValueNode(prim_return));
  188. inputs.push_back(cnode_add);
  189. CNodePtr cnode_return = graph_->NewCNode(inputs);
  190. graph_->set_return(cnode_return);
  191. }
  192. void TestSpecializeMetaFuncGraph::TearDown() {}
  193. TEST_F(TestSpecializeMetaFuncGraph, test_specialize) {
  194. AbstractBasePtrList args_spec_list;
  195. std::cout << graph_->get_return()->ToString() << std::endl;
  196. AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), true);
  197. AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), true);
  198. args_spec_list.push_back(abstract_v1);
  199. args_spec_list.push_back(abstract_v2);
  200. AnalysisResult result = engine_->Run(graph_, args_spec_list);
  201. draw::Draw("befor_graph.dot", graph_);
  202. FuncGraphPtr new_graph = special_->Run(graph_, result.context);
  203. if (new_graph) {
  204. draw::Draw("after_graph.dot", new_graph);
  205. }
  206. }
  207. } // namespace abstract
  208. } // namespace mindspore