You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph.h 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  19. #define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include <list>
  24. #include <memory>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include <functional>
  28. #include "ir/anf.h"
  29. #include "ir/manager.h"
  30. #include "utils/any.h"
  31. #include "utils/ordered_set.h"
  32. #include "pipeline/static_analysis/abstract_value.h"
  33. namespace mindspore {
  34. using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
  35. using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
  36. template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
  37. using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
  38. using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
  39. using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
  40. using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
  41. const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
  42. const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
  43. const char FUNC_GRAPH_FLAG_CORE[] = "core";
  44. const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
  45. // ANF transform class
  46. // either a primitive or a func_graph
  47. class FuncGraphTransform {
  48. public:
  49. enum Type { kGtPrimitive, kGtFuncGraph };
  50. explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr)
  51. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  52. explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_)
  53. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  54. FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {}
  55. ~FuncGraphTransform() = default;
  56. Type type() const {
  57. if (IsFuncGraph()) {
  58. return kGtFuncGraph;
  59. } else {
  60. return kGtPrimitive;
  61. }
  62. }
  63. bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
  64. bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
  65. FuncGraphPtr func_graph() const { return func_graph_.lock(); }
  66. PrimitivePtr primitive() const { return prim_; }
  67. FuncGraphTransform &operator=(const FuncGraphTransform &t) {
  68. if (this != &t) {
  69. prim_ = t.prim_;
  70. func_graph_ = t.func_graph_;
  71. }
  72. return *this;
  73. }
  74. private:
  75. PrimitivePtr prim_;
  76. // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
  77. // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
  78. // FPropRemapper::FinalizeGraph().
  79. FuncGraphWeakPtr func_graph_;
  80. static const PrimitivePtr func_graph_prim_;
  81. };
  82. class FuncGraphBase : public Value {
  83. public:
  84. FuncGraphBase() = default;
  85. ~FuncGraphBase() override = default;
  86. MS_DECLARE_PARENT(FuncGraphBase, Value);
  87. };
  88. extern const char kFuncGraphFlagUndetermined[];
  89. class FuncGraph : public FuncGraphBase {
  90. public:
  91. FuncGraph();
  92. ~FuncGraph() override = default;
  93. MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
  94. // get the graph's abstract
  95. abstract::AbstractFunctionPtr abstract();
  96. abstract::AbstractBasePtr MakeAbstractClosure(const abstract::AnalysisContextPtr &context);
  97. // return the graph's output, or nullptr if not yet deduced
  98. AnfNodePtr output() const;
  99. void set_output(const AnfNodePtr &value, bool force_new_ret = false);
  100. const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
  101. virtual ParameterPtr add_parameter();
  102. void add_parameter(const ParameterPtr &p);
  103. void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
  104. // add a weight parameter with specific name
  105. ParameterPtr AddWeightParameter(const std::string &name);
  106. // create a cnode with given inputs, bound to this graph
  107. virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
  108. // create a cnode with given inputs, bound to this graph, and set to specific scope
  109. CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
  110. // Functions for handling variable argument, keyword-only arguments and variable keyword argument
  111. AnfNodePtr GetDefaultValueByName(const std::string &name);
  112. void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
  113. parameter_default_value_[name] = node;
  114. }
  115. void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
  116. void ClearDefaultValues();
  117. size_t GetDefaultValueCount();
  118. std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
  119. void set_has_vararg(bool has_) { has_vararg_ = has_; }
  120. bool has_vararg() const { return has_vararg_; }
  121. AnfNodePtr GetVariableArgParameter();
  122. std::string GetVariableArgName();
  123. void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
  124. bool has_kwarg() const { return has_kwarg_; }
  125. void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
  126. int kwonlyargs_count() const { return kwonlyargs_count_; }
  127. AnfNodePtr GetVariableKwargParameter();
  128. std::string GetVariableKwargName();
  129. void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
  130. size_t hyper_param_count() const { return hyper_param_count_; }
  131. int GetPositionalArgsCount() const;
  132. AnfNodePtr GetParameterByName(const std::string &name);
  133. bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
  134. FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
  135. void set_is_generate(bool generated) { is_generated_ = generated; }
  136. bool is_generated() const { return is_generated_; }
  137. bool has_flag(const std::string &flag);
  138. std::unordered_map<std::string, bool> &flags() { return flags_; }
  139. void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; }
  140. void set_flags(const std::string &key, const bool value) { flags_[key] = value; }
  141. std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
  142. void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
  143. transforms_ = transforms;
  144. }
  145. CNodePtr get_return() const { return return_; }
  146. void set_return(const CNodePtr &cnode) { return_ = cnode; }
  147. FuncGraphManagerPtr manager() const { return manager_.lock(); }
  148. void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
  149. std::string ToString() const override;
  150. GraphDebugInfoPtr debug_info();
  151. void set_debug_info(const GraphDebugInfoPtr &info) {
  152. if (info == nullptr) {
  153. MS_LOG(EXCEPTION) << "Graph set null debug info";
  154. }
  155. this->debug_info_ = info;
  156. }
  157. // get all nodes belonging to this func graph
  158. const AnfNodeSet &nodes();
  159. void CopyNodes(const AnfNodeSet &other_nodes);
  160. void ClearNodes();
  161. void AddNode(AnfNodePtr node);
  162. void DropNode(AnfNodePtr node);
  163. // get all value_nodes belonging to this func graph
  164. const AnfNodeCounterMap &value_nodes();
  165. void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes);
  166. void ClearValueNodes();
  167. void AddValueNode(AnfNodePtr node, int count = 1);
  168. void DropValueNode(AnfNodePtr node);
  169. // get all free vars directly used in this func graph
  170. const AnfNodeCounterMap &free_variables();
  171. void CopyFreeVariables(const AnfNodeCounterMap &others);
  172. void ClearFreeVariables();
  173. bool AddFreeVariable(AnfNodePtr node, int count = 1);
  174. bool DropFreeVariable(AnfNodePtr node);
  175. // get all vars required by this func graph
  176. const BaseRefCounterMap &free_variables_total();
  177. // Return the set of graphs free_variables_total belong to.
  178. std::vector<AnfNodePtr> free_variables_nodes();
  179. // get all vars that are func graphs
  180. std::vector<FuncGraphPtr> free_variables_func_graphs();
  181. // get all value nodes of func graph directly used by this func graph
  182. const AnfNodeCounterMap &func_graph_value_nodes();
  183. void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others);
  184. void ClearFuncGraphValueNodes();
  185. bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1);
  186. bool DropFuncGraphValueNode(AnfNodePtr node);
  187. // get all value nodes of J func graph directly used by this func graph
  188. const AnfNodeCounterMap &j_func_graph_value_nodes();
  189. void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others);
  190. void ClearJFuncGraphValueNodes();
  191. void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1);
  192. void DropJFuncGraphValueNode(AnfNodePtr node);
  193. // get all func graphs nested used by this func graph
  194. const FuncGraphSet &func_graphs_used_total();
  195. // get all user value nodes of this func graph, by CNode and its input's index
  196. const CNodeIndexCounterMap &func_graph_cnodes_index();
  197. void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes);
  198. void ClearFuncGraphCNodesIndex();
  199. void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
  200. void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
  201. // Return the parent of this graph.
  202. FuncGraphPtr parent();
  203. // Return the children of this graph.
  204. const FuncGraphSet &children();
  205. // Return the scope of this graph, scope have graph self but children not have.
  206. const FuncGraphSet &scope();
  207. // Return whether this graph is recursive
  208. bool recursive();
  209. // Return graphs which forms a recursive loop
  210. std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
  211. std::size_t hash() const override { return std::hash<const FuncGraph *>{}(this); }
  212. void DumpFuncGraph(const std::string &path = "./func_graph.dot");
  213. bool operator==(const Value &other) const override {
  214. if (other.isa<FuncGraph>()) {
  215. return &other == this;
  216. } else {
  217. return false;
  218. }
  219. }
  220. void GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  221. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
  222. int pos_args_input_count);
  223. void GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  224. const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
  225. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  226. void GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
  227. const std::vector<AnfNodePtr> &specialized_parameter_list,
  228. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  229. const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
  230. void add_parameter_obj_node(const AnfNodePtr &p);
  231. std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
  232. std::unordered_map<std::string, bool> flags_;
  233. std::unordered_map<std::string, FuncGraphTransform> transforms_;
  234. // parameter default value
  235. std::map<std::string, AnfNodePtr> parameter_default_value_;
  236. std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
  237. std::list<CNodePtr> GetOrderedCnodes();
  238. void EraseUnusedNodeInOrder(const AnfNodePtr &n);
  239. void EraseUnusedNodeInOrder();
  240. void CheckOrder();
  241. void DumpCNodeList();
  242. void ReleaseFullOrderToEffectOrder();
  243. void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
  244. bool HasEffect(const CNodePtr &cnode);
  245. private:
  246. // graph is manipulated by manager and others
  247. friend FuncGraphManager;
  248. // all nodes of the function
  249. AnfNodeSet nodes_;
  250. // all value nodes of the function
  251. AnfNodeCounterMap value_nodes_;
  252. // all func graph value nodes of the function
  253. AnfNodeCounterMap func_graph_value_nodes_;
  254. // all free variables of the function
  255. AnfNodeCounterMap free_variables_;
  256. // all value nodes calling J in the function
  257. AnfNodeCounterMap j_func_graph_value_nodes_;
  258. // all user value nodes of this func graph, recording by CNode and its input's index
  259. CNodeIndexCounterMap func_graph_cnodes_index_;
  260. // parameters of this function
  261. std::vector<AnfNodePtr> parameters_;
  262. std::vector<AnfNodePtr> paramter_obj_nodes_;
  263. // whether there is a *args and **kwargs, and count kwonlyargs'number
  264. bool has_vararg_;
  265. bool has_kwarg_;
  266. int kwonlyargs_count_;
  267. // the hyper param is placed on the top graph,
  268. // and positioned in the end of the param list, so we record the number to trace the position
  269. size_t hyper_param_count_;
  270. // the argument input list for the graph used to generate this graph
  271. bool is_generated_;
  272. // the cnode that calls 'return' primitive
  273. // we use shared pointer to manage it.
  274. CNodePtr return_;
  275. // back-ref to its manager
  276. // hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
  277. // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
  278. // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
  279. // In some ut test cases, they may use local FuncGraphManager in function which
  280. // generating the func graph, when go outside of that function, func graph will have no
  281. // FuncGraphManager. In that special case, Manage() should be called to make the func graph
  282. // managed.
  283. std::weak_ptr<FuncGraphManager> manager_;
  284. GraphDebugInfoPtr debug_info_;
  285. void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
  286. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
  287. const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
  288. const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes);
  289. // CNode order which relates to origin code order
  290. std::list<CNodePtr> order_;
  291. };
  292. inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
  293. MS_EXCEPTION_IF_NULL(fg);
  294. return fg->NewCNode(inputs);
  295. }
  296. // Find the root cnodes of a segment of cnodes.
  297. std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
  298. // Find the leaf cnodes of a segment of cnodes.
  299. std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
  300. } // namespace mindspore
  301. #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_