You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

static_analysis.h 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
  19. #define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
  20. #include <list>
  21. #include <memory>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <vector>
  25. #include <utility>
  26. #ifdef DEBUG
  27. #include <stack>
  28. #endif
  29. #include "utils/log_adapter.h"
  30. #include "ir/anf.h"
  31. #include "ir/primitive.h"
  32. #include "pipeline/static_analysis/analysis_context.h"
  33. #include "pipeline/static_analysis/abstract_function.h"
  34. #include "pipeline/parse/parse.h"
  35. namespace mindspore {
  36. namespace abstract {
  37. // Superclass for AnfNodeConfig and VirtualConfig.
  38. class Config : public Base {
  39. public:
  40. Config() = default;
  41. ~Config() override = default;
  42. MS_DECLARE_PARENT(Config, Base);
  43. virtual AbstractBasePtr GetEvaluatedValue() = 0;
  44. };
  45. // Config will be stored in AnalysisCache
  46. using ConfigPtr = std::shared_ptr<Config>;
  47. using ConfigPtrList = std::vector<ConfigPtr>;
  48. // Config to a certain node in a certain context.
  49. class AnfNodeConfig : public Config {
  50. public:
  51. AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
  52. : Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node) {
  53. FuncGraphPtr fg;
  54. if (IsValueNode<FuncGraph>(node)) {
  55. auto v = node->cast<ValueNodePtr>();
  56. fg = v->value()->cast<FuncGraphPtr>();
  57. } else {
  58. fg = node->func_graph();
  59. }
  60. context_ = nullptr;
  61. if (context != nullptr) {
  62. context_ = context->Filter(fg);
  63. }
  64. }
  65. ~AnfNodeConfig() override = default;
  66. MS_DECLARE_PARENT(AnfNodeConfig, Config);
  67. AbstractBasePtr GetEvaluatedValue() override;
  68. AnalysisContextPtr context() const { return context_; }
  69. AnfNodePtr node() const { return node_; }
  70. AnalysisEnginePtr engine() const { return engine_.lock(); }
  71. // used by unordered_map;
  72. bool operator==(const AnfNodeConfig &other) const {
  73. // compare node with pointer, context with content;
  74. // context should not be nullptr;
  75. return (node_ == other.node_) && (*context_ == *other.context_);
  76. }
  77. std::string ToString() const override {
  78. std::ostringstream buffer;
  79. buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString();
  80. return buffer.str();
  81. }
  82. private:
  83. // AnalysisEngine is global.
  84. // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use
  85. // weak_ptr to break Config cycle.
  86. std::weak_ptr<AnalysisEngine> engine_;
  87. AnfNodePtr node_;
  88. AnalysisContextPtr context_;
  89. };
  90. using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
  91. struct AnfNodeConfigHasher {
  92. std::size_t operator()(const AnfNodeConfigPtr conf) const;
  93. };
  94. struct AnfNodeConfigEqual {
  95. bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const;
  96. };
  97. class VirtualConfig : public Config {
  98. public:
  99. explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {}
  100. ~VirtualConfig() override = default;
  101. MS_DECLARE_PARENT(VirtualConfig, Config);
  102. AbstractBasePtr GetEvaluatedValue() override { return abstract_; }
  103. private:
  104. AbstractBasePtr abstract_;
  105. };
  106. // AnalysisCache
  107. class AnalysisCache {
  108. public:
  109. AnalysisCache() = default;
  110. ~AnalysisCache() = default;
  111. void Clear() { cache_.clear(); }
  112. void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
  113. AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf);
  114. private:
  115. std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
  116. };
  117. using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
  118. using AnfNodeConfigMap =
  119. std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
  120. struct AnalysisResult {
  121. AbstractBasePtr inferred;
  122. AnalysisContextPtr context;
  123. };
  124. class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
  125. public:
  126. AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
  127. : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {}
  128. ~AnalysisEngine() = default;
  129. // func_graph: The func_graph to analyze.
  130. // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
  131. AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
  132. AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
  133. // Return the Evaluator for the given function.
  134. EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
  135. AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
  136. AbstractBasePtr InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
  137. // Infer the result of fn(args).
  138. AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
  139. void Clear();
  140. void ClearEvaluatorCache();
  141. AnalysisCache &cache() { return cache_; }
  142. AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
  143. return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
  144. }
  145. // Overloaded function.
  146. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn);
  147. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &fn);
  148. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &fn);
  149. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &fn);
  150. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
  151. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
  152. EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
  153. FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
  154. const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }
  155. // Set the analysis result for orig to the result for new.
  156. // This sets an entry in anfnode_config_map from orig to new.
  157. AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
  158. // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
  159. (void)anfnode_config_map_.emplace(orig_conf, new_conf);
  160. MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
  161. << ", to new_conf: " << new_conf->node()->DebugString();
  162. return GetEvaluatedValue(new_conf);
  163. }
  164. const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
  165. AnalysisCache cache_;
  166. std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
  167. private:
  168. const PrimEvaluatorMap &prim_constructors_;
  169. FuncGraphManagerPtr func_graph_manager_;
  170. std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;
  171. AnfNodeConfigMap anfnode_config_map_;
  172. // Use a list to trace multiple evaluators.
  173. std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
  174. AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
  175. const ConfigPtrList &args_conf_list);
  176. AbstractBasePtr Eval(const AnfNodeConfigPtr &conf);
  177. EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
  178. AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
  179. const ConfigPtrList &args_conf_list);
  180. AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
  181. const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list);
  182. #ifdef DEBUG
  183. std::vector<AnfNodePtr> compute_conf_stack_;
  184. #endif
  185. };
  186. // Translate the value to an abstract value.
  187. // Arguments:
  188. // value: The value to convert.
  189. // context: The context in which the value was found, used if the value is a Graph.
  190. // conf: The Config to the valuenode we are converting, if there is one,
  191. // so that we can generate a tracking_id.
  192. AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr,
  193. const AnfNodeConfigPtr &conf = nullptr);
  194. // Convert a value to an abstract value.
  195. // Arguments:
  196. // v: The value to convert.
  197. // broaden: If True, concrete values will be made more abstract, so e.g.
  198. // the value 1234 would become ANYTHING.
  199. AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false);
  200. template <typename T>
  201. AbstractBasePtr FromValue(const T &value, bool broaden = false) {
  202. return FromValueInside(MakeValue(value), broaden);
  203. }
  204. AbstractBasePtr InferOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
  205. } // namespace abstract
  206. } // namespace mindspore
  207. #endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_