You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abstract_function.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
  19. #define PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
  20. #include <memory>
  21. #include <string>
  22. #include "pipeline/static_analysis/abstract_value.h"
  23. #include "pipeline/static_analysis/analysis_context.h"
  24. #include "ir/meta_func_graph.h"
  25. namespace mindspore {
  26. namespace abstract {
  27. class AbstractFuncAtom : public AbstractFunction {
  28. public:
  29. AbstractFuncAtom() = default;
  30. ~AbstractFuncAtom() override = default;
  31. MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)
  32. AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
  33. EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
  34. MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
  35. }
  36. AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
  37. void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
  38. bool operator==(const AbstractFunction &other) const override;
  39. std::size_t hash() const override { return tid(); }
  40. };
  41. class AbstractFuncUnion : public AbstractFunction {
  42. public:
  43. explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list);
  44. AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second);
  45. ~AbstractFuncUnion() override = default;
  46. MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction)
  47. std::string ToString() const override;
  48. AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
  49. EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
  50. MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
  51. }
  52. bool IsSuperSet(const AbstractFunctionPtr &other);
  53. AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
  54. void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
  55. bool operator==(const AbstractFunction &other) const override;
  56. std::size_t hash() const override;
  57. AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; }
  58. private:
  59. AbstractFuncAtomPtrList func_list_;
  60. };
  61. class PrimitiveAbstractClosure : public AbstractFuncAtom {
  62. public:
  63. // Represents a Primitive.
  64. // prim: The primitive
  65. // tracking_id: Identifies different uses of the same primitive.
  66. explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr)
  67. : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {}
  68. ~PrimitiveAbstractClosure() override = default;
  69. MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)
  70. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  71. PrimitivePtr prim() { return prim_; }
  72. AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
  73. void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }
  74. AbstractFunctionPtr Copy() const override { return std::make_shared<PrimitiveAbstractClosure>(prim_, tracking_id()); }
  75. bool operator==(const AbstractFunction &other) const override;
  76. std::size_t hash() const override;
  77. std::string ToString() const override { return "Prim: " + prim_->name(); }
  78. private:
  79. PrimitivePtr prim_;
  80. // store it as weak_ptr to break reference cycle.
  81. // one reference cycle example is Graph::set_output() input0 local variable.
  82. AnfNodeWeakPtr tracking_id_;
  83. };
  84. class FuncGraphAbstractClosure : public AbstractFuncAtom {
  85. public:
  86. // Represents a Graph in a certain Context.
  87. // context: The context, or Context.empty()
  88. FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
  89. : func_graph_(func_graph), context_(context) {
  90. MS_EXCEPTION_IF_NULL(func_graph);
  91. MS_EXCEPTION_IF_NULL(context);
  92. }
  93. ~FuncGraphAbstractClosure() override = default;
  94. MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)
  95. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  96. FuncGraphPtr func_graph() { return func_graph_; }
  97. AnalysisContextPtr context() const override { return context_; }
  98. AbstractFunctionPtr Copy() const override {
  99. return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_);
  100. }
  101. bool operator==(const AbstractFunction &other) const override;
  102. std::size_t hash() const override;
  103. std::string ToString() const override;
  104. private:
  105. FuncGraphPtr func_graph_;
  106. AnalysisContextPtr context_;
  107. };
  108. using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
  109. class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
  110. public:
  111. explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope)
  112. : meta_func_graph_(meta_func_graph), scope_(scope) {}
  113. ~MetaFuncGraphAbstractClosure() override = default;
  114. MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
  115. MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; }
  116. AnalysisContextPtr context() const override { return kDummyAnalysisContext; }
  117. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  118. ScopePtr GetScope() { return scope_; }
  119. AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
  120. bool operator==(const AbstractFunction &other) const override;
  121. std::size_t hash() const override;
  122. std::string ToString() const override;
  123. private:
  124. MetaFuncGraphPtr meta_func_graph_;
  125. ScopePtr scope_;
  126. };
  127. using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
  128. class PartialAbstractClosure : public AbstractFuncAtom {
  129. public:
  130. // Represents a partial application.
  131. // args_spec_list: The first few arguments of that function
  132. PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
  133. const AnfNodePtr &node = nullptr)
  134. : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
  135. ~PartialAbstractClosure() override = default;
  136. MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
  137. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  138. AbstractFunctionPtr fn() { return fn_; }
  139. AbstractBasePtrList args() { return args_spec_list_; }
  140. AnfNodePtr node() { return node_.lock(); }
  141. void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
  142. AbstractFunctionPtr Copy() const override {
  143. return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock());
  144. }
  145. bool operator==(const AbstractFunction &other) const override;
  146. std::size_t hash() const override;
  147. std::string ToString() const override;
  148. private:
  149. AbstractFuncAtomPtr fn_;
  150. AbstractBasePtrList args_spec_list_;
  151. // The CNode which this PartialAbstractClosure evaluated from.
  152. AnfNodeWeakPtr node_;
  153. };
  154. class JTransformedAbstractClosure : public AbstractFuncAtom {
  155. public:
  156. // Represents a Function transformed through the application of J.
  157. explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
  158. ~JTransformedAbstractClosure() override = default;
  159. MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
  160. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  161. AbstractFuncAtomPtr fn() { return fn_; }
  162. AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
  163. bool operator==(const AbstractFunction &other) const override;
  164. std::size_t hash() const override;
  165. std::string ToString() const override { return "J(" + fn_->ToString() + ")"; }
  166. private:
  167. AbstractFuncAtomPtr fn_;
  168. };
  169. class VirtualAbstractClosure : public AbstractFuncAtom {
  170. public:
  171. // Represents some function with an explicitly fixed type signature.
  172. // args_spec_list: The arguments as abstract value given to the function
  173. // output: The output which is abstract value.
  174. VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec)
  175. : args_spec_list_(args_spec_list), output_(output_spec) {}
  176. VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec)
  177. : args_spec_list_({args_spec}), output_(output_spec) {}
  178. ~VirtualAbstractClosure() override = default;
  179. MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)
  180. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  181. AbstractBasePtrList args_spec_list() { return args_spec_list_; }
  182. AbstractBasePtr output() { return output_; }
  183. AbstractFunctionPtr Copy() const override {
  184. return std::make_shared<VirtualAbstractClosure>(args_spec_list_, output_);
  185. }
  186. bool operator==(const AbstractFunction &other) const override;
  187. std::size_t hash() const override;
  188. std::string ToString() const override;
  189. private:
  190. AbstractBasePtrList args_spec_list_;
  191. AbstractBasePtr output_;
  192. };
  193. using VirtualAbstractClosurePtr = std::shared_ptr<VirtualAbstractClosure>;
  194. class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
  195. public:
  196. // Represents a Primitive with an explicitly fixed type signature.
  197. // args_spec_list: The arguments as abstract value given to the Primitive
  198. // output: The output which is abstract value.
  199. TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list,
  200. const AbstractBasePtr &output_spec)
  201. : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {}
  202. ~TypedPrimitiveAbstractClosure() override = default;
  203. MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)
  204. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  205. PrimitivePtr prim() { return prim_; }
  206. AbstractBasePtrList args_spec_list() { return args_spec_list_; }
  207. AbstractBasePtr output() { return output_; }
  208. AbstractFunctionPtr Copy() const override {
  209. return std::make_shared<TypedPrimitiveAbstractClosure>(prim_, args_spec_list_, output_);
  210. }
  211. bool operator==(const AbstractFunction &other) const override;
  212. std::size_t hash() const override;
  213. std::string ToString() const override;
  214. private:
  215. PrimitivePtr prim_;
  216. AbstractBasePtrList args_spec_list_;
  217. AbstractBasePtr output_;
  218. };
  219. // Represents a function that can't be called.
  220. class DummyAbstractClosure : public AbstractFuncAtom {
  221. public:
  222. DummyAbstractClosure() = default;
  223. ~DummyAbstractClosure() override = default;
  224. MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
  225. EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }
  226. AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); }
  227. bool operator==(const AbstractFunction &other) const override;
  228. std::string ToString() const override { return "DummyAbstractClosure()"; }
  229. };
  230. struct AbstractFunctionHasher {
  231. std::size_t operator()(const AbstractFunctionPtr &t) const {
  232. std::size_t hash = t->hash();
  233. return hash;
  234. }
  235. };
  236. struct AbstractFunctionEqual {
  237. bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; }
  238. };
  239. } // namespace abstract
  240. } // namespace mindspore
  241. #endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_