You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph.h 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  19. #define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_
  20. #include <map>
  21. #include <string>
  22. #include <vector>
  23. #include <list>
  24. #include <memory>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include <functional>
  28. #include "ir/anf.h"
  29. #include "ir/manager.h"
  30. #include "utils/ordered_set.h"
  31. #include "utils/ordered_map.h"
  32. #include "utils/base_ref.h"
  33. namespace mindspore {
  34. using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
  35. using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
  36. struct CNodeIndexHasher {
  37. std::size_t operator()(const CNodeIndexPairPtr pair) const {
  38. MS_EXCEPTION_IF_NULL(pair);
  39. MS_EXCEPTION_IF_NULL(pair->first);
  40. return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
  41. }
  42. };
  43. struct CNodeIndexEqual {
  44. bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
  45. if (lhs == nullptr || rhs == nullptr) {
  46. return false;
  47. }
  48. if (lhs == rhs) {
  49. return true;
  50. }
  51. if (lhs->first != rhs->first) {
  52. return false;
  53. }
  54. if (lhs->second != rhs->second) {
  55. return false;
  56. }
  57. return true;
  58. }
  59. };
  60. template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
  61. using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
  62. using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
  63. using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
  64. using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
  65. const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
  66. const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
  67. const char FUNC_GRAPH_FLAG_CORE[] = "core";
  68. const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
  69. const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
  70. namespace abstract {
  71. class AbstractKeywordArg;
  72. using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
  73. class AbstractFunction;
  74. using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
  75. } // namespace abstract
  76. // ANF transform class
  77. // either a primitive or a func_graph
  78. class FuncGraphTransform {
  79. public:
  80. enum Type { kGtPrimitive, kGtFuncGraph };
  81. explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr)
  82. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  83. explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_)
  84. : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
  85. FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {}
  86. ~FuncGraphTransform() = default;
  87. Type type() const {
  88. if (IsFuncGraph()) {
  89. return kGtFuncGraph;
  90. } else {
  91. return kGtPrimitive;
  92. }
  93. }
  94. bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
  95. bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
  96. FuncGraphPtr func_graph() const { return func_graph_.lock(); }
  97. PrimitivePtr primitive() const { return prim_; }
  98. FuncGraphTransform &operator=(const FuncGraphTransform &t) {
  99. if (this != &t) {
  100. prim_ = t.prim_;
  101. func_graph_ = t.func_graph_;
  102. }
  103. return *this;
  104. }
  105. private:
  106. PrimitivePtr prim_;
  107. // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
  108. // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
  109. // FPropRemapper::FinalizeGraph().
  110. FuncGraphWeakPtr func_graph_;
  111. static const PrimitivePtr func_graph_prim_;
  112. };
  113. class FuncGraphBase : public Value {
  114. public:
  115. FuncGraphBase() = default;
  116. ~FuncGraphBase() override = default;
  117. MS_DECLARE_PARENT(FuncGraphBase, Value);
  118. };
  119. extern const char kFuncGraphFlagUndetermined[];
  120. class FuncGraph : public FuncGraphBase {
  121. public:
  122. FuncGraph();
  123. ~FuncGraph() override = default;
  124. MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
  125. // get the graph's abstract
  126. abstract::AbstractFunctionPtr abstract();
  127. abstract::AbstractBasePtr MakeAbstractClosure(const abstract::AnalysisContextPtr &context);
  128. // return the graph's output, or nullptr if not yet deduced
  129. AnfNodePtr output() const;
  130. void set_output(const AnfNodePtr &value, bool force_new_ret = false);
  131. const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
  132. virtual ParameterPtr add_parameter();
  133. void add_parameter(const ParameterPtr &p);
  134. void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
  135. // add a weight parameter with specific name
  136. ParameterPtr AddWeightParameter(const std::string &name);
  137. // create a cnode with given inputs, bound to this graph
  138. virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
  139. // create a cnode with given inputs, bound to this graph, and set to specific scope
  140. CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
  141. // Functions for handling variable argument, keyword-only arguments and variable keyword argument
  142. AnfNodePtr GetDefaultValueByName(const std::string &name);
  143. void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
  144. parameter_default_value_[name] = node;
  145. }
  146. void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
  147. void ClearDefaultValues();
  148. size_t GetDefaultValueCount();
  149. std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
  150. void set_has_vararg(bool has_) { has_vararg_ = has_; }
  151. bool has_vararg() const { return has_vararg_; }
  152. AnfNodePtr GetVariableArgParameter();
  153. std::string GetVariableArgName();
  154. void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
  155. bool has_kwarg() const { return has_kwarg_; }
  156. void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
  157. int kwonlyargs_count() const { return kwonlyargs_count_; }
  158. AnfNodePtr GetVariableKwargParameter();
  159. std::string GetVariableKwargName();
  160. void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
  161. size_t hyper_param_count() const { return hyper_param_count_; }
  162. int GetPositionalArgsCount() const;
  163. AnfNodePtr GetParameterByName(const std::string &name);
  164. bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
  165. FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
  166. void set_is_generate(bool generated) { is_generated_ = generated; }
  167. bool is_generated() const { return is_generated_; }
  168. std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
  169. void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
  170. for (auto &attr : attrs) {
  171. attrs_[attr.first] = attr.second;
  172. }
  173. }
  174. bool has_flag(const std::string &key);
  175. void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
  176. void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
  177. bool has_attr(const std::string &key);
  178. ValuePtr get_attr(const std::string &key);
  179. void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
  180. std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
  181. void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
  182. transforms_ = transforms;
  183. }
  184. CNodePtr get_return() const { return return_; }
  185. void set_return(const CNodePtr &cnode) { return_ = cnode; }
  186. FuncGraphManagerPtr manager() const { return manager_.lock(); }
  187. void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
  188. std::string ToString() const override;
  189. GraphDebugInfoPtr debug_info();
  190. void set_debug_info(const GraphDebugInfoPtr &info) {
  191. if (info == nullptr) {
  192. MS_LOG(EXCEPTION) << "Graph set null debug info";
  193. }
  194. this->debug_info_ = info;
  195. }
  196. // get all nodes belonging to this func graph
  197. const AnfNodeSet &nodes();
  198. void CopyNodes(const FuncGraphPtr &source);
  199. void ClearNodes();
  200. void AddNode(AnfNodePtr node);
  201. void DropNode(AnfNodePtr node);
  202. // get all value_nodes belonging to this func graph
  203. const AnfNodeCounterMap &value_nodes();
  204. void CopyValueNodes(const FuncGraphPtr &source);
  205. void ClearValueNodes();
  206. void AddValueNode(AnfNodePtr node, int count = 1);
  207. void DropValueNode(AnfNodePtr node);
  208. // get all free vars directly used in this func graph
  209. const AnfNodeCounterMap &free_variables();
  210. void CopyFreeVariables(const FuncGraphPtr &source);
  211. void ClearFreeVariables();
  212. bool AddFreeVariable(AnfNodePtr node, int count = 1);
  213. bool DropFreeVariable(AnfNodePtr node);
  214. // get all vars required by this func graph
  215. const BaseRefCounterMap &free_variables_total();
  216. // Return the set of graphs free_variables_total belong to.
  217. std::vector<AnfNodePtr> free_variables_nodes();
  218. // get all vars that are func graphs
  219. std::vector<FuncGraphPtr> free_variables_func_graphs();
  220. // get all value nodes of func graph directly used by this func graph
  221. const FuncGraphCounterMap &func_graphs_used();
  222. void CopyFuncGraphsUsed(const FuncGraphPtr &source);
  223. void ClearFuncGraphsUsed();
  224. bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
  225. bool DropFuncGraphUsed(FuncGraphPtr fg);
  226. // get all value nodes of J func graph directly used by this func graph
  227. const FuncGraphCounterMap &j_func_graphs();
  228. void CopyJFuncGraphs(const FuncGraphPtr &source);
  229. void ClearJFuncGraphs();
  230. void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
  231. void DropJFuncGraph(FuncGraphPtr fg);
  232. // get all func graphs nested used by this func graph
  233. const FuncGraphSet &func_graphs_used_total();
  234. // get all user value nodes of this func graph, by CNode and its input's index
  235. const CNodeIndexCounterMap &func_graph_cnodes_index();
  236. void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
  237. void ClearFuncGraphCNodesIndex();
  238. void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
  239. void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
  240. // Return the parent of this graph.
  241. FuncGraphPtr parent();
  242. // Return the children of this graph.
  243. const FuncGraphSet &children();
  244. // Return the scope of this graph, scope have graph self but children not have.
  245. const FuncGraphSet &scope();
  246. // Return whether this graph is recursive
  247. bool recursive();
  248. // Return graphs which forms a recursive loop
  249. std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
  250. std::size_t hash() const override { return std::hash<const FuncGraph *>{}(this); }
  251. void DumpFuncGraph(const std::string &path = "./func_graph.dot");
  252. bool operator==(const Value &other) const override {
  253. if (other.isa<FuncGraph>()) {
  254. return &other == this;
  255. } else {
  256. return false;
  257. }
  258. }
  259. void GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  260. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
  261. int pos_args_input_count);
  262. void GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector<AnfNodePtr> *specialized_parameter_list,
  263. const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
  264. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  265. void GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
  266. const std::vector<AnfNodePtr> &specialized_parameter_list,
  267. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes);
  268. const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
  269. void add_parameter_obj_node(const AnfNodePtr &p);
  270. std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
  271. std::unordered_map<std::string, ValuePtr> attrs_;
  272. std::unordered_map<std::string, FuncGraphTransform> transforms_;
  273. // parameter default value
  274. std::map<std::string, AnfNodePtr> parameter_default_value_;
  275. std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
  276. size_t seen_;
  277. std::list<CNodePtr> GetOrderedCnodes();
  278. void EraseUnusedNodeInOrder(const AnfNodePtr &n);
  279. void EraseUnusedNodeInOrder();
  280. void CheckOrder();
  281. void DumpCNodeList();
  282. void ReleaseFullOrderToEffectOrder();
  283. void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
  284. bool HasEffect(const CNodePtr &cnode);
  285. private:
  286. // graph is manipulated by manager and others
  287. friend FuncGraphManager;
  288. // all nodes of the function
  289. AnfNodeSet nodes_;
  290. // all value nodes of the function
  291. AnfNodeCounterMap value_nodes_;
  292. // all func graph value nodes of the function
  293. FuncGraphCounterMap func_graphs_used_;
  294. // all free variables of the function
  295. AnfNodeCounterMap free_variables_;
  296. // all value nodes calling J in the function
  297. FuncGraphCounterMap j_func_graphs_;
  298. // all user value nodes of this func graph, recording by CNode and its input's index
  299. CNodeIndexCounterMap func_graph_cnodes_index_;
  300. // parameters of this function
  301. std::vector<AnfNodePtr> parameters_;
  302. std::vector<AnfNodePtr> paramter_obj_nodes_;
  303. // whether there is a *args and **kwargs, and count kwonlyargs'number
  304. bool has_vararg_;
  305. bool has_kwarg_;
  306. int kwonlyargs_count_;
  307. // the hyper param is placed on the top graph,
  308. // and positioned in the end of the param list, so we record the number to trace the position
  309. size_t hyper_param_count_;
  310. // the argument input list for the graph used to generate this graph
  311. bool is_generated_;
  312. // the cnode that calls 'return' primitive
  313. // we use shared pointer to manage it.
  314. CNodePtr return_;
  315. // back-ref to its manager
  316. // hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
  317. // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
  318. // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
  319. // In some ut test cases, they may use local FuncGraphManager in function which
  320. // generating the func graph, when go outside of that function, func graph will have no
  321. // FuncGraphManager. In that special case, Manage() should be called to make the func graph
  322. // managed.
  323. std::weak_ptr<FuncGraphManager> manager_;
  324. GraphDebugInfoPtr debug_info_;
  325. void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
  326. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
  327. const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
  328. const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes);
  329. // CNode order which relates to origin code order
  330. std::list<CNodePtr> order_;
  331. };
  332. inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
  333. MS_EXCEPTION_IF_NULL(fg);
  334. return fg->NewCNode(inputs);
  335. }
  336. size_t NewFgSeenGeneration();
  337. // Find the root cnodes of a segment of cnodes.
  338. std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
  339. // Find the leaf cnodes of a segment of cnodes.
  340. std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
  341. } // namespace mindspore
  342. #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_