You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  19. #define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include <list>
  24. #include <memory>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include "ir/anf.h"
  28. #include "ir/manager.h"
  29. #include "utils/any.h"
  30. #include "utils/ordered_set.h"
  31. #include "pipeline/static_analysis/abstract_value.h"
  32. namespace mindspore {
  33. using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
  34. using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
  35. using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
  36. const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
  37. const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
  38. const char FUNC_GRAPH_FLAG_CORE[] = "core";
  39. // ANF transform class
  40. // either a primitive or a func_graph
  41. class FuncGraphTransform {
  42. public:
  43. enum Type { kGtPrimitive, kGtFuncGraph };
  44. explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr)
  45. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  46. explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_)
  47. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  48. FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {}
  49. ~FuncGraphTransform() = default;
  50. Type type() const {
  51. if (IsFuncGraph()) {
  52. return kGtFuncGraph;
  53. } else {
  54. return kGtPrimitive;
  55. }
  56. }
  57. bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
  58. bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
  59. FuncGraphPtr func_graph() const { return func_graph_.lock(); }
  60. PrimitivePtr primitive() const { return prim_; }
  61. FuncGraphTransform &operator=(const FuncGraphTransform &t) {
  62. if (this != &t) {
  63. prim_ = t.prim_;
  64. func_graph_ = t.func_graph_;
  65. }
  66. return *this;
  67. }
  68. private:
  69. PrimitivePtr prim_;
  70. // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
  71. // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
  72. // FPropRemapper::FinalizeGraph().
  73. FuncGraphWeakPtr func_graph_;
  74. static const PrimitivePtr func_graph_prim_;
  75. };
  76. class FuncGraphBase : public Value {
  77. public:
  78. FuncGraphBase() = default;
  79. ~FuncGraphBase() override = default;
  80. MS_DECLARE_PARENT(FuncGraphBase, Value);
  81. };
  82. extern const char kFuncGraphFlagUndetermined[];
  83. class FuncGraph : public FuncGraphBase {
  84. public:
  85. FuncGraph();
  86. ~FuncGraph() override = default;
  87. MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
  88. // get the graph's abstract
  89. abstract::AbstractFunctionPtr abstract();
  90. abstract::AbstractBasePtr MakeAbstractClosure(const abstract::AnalysisContextPtr &context);
  91. // return the graph's output, or nullptr if not yet deduced
  92. AnfNodePtr output() const;
  93. void set_output(const AnfNodePtr &value, bool force_new_ret = false);
  94. const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
  95. virtual ParameterPtr add_parameter();
  96. void add_parameter(const ParameterPtr &p);
  97. void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
  98. // add a weight parameter with specific name
  99. ParameterPtr AddWeightParameter(const std::string &name);
  100. // create a cnode with given inputs, bound to this graph
  101. virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
  102. // create a cnode with given inputs, bound to this graph, and set to specific scope
  103. CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
  104. // Functions for handling variable argument, keyword-only arguments and variable keyword argument
  105. AnfNodePtr GetDefaultValueByName(const std::string &name);
  106. void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
  107. parameter_default_value_[name] = node;
  108. }
  109. void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
  110. void ClearDefaultValues();
  111. size_t GetDefaultValueCount();
  112. std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
  113. void set_has_vararg(bool has_) { has_vararg_ = has_; }
  114. bool has_vararg() const { return has_vararg_; }
  115. AnfNodePtr GetVariableArgParameter();
  116. std::string GetVariableArgName();
  117. void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
  118. bool has_kwarg() const { return has_kwarg_; }
  119. void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
  120. int kwonlyargs_count() const { return kwonlyargs_count_; }
  121. AnfNodePtr GetVariableKwargParameter();
  122. std::string GetVariableKwargName();
  123. void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
  124. size_t hyper_param_count() const { return hyper_param_count_; }
  125. int GetPositionalArgsCount() const;
  126. AnfNodePtr GetParameterByName(const std::string &name);
  127. bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
  128. FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
  129. void set_is_generate(bool generated) { is_generated_ = generated; }
  130. bool is_generated() const { return is_generated_; }
  131. bool has_flag(const std::string &flag);
  132. std::unordered_map<std::string, bool> &flags() { return flags_; }
  133. void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; }
  134. void set_flags(const std::string &key, const bool value) { flags_[key] = value; }
  135. std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
  136. void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
  137. transforms_ = transforms;
  138. }
  139. CNodePtr get_return() const { return return_; }
  140. void set_return(const CNodePtr &cnode) { return_ = cnode; }
  141. FuncGraphManagerPtr manager() const { return manager_.lock(); }
  142. void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
  143. std::string ToString() const override;
  144. GraphDebugInfoPtr debug_info();
  145. void set_debug_info(const GraphDebugInfoPtr &info) {
  146. if (info == nullptr) {
  147. MS_LOG(EXCEPTION) << "Graph set null debug info";
  148. }
  149. this->debug_info_ = info;
  150. }
  151. // get all nodes belonging to this func graph
  152. const AnfNodeSet &nodes();
  153. // get all value_nodes belonging to this func graph
  154. const AnfNodeCounterMap &value_nodes();
  155. // get all vars directly pointed to in this func graph
  156. const AnfNodeCounterMap &free_variables_direct();
  157. // get all vars required by this func graph
  158. const BaseRefCounterMap &free_variables_total();
  159. // Return the set of graphs free_variables_total belong to.
  160. std::vector<AnfNodePtr> free_variables_nodes();
  161. // get all vars that are func graphs
  162. std::vector<FuncGraphPtr> free_variables_func_graphs();
  163. // get all func graphs directly used by this func graph
  164. const FuncGraphCounterMap &func_graphs_used();
  165. // get all func graphs nested used by this func graph
  166. const FuncGraphSet &func_graphs_used_total();
  167. // get all users of this func graph
  168. const FuncGraphCounterMap &func_graph_users();
  169. // get all user cnodes of this func graph
  170. const AnfNodeCounterMap &func_graph_user_cnodes();
  171. // Return the parent of this graph.
  172. FuncGraphPtr parent();
  173. // Return the children of this graph.
  174. const FuncGraphSet &children();
  175. // Return the scope of this graph, scope have graph self but children not have.
  176. const FuncGraphSet &scope();
  177. // Return whether this graph is recursive
  178. bool recursive();
  179. // Return graphs which forms a recursive loop
  180. std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
  181. std::size_t hash() const override { return std::hash<const FuncGraph *>{}(this); }
  182. void DumpFuncGraph(const std::string &path = "./func_graph.dot");
  183. bool operator==(const Value &other) const override {
  184. if (other.isa<FuncGraph>()) {
  185. return &other == this;
  186. } else {
  187. return false;
  188. }
  189. }
  190. void GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  191. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
  192. int pos_args_input_count);
  193. void GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  194. const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
  195. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  196. void GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
  197. const std::vector<AnfNodePtr> &specialized_parameter_list,
  198. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  199. const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
  200. void add_parameter_obj_node(const AnfNodePtr &p);
  201. std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
  202. std::unordered_map<std::string, bool> flags_;
  203. std::unordered_map<std::string, FuncGraphTransform> transforms_;
  204. // parameter default value
  205. std::map<std::string, AnfNodePtr> parameter_default_value_;
  206. std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
  207. std::list<CNodePtr> GetOrderedCnodes();
  208. void EraseUnusedNodeInOrder(const AnfNodePtr &n);
  209. void EraseUnusedNodeInOrder();
  210. void CheckOrder();
  211. void DumpCNodeList();
  212. void ReleaseFullOrderToEffectOrder();
  213. void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
  214. bool HasEffect(const CNodePtr &cnode);
  215. private:
  216. // graph is manipulated by manager and others
  217. friend FuncGraphManager;
  218. // parameters of this function
  219. std::vector<AnfNodePtr> parameters_;
  220. std::vector<AnfNodePtr> paramter_obj_nodes_;
  221. // whether there is a *args and **kwargs, and count kwonlyargs'number
  222. bool has_vararg_;
  223. bool has_kwarg_;
  224. int kwonlyargs_count_;
  225. // the hyper param is placed on the top graph,
  226. // and positioned in the end of the param list, so we record the number to trace the position
  227. size_t hyper_param_count_;
  228. // the argument input list for the graph used to generate this graph
  229. bool is_generated_;
  230. // the cnode that calls 'return' primitive
  231. // we use shared pointer to manage it.
  232. CNodePtr return_;
  233. // back-ref to its manager
  234. // hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
  235. // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
  236. // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
  237. // In some ut test cases, they may use local FuncGraphManager in function which
  238. // generating the func graph, when go outside of that function, func graph will have no
  239. // FuncGraphManager. In that special case, Manage() should be called to make the func graph
  240. // managed.
  241. std::weak_ptr<FuncGraphManager> manager_;
  242. GraphDebugInfoPtr debug_info_;
  243. void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
  244. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
  245. const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
  246. const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes);
  247. // CNode order which relates to origin code order
  248. std::list<CNodePtr> order_;
  249. };
  250. inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
  251. MS_EXCEPTION_IF_NULL(fg);
  252. return fg->NewCNode(inputs);
  253. }
  254. // Find the root cnodes of a segment of cnodes.
  255. std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
  256. // Find the leaf cnodes of a segment of cnodes.
  257. std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
  258. } // namespace mindspore
  259. #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_