You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluator.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_
  19. #define PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_
  20. #include <memory>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <vector>
  24. #include "pipeline/static_analysis/static_analysis.h"
  25. namespace mindspore {
  26. namespace abstract {
  27. using EvaluatorCacheMap =
  28. std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
  29. using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
  30. class Evaluator : public Base {
  31. public:
  32. explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {}
  33. ~Evaluator() override = default;
  34. MS_DECLARE_PARENT(Evaluator, Base);
  35. // difference between Run() and Infer():
  36. // Run() will be called with ConfigPtrList, but Infer() will be called with AbstractBasePtr.
  37. // Run() will modify cache_ member, so it cannot marked as const;
  38. virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
  39. virtual AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
  40. virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
  41. std::string ToString() const override { return identifier_; }
  42. virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }
  43. virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
  44. EvaluatorCacheMapPtr &cache() { return cache_; }
  45. EvaluatorCacheMapPtr cache_;
  46. std::string identifier_;
  47. AnfNodeWeakPtr bound_node_;
  48. };
  49. class PrimEvaluator : public Evaluator {
  50. public:
  51. explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
  52. ~PrimEvaluator() override = default;
  53. MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
  54. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) final {
  55. MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
  56. }
  57. };
  58. class TrivialPrimEvaluator : public PrimEvaluator {
  59. public:
  60. explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
  61. ~TrivialPrimEvaluator() override = default;
  62. MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
  63. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
  64. virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
  65. };
  66. class TransitionPrimEvaluator : public PrimEvaluator {
  67. public:
  68. explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
  69. ~TransitionPrimEvaluator() override = default;
  70. MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
  71. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
  72. // Parameter in_conf0 : the first element in args_conf_list;
  73. virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
  74. const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
  75. };
  76. class SymbolicPrimEvaluator : public PrimEvaluator {
  77. public:
  78. explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
  79. ~SymbolicPrimEvaluator() override = default;
  80. MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
  81. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
  82. virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
  83. };
  84. // Evaluator will be stored in AnalysisEngine.constructors_
  85. using EvaluatorPtrList = std::vector<EvaluatorPtr>;
  86. class DummyEvaluator : public Evaluator {
  87. public:
  88. DummyEvaluator() : Evaluator("dummy") {}
  89. ~DummyEvaluator() override = default;
  90. MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
  91. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
  92. };
  93. // Wrap another evaluator to track a subset of uses.
  94. // A TrackedEvaluator has its own cache that maps possible calls to
  95. // their results, but is ultimately backed by a different evaluator.
  96. // Multiple TrackedEvaluators can be backed by the same Evaluator.
  97. class TrackedEvaluator : public Evaluator {
  98. public:
  99. explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {}
  100. ~TrackedEvaluator() override = default;
  101. MS_DECLARE_PARENT(TrackedEvaluator, Evaluator);
  102. AnfNodePtr bound_node() const override {
  103. if (sub_evaluator_ != nullptr) {
  104. return sub_evaluator_->bound_node();
  105. }
  106. return bound_node_.lock();
  107. }
  108. void set_bound_node(const AnfNodePtr &node) override {
  109. if (sub_evaluator_ != nullptr) {
  110. sub_evaluator_->set_bound_node(node);
  111. }
  112. bound_node_ = AnfNodeWeakPtr(node);
  113. }
  114. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  115. MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
  116. }
  117. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  118. AnfNodeConfigPtr out_conf) override;
  119. std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
  120. private:
  121. EvaluatorPtr sub_evaluator_;
  122. };
  123. class BaseFuncGraphEvaluator : public Evaluator {
  124. public:
  125. explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context)
  126. : Evaluator("basegraph"), parent_context_(context) {}
  127. ~BaseFuncGraphEvaluator() override = default;
  128. MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
  129. AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
  130. virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
  131. AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list);
  132. AnalysisContextPtr graph_context() const { return graph_context_; }
  133. protected:
  134. AnalysisContextPtr parent_context_;
  135. private:
  136. AnalysisContextPtr graph_context_;
  137. };
  138. class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
  139. public:
  140. FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
  141. : BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {}
  142. ~FuncGraphEvaluator() override = default;
  143. MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
  144. FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
  145. FuncGraphPtr func_graph() { return func_graph_; }
  146. AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override;
  147. std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); }
  148. private:
  149. FuncGraphPtr func_graph_;
  150. std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
  151. func_graph_cache_;
  152. };
  153. using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
  154. class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator {
  155. public:
  156. // Note: context parameter is not used;
  157. MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, AnalysisContextPtr, const ScopePtr &scope)
  158. : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {}
  159. ~MetaFuncGraphEvaluator() override = default;
  160. MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator);
  161. FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
  162. // Return normalized versions of the arguments.
  163. AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override {
  164. return meta_func_graph_->NormalizeArgs(args_spec_list);
  165. }
  166. std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); }
  167. private:
  168. MetaFuncGraphPtr meta_func_graph_;
  169. std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
  170. func_graph_cache_;
  171. ScopePtr scope_;
  172. };
  173. class PartialAppEvaluator : public Evaluator {
  174. public:
  175. PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args)
  176. : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {}
  177. ~PartialAppEvaluator() override = default;
  178. MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator);
  179. AnfNodePtr bound_node() const override {
  180. if (evaluator_ != nullptr) {
  181. return evaluator_->bound_node();
  182. }
  183. return bound_node_.lock();
  184. }
  185. void set_bound_node(const AnfNodePtr &node) override {
  186. if (evaluator_ != nullptr) {
  187. evaluator_->set_bound_node(node);
  188. }
  189. bound_node_ = AnfNodeWeakPtr(node);
  190. }
  191. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  192. MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
  193. }
  194. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  195. AnfNodeConfigPtr out_conf) override;
  196. std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
  197. private:
  198. EvaluatorPtr evaluator_;
  199. AbstractBasePtrList args_spec_list_;
  200. };
  201. class VirtualEvaluator : public Evaluator {
  202. public:
  203. VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output)
  204. : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {}
  205. ~VirtualEvaluator() override = default;
  206. MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
  207. AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
  208. std::string ToString() const override { return identifier_; }
  209. private:
  210. AbstractBasePtrList args_spec_list_;
  211. AbstractBasePtr output_;
  212. };
  213. class JEvaluator : public Evaluator {
  214. public:
  215. JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
  216. : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
  217. ~JEvaluator() override = default;
  218. MS_DECLARE_PARENT(JEvaluator, Evaluator);
  219. AnfNodePtr bound_node() const override {
  220. if (evaluator_ != nullptr) {
  221. return evaluator_->bound_node();
  222. }
  223. return bound_node_.lock();
  224. }
  225. void set_bound_node(const AnfNodePtr &node) override {
  226. if (evaluator_ != nullptr) {
  227. evaluator_->set_bound_node(node);
  228. }
  229. bound_node_ = AnfNodeWeakPtr(node);
  230. }
  231. AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
  232. MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
  233. }
  234. AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
  235. AnfNodeConfigPtr out_conf) override;
  236. std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
  237. private:
  238. EvaluatorPtr evaluator_;
  239. AbstractFunctionPtr orig_func_;
  240. };
  241. } // namespace abstract
  242. } // namespace mindspore
  243. #endif // PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_