You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abstract_function.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
  19. #define PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
  20. #include <memory>
  21. #include <string>
  22. #include "pipeline/static_analysis/abstract_value.h"
  23. #include "pipeline/static_analysis/analysis_context.h"
  24. #include "ir/meta_func_graph.h"
  25. namespace mindspore {
  26. namespace abstract {
  27. class AbstractFuncAtom : public AbstractFunction {
  28. public:
  29. AbstractFuncAtom() = default;
  30. ~AbstractFuncAtom() override = default;
  31. MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)
  32. AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
  33. EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
  34. MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
  35. }
  36. AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
  37. void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
  38. bool operator==(const AbstractFunction &other) const override;
  39. std::size_t hash() const override { return tid(); }
  40. };
  41. class AbstractFuncUnion : public AbstractFunction {
  42. public:
  43. explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list);
  44. AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second);
  45. ~AbstractFuncUnion() override = default;
  46. MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction)
  47. std::string ToString() const override;
  48. AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
  49. EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
  50. MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
  51. }
  52. bool IsSuperSet(const AbstractFunctionPtr &other);
  53. AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
  54. void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
  55. bool operator==(const AbstractFunction &other) const override;
  56. std::size_t hash() const override;
  57. AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; }
  58. private:
  59. AbstractFuncAtomPtrList func_list_;
  60. };
  61. class PrimitiveAbstractClosure : public AbstractFuncAtom {
  62. public:
  63. // Represents a Primitive.
  64. // prim: The primitive
  65. // tracking_id: Identifies different uses of the same primitive.
  66. explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr)
  67. : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {}
  68. ~PrimitiveAbstractClosure() override = default;
  69. MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)
  70. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  71. PrimitivePtr prim() { return prim_; }
  72. AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
  73. void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }
  74. AbstractFunctionPtr Copy() const override { return std::make_shared<PrimitiveAbstractClosure>(prim_, tracking_id()); }
  75. bool operator==(const AbstractFunction &other) const override;
  76. std::size_t hash() const override;
  77. std::string ToString() const override { return "Prim: " + prim_->name(); }
  78. private:
  79. PrimitivePtr prim_;
  80. // store it as weak_ptr to break reference cycle.
  81. // one reference cycle example is Graph::set_output() input0 local variable.
  82. AnfNodeWeakPtr tracking_id_;
  83. };
  84. class FuncGraphAbstractClosure : public AbstractFuncAtom {
  85. public:
  86. // Represents a Graph in a certain Context.
  87. // context: The context, or Context.empty()
  88. FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
  89. : func_graph_(func_graph), context_(context) {
  90. MS_EXCEPTION_IF_NULL(func_graph);
  91. MS_EXCEPTION_IF_NULL(context);
  92. }
  93. ~FuncGraphAbstractClosure() override = default;
  94. MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)
  95. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  96. FuncGraphPtr func_graph() { return func_graph_; }
  97. AnalysisContextPtr context() const override { return context_; }
  98. AbstractFunctionPtr Copy() const override {
  99. return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_);
  100. }
  101. bool operator==(const AbstractFunction &other) const override;
  102. std::size_t hash() const override;
  103. std::string ToString() const override;
  104. private:
  105. FuncGraphPtr func_graph_;
  106. AnalysisContextPtr context_;
  107. };
  108. using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
  109. class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
  110. public:
  111. explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope)
  112. : meta_func_graph_(meta_func_graph), scope_(scope) {}
  113. ~MetaFuncGraphAbstractClosure() override = default;
  114. MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
  115. MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; }
  116. AnalysisContextPtr context() const override { return kDummyAnalysisContext; }
  117. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  118. ScopePtr GetScope() { return scope_; }
  119. AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
  120. bool operator==(const AbstractFunction &other) const override;
  121. std::size_t hash() const override;
  122. std::string ToString() const override;
  123. private:
  124. MetaFuncGraphPtr meta_func_graph_;
  125. ScopePtr scope_;
  126. };
  127. using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
  128. class PartialAbstractClosure : public AbstractFuncAtom {
  129. public:
  130. // Represents a partial application.
  131. // args_spec_list: The first few arguments of that function
  132. PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list)
  133. : fn_(fn), args_spec_list_(args_spec_list) {}
  134. ~PartialAbstractClosure() override = default;
  135. MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
  136. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  137. AbstractFunctionPtr fn() { return fn_; }
  138. AbstractBasePtrList args() { return args_spec_list_; }
  139. AbstractFunctionPtr Copy() const override { return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_); }
  140. bool operator==(const AbstractFunction &other) const override;
  141. std::size_t hash() const override;
  142. std::string ToString() const override;
  143. private:
  144. AbstractFuncAtomPtr fn_;
  145. AbstractBasePtrList args_spec_list_;
  146. };
  147. class JTransformedAbstractClosure : public AbstractFuncAtom {
  148. public:
  149. // Represents a Function transformed through the application of J.
  150. explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
  151. ~JTransformedAbstractClosure() override = default;
  152. MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
  153. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  154. AbstractFuncAtomPtr fn() { return fn_; }
  155. AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
  156. bool operator==(const AbstractFunction &other) const override;
  157. std::size_t hash() const override;
  158. std::string ToString() const override { return "J(" + fn_->ToString() + ")"; }
  159. private:
  160. AbstractFuncAtomPtr fn_;
  161. };
  162. class VirtualAbstractClosure : public AbstractFuncAtom {
  163. public:
  164. // Represents some function with an explicitly fixed type signature.
  165. // args_spec_list: The arguments as abstract value given to the function
  166. // output: The output which is abstract value.
  167. VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec)
  168. : args_spec_list_(args_spec_list), output_(output_spec) {}
  169. VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec)
  170. : args_spec_list_({args_spec}), output_(output_spec) {}
  171. ~VirtualAbstractClosure() override = default;
  172. MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)
  173. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  174. AbstractBasePtrList args_spec_list() { return args_spec_list_; }
  175. AbstractBasePtr output() { return output_; }
  176. AbstractFunctionPtr Copy() const override {
  177. return std::make_shared<VirtualAbstractClosure>(args_spec_list_, output_);
  178. }
  179. bool operator==(const AbstractFunction &other) const override;
  180. std::size_t hash() const override;
  181. std::string ToString() const override;
  182. private:
  183. AbstractBasePtrList args_spec_list_;
  184. AbstractBasePtr output_;
  185. };
  186. using VirtualAbstractClosurePtr = std::shared_ptr<VirtualAbstractClosure>;
  187. class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
  188. public:
  189. // Represents a Primitive with an explicitly fixed type signature.
  190. // args_spec_list: The arguments as abstract value given to the Primitive
  191. // output: The output which is abstract value.
  192. TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list,
  193. const AbstractBasePtr &output_spec)
  194. : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {}
  195. ~TypedPrimitiveAbstractClosure() override = default;
  196. MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)
  197. EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
  198. PrimitivePtr prim() { return prim_; }
  199. AbstractBasePtrList args_spec_list() { return args_spec_list_; }
  200. AbstractBasePtr output() { return output_; }
  201. AbstractFunctionPtr Copy() const override {
  202. return std::make_shared<TypedPrimitiveAbstractClosure>(prim_, args_spec_list_, output_);
  203. }
  204. bool operator==(const AbstractFunction &other) const override;
  205. std::size_t hash() const override;
  206. std::string ToString() const override;
  207. private:
  208. PrimitivePtr prim_;
  209. AbstractBasePtrList args_spec_list_;
  210. AbstractBasePtr output_;
  211. };
  212. // Represents a function that can't be called.
  213. class DummyAbstractClosure : public AbstractFuncAtom {
  214. public:
  215. DummyAbstractClosure() = default;
  216. ~DummyAbstractClosure() override = default;
  217. MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
  218. EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }
  219. AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); }
  220. bool operator==(const AbstractFunction &other) const override;
  221. std::string ToString() const override { return "DummyAbstractClosure()"; }
  222. };
  223. struct AbstractFunctionHasher {
  224. std::size_t operator()(const AbstractFunctionPtr &t) const {
  225. std::size_t hash = t->hash();
  226. return hash;
  227. }
  228. };
  229. struct AbstractFunctionEqual {
  230. bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; }
  231. };
  232. } // namespace abstract
  233. } // namespace mindspore
  234. #endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_