You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph.h 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  19. #define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include <list>
  24. #include <memory>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include <functional>
  28. #include "ir/anf.h"
  29. #include "ir/manager.h"
  30. #include "utils/ordered_set.h"
  31. #include "utils/ordered_map.h"
  32. #include "utils/base_ref.h"
  33. namespace mindspore {
  34. using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
  35. using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
  36. template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
  37. using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
  38. using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
  39. using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
  40. using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
  41. const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
  42. const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
  43. const char FUNC_GRAPH_FLAG_CORE[] = "core";
  44. const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
  45. namespace abstract {
  46. class AbstractKeywordArg;
  47. using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
  48. class AbstractFunction;
  49. using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
  50. } // namespace abstract
  51. // ANF transform class
  52. // either a primitive or a func_graph
  53. class FuncGraphTransform {
  54. public:
  55. enum Type { kGtPrimitive, kGtFuncGraph };
  56. explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr)
  57. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  58. explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_)
  59. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  60. FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {}
  61. ~FuncGraphTransform() = default;
  62. Type type() const {
  63. if (IsFuncGraph()) {
  64. return kGtFuncGraph;
  65. } else {
  66. return kGtPrimitive;
  67. }
  68. }
  69. bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
  70. bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
  71. FuncGraphPtr func_graph() const { return func_graph_.lock(); }
  72. PrimitivePtr primitive() const { return prim_; }
  73. FuncGraphTransform &operator=(const FuncGraphTransform &t) {
  74. if (this != &t) {
  75. prim_ = t.prim_;
  76. func_graph_ = t.func_graph_;
  77. }
  78. return *this;
  79. }
  80. private:
  81. PrimitivePtr prim_;
  82. // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
  83. // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
  84. // FPropRemapper::FinalizeGraph().
  85. FuncGraphWeakPtr func_graph_;
  86. static const PrimitivePtr func_graph_prim_;
  87. };
  88. class FuncGraphBase : public Value {
  89. public:
  90. FuncGraphBase() = default;
  91. ~FuncGraphBase() override = default;
  92. MS_DECLARE_PARENT(FuncGraphBase, Value);
  93. };
  94. extern const char kFuncGraphFlagUndetermined[];
  95. class FuncGraph : public FuncGraphBase {
  96. public:
  97. FuncGraph();
  98. ~FuncGraph() override = default;
  99. MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
  100. // get the graph's abstract
  101. abstract::AbstractFunctionPtr abstract();
  102. abstract::AbstractBasePtr MakeAbstractClosure(const abstract::AnalysisContextPtr &context);
  103. // return the graph's output, or nullptr if not yet deduced
  104. AnfNodePtr output() const;
  105. void set_output(const AnfNodePtr &value, bool force_new_ret = false);
  106. const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
  107. virtual ParameterPtr add_parameter();
  108. void add_parameter(const ParameterPtr &p);
  109. void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
  110. // add a weight parameter with specific name
  111. ParameterPtr AddWeightParameter(const std::string &name);
  112. // create a cnode with given inputs, bound to this graph
  113. virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
  114. // create a cnode with given inputs, bound to this graph, and set to specific scope
  115. CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
  116. // Functions for handling variable argument, keyword-only arguments and variable keyword argument
  117. AnfNodePtr GetDefaultValueByName(const std::string &name);
  118. void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
  119. parameter_default_value_[name] = node;
  120. }
  121. void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
  122. void ClearDefaultValues();
  123. size_t GetDefaultValueCount();
  124. std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
  125. void set_has_vararg(bool has_) { has_vararg_ = has_; }
  126. bool has_vararg() const { return has_vararg_; }
  127. AnfNodePtr GetVariableArgParameter();
  128. std::string GetVariableArgName();
  129. void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
  130. bool has_kwarg() const { return has_kwarg_; }
  131. void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
  132. int kwonlyargs_count() const { return kwonlyargs_count_; }
  133. AnfNodePtr GetVariableKwargParameter();
  134. std::string GetVariableKwargName();
  135. void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
  136. size_t hyper_param_count() const { return hyper_param_count_; }
  137. int GetPositionalArgsCount() const;
  138. AnfNodePtr GetParameterByName(const std::string &name);
  139. bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
  140. FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
  141. void set_is_generate(bool generated) { is_generated_ = generated; }
  142. bool is_generated() const { return is_generated_; }
  143. bool has_flag(const std::string &flag);
  144. std::unordered_map<std::string, bool> &flags() { return flags_; }
  145. void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; }
  146. void set_flags(const std::string &key, const bool value) { flags_[key] = value; }
  147. std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
  148. void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
  149. transforms_ = transforms;
  150. }
  151. CNodePtr get_return() const { return return_; }
  152. void set_return(const CNodePtr &cnode) { return_ = cnode; }
  153. FuncGraphManagerPtr manager() const { return manager_.lock(); }
  154. void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
  155. std::string ToString() const override;
  156. GraphDebugInfoPtr debug_info();
  157. void set_debug_info(const GraphDebugInfoPtr &info) {
  158. if (info == nullptr) {
  159. MS_LOG(EXCEPTION) << "Graph set null debug info";
  160. }
  161. this->debug_info_ = info;
  162. }
  163. // get all nodes belonging to this func graph
  164. const AnfNodeSet &nodes();
  165. void CopyNodes(const FuncGraphPtr &source);
  166. void ClearNodes();
  167. void AddNode(AnfNodePtr node);
  168. void DropNode(AnfNodePtr node);
  169. // get all value_nodes belonging to this func graph
  170. const AnfNodeCounterMap &value_nodes();
  171. void CopyValueNodes(const FuncGraphPtr &source);
  172. void ClearValueNodes();
  173. void AddValueNode(AnfNodePtr node, int count = 1);
  174. void DropValueNode(AnfNodePtr node);
  175. // get all free vars directly used in this func graph
  176. const AnfNodeCounterMap &free_variables();
  177. void CopyFreeVariables(const FuncGraphPtr &source);
  178. void ClearFreeVariables();
  179. bool AddFreeVariable(AnfNodePtr node, int count = 1);
  180. bool DropFreeVariable(AnfNodePtr node);
  181. // get all vars required by this func graph
  182. const BaseRefCounterMap &free_variables_total();
  183. // Return the set of graphs free_variables_total belong to.
  184. std::vector<AnfNodePtr> free_variables_nodes();
  185. // get all vars that are func graphs
  186. std::vector<FuncGraphPtr> free_variables_func_graphs();
  187. // get all value nodes of func graph directly used by this func graph
  188. const FuncGraphCounterMap &func_graphs_used();
  189. void CopyFuncGraphsUsed(const FuncGraphPtr &source);
  190. void ClearFuncGraphsUsed();
  191. bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
  192. bool DropFuncGraphUsed(FuncGraphPtr fg);
  193. // get all value nodes of J func graph directly used by this func graph
  194. const FuncGraphCounterMap &j_func_graphs();
  195. void CopyJFuncGraphs(const FuncGraphPtr &source);
  196. void ClearJFuncGraphs();
  197. void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
  198. void DropJFuncGraph(FuncGraphPtr fg);
  199. // get all func graphs nested used by this func graph
  200. const FuncGraphSet &func_graphs_used_total();
  201. // get all user value nodes of this func graph, by CNode and its input's index
  202. const CNodeIndexCounterMap &func_graph_cnodes_index();
  203. void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
  204. void ClearFuncGraphCNodesIndex();
  205. void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
  206. void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
  207. // Return the parent of this graph.
  208. FuncGraphPtr parent();
  209. // Return the children of this graph.
  210. const FuncGraphSet &children();
  211. // Return the scope of this graph, scope have graph self but children not have.
  212. const FuncGraphSet &scope();
  213. // Return whether this graph is recursive
  214. bool recursive();
  215. // Return graphs which forms a recursive loop
  216. std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
  217. std::size_t hash() const override { return std::hash<const FuncGraph *>{}(this); }
  218. void DumpFuncGraph(const std::string &path = "./func_graph.dot");
  219. bool operator==(const Value &other) const override {
  220. if (other.isa<FuncGraph>()) {
  221. return &other == this;
  222. } else {
  223. return false;
  224. }
  225. }
  226. void GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  227. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
  228. int pos_args_input_count);
  229. void GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  230. const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
  231. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  232. void GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
  233. const std::vector<AnfNodePtr> &specialized_parameter_list,
  234. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  235. const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
  236. void add_parameter_obj_node(const AnfNodePtr &p);
  237. std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
  238. std::unordered_map<std::string, bool> flags_;
  239. std::unordered_map<std::string, FuncGraphTransform> transforms_;
  240. // parameter default value
  241. std::map<std::string, AnfNodePtr> parameter_default_value_;
  242. std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
  243. size_t seen_;
  244. std::list<CNodePtr> GetOrderedCnodes();
  245. void EraseUnusedNodeInOrder(const AnfNodePtr &n);
  246. void EraseUnusedNodeInOrder();
  247. void CheckOrder();
  248. void DumpCNodeList();
  249. void ReleaseFullOrderToEffectOrder();
  250. void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
  251. bool HasEffect(const CNodePtr &cnode);
  252. private:
  253. // graph is manipulated by manager and others
  254. friend FuncGraphManager;
  255. // all nodes of the function
  256. AnfNodeSet nodes_;
  257. // all value nodes of the function
  258. AnfNodeCounterMap value_nodes_;
  259. // all func graph value nodes of the function
  260. FuncGraphCounterMap func_graphs_used_;
  261. // all free variables of the function
  262. AnfNodeCounterMap free_variables_;
  263. // all value nodes calling J in the function
  264. FuncGraphCounterMap j_func_graphs_;
  265. // all user value nodes of this func graph, recording by CNode and its input's index
  266. CNodeIndexCounterMap func_graph_cnodes_index_;
  267. // parameters of this function
  268. std::vector<AnfNodePtr> parameters_;
  269. std::vector<AnfNodePtr> paramter_obj_nodes_;
  270. // whether there is a *args and **kwargs, and count kwonlyargs'number
  271. bool has_vararg_;
  272. bool has_kwarg_;
  273. int kwonlyargs_count_;
  274. // the hyper param is placed on the top graph,
  275. // and positioned in the end of the param list, so we record the number to trace the position
  276. size_t hyper_param_count_;
  277. // the argument input list for the graph used to generate this graph
  278. bool is_generated_;
  279. // the cnode that calls 'return' primitive
  280. // we use shared pointer to manage it.
  281. CNodePtr return_;
  282. // back-ref to its manager
  283. // hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
  284. // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
  285. // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
  286. // In some ut test cases, they may use local FuncGraphManager in function which
  287. // generating the func graph, when go outside of that function, func graph will have no
  288. // FuncGraphManager. In that special case, Manage() should be called to make the func graph
  289. // managed.
  290. std::weak_ptr<FuncGraphManager> manager_;
  291. GraphDebugInfoPtr debug_info_;
  292. void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
  293. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
  294. const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
  295. const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes);
  296. // CNode order which relates to origin code order
  297. std::list<CNodePtr> order_;
  298. };
  299. inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
  300. MS_EXCEPTION_IF_NULL(fg);
  301. return fg->NewCNode(inputs);
  302. }
  303. size_t NewFgSeenGeneration();
  304. // Find the root cnodes of a segment of cnodes.
  305. std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
  306. // Find the leaf cnodes of a segment of cnodes.
  307. std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
  308. } // namespace mindspore
  309. #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_