You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ir/func_graph.h"
  19. #include <algorithm>
  20. #include <sstream>
  21. #include <utility>
  22. #include "ir/manager.h"
  23. #include "ir/func_graph_cloner.h"
  24. #include "operator/ops.h"
  25. #include "utils/ordered_set.h"
  26. #include "pipeline/static_analysis/static_analysis.h"
  27. #include "pipeline/static_analysis/abstract_function.h"
  28. #include "debug/anf_ir_dump.h"
  29. #include "debug/trace.h"
  30. #include "debug/draw.h"
  31. #include "debug/label.h"
  32. namespace mindspore {
  33. using mindspore::abstract::AbstractFunction;
  34. using mindspore::abstract::AbstractFunctionPtr;
  35. using mindspore::abstract::AnalysisContextPtr;
  36. using mindspore::abstract::PrimitiveAbstractClosure;
  37. using mindspore::abstract::VirtualAbstractClosure;
  38. /*
  39. * Methods of Graph
  40. */
  41. FuncGraph::FuncGraph()
  42. : flags_(),
  43. transforms_(),
  44. parameter_default_value_(),
  45. parameters_(),
  46. has_vararg_(false),
  47. has_kwarg_(false),
  48. kwonlyargs_count_(0),
  49. hyper_param_count_(0),
  50. is_generated_(false),
  51. return_(nullptr),
  52. manager_(std::weak_ptr<FuncGraphManager>()) {
  53. debug_info_ = std::make_shared<GraphDebugInfo>();
  54. }
  55. AbstractFunctionPtr FuncGraph::abstract() {
  56. AbstractBasePtrList args_spec_list;
  57. for (auto &p : parameters_) {
  58. MS_EXCEPTION_IF_NULL(p);
  59. if (p->abstract() == nullptr) {
  60. MS_LOG(ERROR) << "Error!!";
  61. return nullptr;
  62. }
  63. args_spec_list.push_back(p->abstract());
  64. }
  65. if (nullptr == output()) {
  66. MS_LOG(ERROR) << "Error func graph no output";
  67. return nullptr;
  68. }
  69. return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
  70. }
  71. abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) {
  72. AnalysisContextPtr temp_context = context;
  73. if (temp_context == nullptr) {
  74. temp_context = abstract::AnalysisContext::DummyContext();
  75. }
  76. return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
  77. }
  78. AnfNodePtr FuncGraph::output() const {
  79. // If return value is set, return should have two inputs.
  80. if (return_ != nullptr && return_->inputs().size() == 2) {
  81. return return_->input(1);
  82. } else {
  83. // If not set yet, return nullptr.
  84. return nullptr;
  85. }
  86. }
  87. void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
  88. if (force_new_ret || return_ == nullptr) {
  89. std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
  90. FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
  91. return_ = this_graph->NewCNode(params);
  92. } else {
  93. if (manager_.lock()) {
  94. manager_.lock()->SetEdge(return_, 1, value);
  95. } else {
  96. return_->set_input(1, value);
  97. }
  98. }
  99. return_->set_abstract(value->abstract());
  100. AnfNodePtr input0 = return_->input(0);
  101. PrimitivePtr return_prim = prim::kPrimReturn;
  102. auto f = std::make_shared<PrimitiveAbstractClosure>(return_prim, input0);
  103. input0->set_abstract(f);
  104. }
  105. ParameterPtr FuncGraph::add_parameter() {
  106. FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
  107. ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
  108. add_parameter(p);
  109. return p;
  110. }
  111. void FuncGraph::add_parameter(const ParameterPtr &p) {
  112. if (manager_.lock()) {
  113. std::vector<AnfNodePtr> new_params = parameters_;
  114. new_params.push_back(p);
  115. manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
  116. } else {
  117. parameters_.push_back(p);
  118. }
  119. }
  120. ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
  121. FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
  122. ParameterPtr p = std::make_shared<Parameter>(this_graph);
  123. p->set_name(name);
  124. p->debug_info()->set_name(name);
  125. std::vector<AnfNodePtr> new_params = parameters_;
  126. // append parameter
  127. new_params.push_back(p);
  128. if (manager_.lock()) {
  129. manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
  130. } else {
  131. parameters_.push_back(p);
  132. }
  133. hyper_param_count_++;
  134. return p;
  135. }
  136. bool FuncGraph::has_flag(const std::string &flag) {
  137. if (flags_.count(flag)) {
  138. return flags_[flag];
  139. }
  140. return false;
  141. }
  142. CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  143. CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
  144. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  145. order_.push_back(cnode);
  146. MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order.";
  147. }
  148. return cnode;
  149. }
  150. CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) {
  151. CNodePtr app = NewCNode(inputs);
  152. app->set_scope(scope);
  153. return app;
  154. }
  155. void FuncGraph::DumpCNodeList() {
  156. MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
  157. for (const auto &cnode : order_) {
  158. MS_LOG(INFO) << cnode->DebugString();
  159. }
  160. }
  161. std::string FuncGraph::ToString() const {
  162. return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
  163. }
  164. GraphDebugInfoPtr FuncGraph::debug_info() {
  165. MS_EXCEPTION_IF_NULL(this->debug_info_);
  166. if (this->debug_info_->get_graph() == nullptr) {
  167. this->debug_info_->set_graph(shared_from_base<FuncGraph>());
  168. }
  169. return this->debug_info_;
  170. }
  171. const AnfNodeSet &FuncGraph::nodes() {
  172. auto mng = manager_.lock();
  173. MS_EXCEPTION_IF_NULL(mng);
  174. auto &nodes = mng->nodes();
  175. return nodes[shared_from_base<FuncGraph>()];
  176. }
  177. const AnfNodeCounterMap &FuncGraph::value_nodes() {
  178. auto mng = manager_.lock();
  179. MS_EXCEPTION_IF_NULL(mng);
  180. auto &cts = mng->valuenodes();
  181. return cts[shared_from_base<FuncGraph>()];
  182. }
  183. const AnfNodeCounterMap &FuncGraph::free_variables_direct() {
  184. auto mng = manager_.lock();
  185. MS_EXCEPTION_IF_NULL(mng);
  186. auto &fv_direct = mng->free_variables_direct();
  187. return fv_direct[shared_from_base<FuncGraph>()];
  188. }
  189. const BaseRefCounterMap &FuncGraph::free_variables_total() {
  190. auto mng = manager_.lock();
  191. MS_EXCEPTION_IF_NULL(mng);
  192. auto &fv_total = mng->free_variables_total();
  193. return fv_total[shared_from_base<FuncGraph>()];
  194. }
  195. std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
  196. std::vector<AnfNodePtr> nodes;
  197. const auto &fv_total = this->free_variables_total();
  198. for (auto &p : fv_total) {
  199. auto key = p.first;
  200. if (utils::isa<AnfNodePtr>(key)) {
  201. nodes.push_back(utils::cast<AnfNodePtr>(key));
  202. }
  203. }
  204. return nodes;
  205. }
  206. std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
  207. std::vector<FuncGraphPtr> func_graphs;
  208. const auto &fv_total = this->free_variables_total();
  209. for (auto &p : fv_total) {
  210. auto key = p.first;
  211. if (utils::isa<FuncGraphPtr>(key)) {
  212. func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
  213. }
  214. }
  215. return func_graphs;
  216. }
  217. const FuncGraphCounterMap &FuncGraph::func_graphs_used() {
  218. auto mng = manager_.lock();
  219. MS_EXCEPTION_IF_NULL(mng);
  220. auto &used = mng->func_graphs_used();
  221. return used[shared_from_base<FuncGraph>()];
  222. }
  223. const FuncGraphSet &FuncGraph::func_graphs_used_total() {
  224. auto mng = manager_.lock();
  225. MS_EXCEPTION_IF_NULL(mng);
  226. auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
  227. return used;
  228. }
  229. const FuncGraphCounterMap &FuncGraph::func_graph_users() {
  230. auto mng = manager_.lock();
  231. MS_EXCEPTION_IF_NULL(mng);
  232. auto &users = mng->func_graph_users();
  233. return users[shared_from_base<FuncGraph>()];
  234. }
  235. const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
  236. auto mng = manager_.lock();
  237. MS_EXCEPTION_IF_NULL(mng);
  238. auto &users = mng->func_graph_user_cnodes();
  239. return users[shared_from_base<FuncGraph>()];
  240. }
  241. FuncGraphPtr FuncGraph::parent() {
  242. // report the bug early.
  243. if (manager_.lock() == nullptr) {
  244. MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
  245. << " NodeInfo: " << trace::GetDebugInfo(debug_info());
  246. }
  247. auto mng = manager_.lock();
  248. MS_EXCEPTION_IF_NULL(mng);
  249. return mng->parent(shared_from_base<FuncGraph>());
  250. }
  251. const FuncGraphSet &FuncGraph::children() {
  252. auto mng = manager_.lock();
  253. MS_EXCEPTION_IF_NULL(mng);
  254. return mng->children(shared_from_base<FuncGraph>());
  255. }
  256. const FuncGraphSet &FuncGraph::scope() {
  257. auto mng = manager_.lock();
  258. MS_EXCEPTION_IF_NULL(mng);
  259. return mng->scopes(shared_from_base<FuncGraph>());
  260. }
  261. bool FuncGraph::recursive() {
  262. auto mng = manager_.lock();
  263. MS_EXCEPTION_IF_NULL(mng);
  264. return mng->recursive(shared_from_base<FuncGraph>());
  265. }
  266. std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
  267. auto mng = manager_.lock();
  268. MS_EXCEPTION_IF_NULL(mng);
  269. return mng->recursive_graphs(shared_from_base<FuncGraph>());
  270. }
  271. void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
  272. AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
  273. auto itr = this->parameter_default_value_.find(name);
  274. if (itr == parameter_default_value_.end()) {
  275. return nullptr;
  276. }
  277. auto default_value = itr->second;
  278. if (default_value == nullptr) {
  279. MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist";
  280. }
  281. if (IsValueNode<NullObj>(default_value)) {
  282. return nullptr;
  283. }
  284. return default_value;
  285. }
  286. // set the default values
  287. void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
  288. auto all_is_null = std::all_of(value_list.begin(), value_list.end(),
  289. [](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); });
  290. if (value_list.empty()) {
  291. all_is_null = true;
  292. }
  293. for (size_t i = 0; i < name_list.size(); ++i) {
  294. if (!all_is_null) {
  295. this->parameter_default_value_[name_list[i]] = value_list[i];
  296. }
  297. }
  298. }
  299. void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
  300. size_t FuncGraph::GetDefaultValueCount() {
  301. int null_count =
  302. std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
  303. [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); });
  304. return parameter_default_value_.size() - IntToSize(null_count);
  305. }
  306. AnfNodePtr FuncGraph::GetVariableArgParameter() {
  307. if (!has_vararg_) {
  308. return nullptr;
  309. }
  310. if (has_kwarg_) {
  311. if (parameters_.size() < hyper_param_count_ + 2) {
  312. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  313. << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
  314. }
  315. return parameters_[parameters_.size() - hyper_param_count_ - 2];
  316. }
  317. if (parameters_.size() < hyper_param_count_ + 1) {
  318. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  319. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  320. }
  321. return parameters_[parameters_.size() - hyper_param_count_ - 1];
  322. }
  323. std::string FuncGraph::GetVariableArgName() {
  324. if (!has_vararg_) {
  325. return "";
  326. }
  327. if (has_kwarg_) {
  328. if (parameters_.size() < hyper_param_count_ + 2) {
  329. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  330. << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
  331. }
  332. return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast<ParameterPtr>()->name();
  333. }
  334. if (parameters_.size() < hyper_param_count_ + 1) {
  335. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  336. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  337. }
  338. return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>()->name();
  339. }
  340. AnfNodePtr FuncGraph::GetVariableKwargParameter() {
  341. if (has_kwarg_) {
  342. if (parameters_.size() < hyper_param_count_ + 1) {
  343. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  344. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  345. }
  346. return parameters_[parameters_.size() - hyper_param_count_ - 1];
  347. }
  348. return nullptr;
  349. }
  350. std::string FuncGraph::GetVariableKwargName() {
  351. if (has_kwarg_) {
  352. if (parameters_.size() < hyper_param_count_ + 1) {
  353. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  354. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  355. }
  356. return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>()->name();
  357. }
  358. return "";
  359. }
  360. int FuncGraph::GetPositionalArgsCount() const {
  361. int count = SizeToInt(parameters_.size());
  362. if (has_kwarg_) {
  363. count--;
  364. }
  365. if (has_vararg_) {
  366. count--;
  367. }
  368. return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
  369. }
  370. AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
  371. for (size_t i = 0; i < parameters_.size(); ++i) {
  372. MS_EXCEPTION_IF_NULL(parameters_[i]);
  373. auto param_cast = parameters_[i]->cast<ParameterPtr>();
  374. MS_EXCEPTION_IF_NULL(param_cast);
  375. if (param_cast->name() == name) {
  376. return parameters_[i];
  377. }
  378. }
  379. return nullptr;
  380. }
  381. void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
  382. std::vector<AnfNodePtr> *specialized_parameter_list,
  383. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
  384. int pos_args_input_count) {
  385. // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
  386. if (specialized_graph->has_vararg()) {
  387. TraceManager::DebugTrace(
  388. std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
  389. std::vector<AnfNodePtr> var_param_tuple_nodes;
  390. var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
  391. if (variable_args_count < 0) {
  392. MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count
  393. << " were given.";
  394. }
  395. // for python variable argument input , there is no upper limit
  396. for (int i = 0; i < variable_args_count; ++i) {
  397. ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
  398. std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i);
  399. p->set_name(param_name);
  400. MS_EXCEPTION_IF_NULL(p->debug_info());
  401. p->debug_info()->set_name(param_name);
  402. var_param_tuple_nodes.push_back(p);
  403. MS_EXCEPTION_IF_NULL(specialized_parameter_list);
  404. specialized_parameter_list->push_back(p);
  405. }
  406. auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes);
  407. (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
  408. TraceManager::EndTrace();
  409. } else if (variable_args_count > 0) {
  410. MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount()
  411. << " positional arguments, but " << pos_args_input_count << " were given.";
  412. }
  413. }
  414. void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
  415. std::vector<AnfNodePtr> *specialized_parameter_list,
  416. const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
  417. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
  418. std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
  419. std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
  420. for (const auto &kwarg : kwarg_list) {
  421. MS_EXCEPTION_IF_NULL(kwarg);
  422. std::string kw_param_name = kwarg->get_key();
  423. MS_EXCEPTION_IF_NULL(specialized_graph);
  424. AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
  425. // if not find correspoding parameter node
  426. if (param_node == nullptr) {
  427. if (!has_kwarg()) {
  428. MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
  429. } else {
  430. ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
  431. std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
  432. MS_EXCEPTION_IF_NULL(specialized_parameter_list);
  433. auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
  434. [param_name](const AnfNodePtr &node) {
  435. MS_EXCEPTION_IF_NULL(node);
  436. auto param = node->cast<ParameterPtr>();
  437. return param != nullptr && param->name() == param_name;
  438. });
  439. if (find_kw_arg_in_list) {
  440. MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name;
  441. }
  442. p->set_name(param_name);
  443. p->debug_info()->set_name(param_name);
  444. kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name));
  445. auto extract_node =
  446. specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p});
  447. kwarg_values_tuple_nodes.push_back(extract_node);
  448. specialized_parameter_list->push_back(p);
  449. }
  450. } else {
  451. auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
  452. // multiply values found given for parameter
  453. if (node_itr != specialized_parameter_list->end()) {
  454. MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name;
  455. } else {
  456. specialized_parameter_list->push_back(param_node);
  457. auto extract_node = specialized_graph->NewCNode(
  458. {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
  459. (void)repl_nodes->emplace(param_node, extract_node);
  460. }
  461. }
  462. }
  463. GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes);
  464. }
  465. void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
  466. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
  467. const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
  468. const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) {
  469. if (has_kwarg()) {
  470. MS_EXCEPTION_IF_NULL(specialized_graph);
  471. TraceManager::DebugTrace(
  472. std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
  473. auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
  474. auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
  475. auto make_dict_node =
  476. specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
  477. MS_EXCEPTION_IF_NULL(repl_nodes);
  478. (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
  479. TraceManager::EndTrace();
  480. }
  481. }
  482. bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
  483. // if the function does not have any vararg/kwarg/kwonly/default value/kw args input
  484. // return the original graph
  485. if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
  486. return false;
  487. }
  488. // if the graph is generated for specific input, do not need to generate again
  489. if (is_generated()) {
  490. return false;
  491. }
  492. return true;
  493. }
  494. void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
  495. const std::vector<AnfNodePtr> &specialized_parameter_list,
  496. std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
  497. MS_EXCEPTION_IF_NULL(specialized_graph);
  498. for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
  499. auto param_node = specialized_graph->parameters()[i];
  500. MS_EXCEPTION_IF_NULL(param_node);
  501. auto param_name = param_node->cast<ParameterPtr>()->name();
  502. auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node);
  503. if (node_itr != specialized_parameter_list.end()) {
  504. continue;
  505. }
  506. if (param_name == specialized_graph->GetVariableArgName() ||
  507. param_name == specialized_graph->GetVariableKwargName()) {
  508. continue;
  509. }
  510. auto default_value = specialized_graph->GetDefaultValueByName(param_name);
  511. if (default_value == nullptr) {
  512. MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name;
  513. }
  514. MS_EXCEPTION_IF_NULL(repl_nodes);
  515. (void)repl_nodes->emplace(param_node, default_value);
  516. }
  517. }
  518. FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
  519. std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
  520. size_t arguments_count = args_spec_list.size();
  521. for (const auto &arg : args_spec_list) {
  522. // if it is a keyword argument
  523. MS_EXCEPTION_IF_NULL(arg);
  524. if (arg->isa<abstract::AbstractKeywordArg>()) {
  525. kwarg_list.push_back(dyn_cast<abstract::AbstractKeywordArg>(arg));
  526. }
  527. }
  528. if (!NeedGenerate(kwarg_list)) {
  529. return shared_from_base<FuncGraph>();
  530. }
  531. FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
  532. size_t kwarg_count = kwarg_list.size();
  533. int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count());
  534. int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
  535. int variable_args_count = pos_args_input_count - pos_args_count;
  536. std::vector<AnfNodePtr> specialized_parameter_list;
  537. std::unordered_map<AnfNodePtr, AnfNodePtr> repl_nodes;
  538. // the parameters that has arg input, copy from original parameters
  539. for (size_t i = 0; i < IntToSize(pos_args_count); ++i) {
  540. specialized_parameter_list.push_back(specialized_graph->parameters()[i]);
  541. }
  542. GenerateVarParams(specialized_graph, &specialized_parameter_list, &repl_nodes, variable_args_count,
  543. pos_args_input_count);
  544. GenerateKwParams(specialized_graph, &specialized_parameter_list, kwarg_list, &repl_nodes);
  545. GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
  546. // append hyper parameter to specialized_parameter_list
  547. MS_EXCEPTION_IF_NULL(specialized_graph);
  548. auto params = specialized_graph->parameters();
  549. (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
  550. std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
  551. std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
  552. auto tr = manager->Transact();
  553. for (auto &node_pair : repl_nodes) {
  554. MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
  555. << node_pair.second->DebugString();
  556. (void)tr.Replace(node_pair.first, node_pair.second);
  557. }
  558. tr.SetParameters(specialized_graph, specialized_parameter_list);
  559. tr.Commit();
  560. specialized_graph->set_has_kwarg(false);
  561. specialized_graph->set_has_vararg(false);
  562. specialized_graph->set_kwonlyargs_count(0);
  563. specialized_graph->ClearDefaultValues();
  564. specialized_graph->set_is_generate(true);
  565. return specialized_graph;
  566. }
  567. void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
  568. std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
  569. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  570. MS_LOG(DEBUG) << "Return ordered cnodes.";
  571. return order_;
  572. } else {
  573. auto this_ptr = shared_from_base<FuncGraph>();
  574. auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1);
  575. auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);
  576. std::list<CNodePtr> cnodes;
  577. auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph);
  578. for (const auto &node : nodes) {
  579. auto cnode = dyn_cast<CNode>(node);
  580. if (cnode) {
  581. cnodes.push_back(cnode);
  582. }
  583. }
  584. return cnodes;
  585. }
  586. }
  587. void FuncGraph::EraseUnusedNodeInOrder() {
  588. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  589. auto mng = manager_.lock();
  590. if (mng) {
  591. auto nodes = mng->nodes()[shared_from_base<FuncGraph>()];
  592. // Erase unused cnode.
  593. for (auto it = order_.begin(); it != order_.end();) {
  594. if (nodes.count(*it)) {
  595. (void)it++;
  596. } else {
  597. MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order.";
  598. it = order_.erase(it);
  599. }
  600. }
  601. }
  602. }
  603. }
  604. void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) {
  605. if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) {
  606. order_.remove(n->cast<CNodePtr>());
  607. MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list.";
  608. }
  609. }
  610. void FuncGraph::CheckOrder() {
  611. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  612. MS_LOG(DEBUG) << "Check graph " << ToString();
  613. for (auto it = order_.begin(); it != order_.end(); (void)it++) {
  614. for (const auto &input_node : (*it)->inputs()) {
  615. if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) {
  616. // Need to reorder the wrong order node.
  617. auto found = std::find(order_.begin(), it, input_node);
  618. if (found == it) {
  619. DumpCNodeList();
  620. MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString()
  621. << " doesn't obey the input dependency, "
  622. << "as input " << input_node->DebugString() << " is not ahead of itself.";
  623. }
  624. }
  625. }
  626. }
  627. auto mng = manager_.lock();
  628. if (mng != nullptr) {
  629. const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()];
  630. if (nodes.size() != (order_.size() + parameters_.size())) {
  631. DumpCNodeList();
  632. MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
  633. << nodes.size() - parameters_.size() << ".";
  634. }
  635. }
  636. MS_LOG(DEBUG) << "Check order okay.";
  637. }
  638. }
  639. const char kPrimHasEffect[] = "_side_effect_flag";
  640. bool FuncGraph::HasEffect(const CNodePtr &cnode) {
  641. auto prim = GetCNodePrimitive(cnode);
  642. if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) {
  643. auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>();
  644. auto prim_val = do_sig->function();
  645. if (prim_val != nullptr && prim_val->isa<Primitive>()) {
  646. prim = prim_val->cast<PrimitivePtr>();
  647. } else {
  648. prim = nullptr;
  649. }
  650. }
  651. if (prim != nullptr) {
  652. auto effect_val = prim->GetAttr(kPrimHasEffect);
  653. if (effect_val && effect_val->isa<BoolImm>()) {
  654. auto effect_bool = GetValue<bool>(effect_val);
  655. return effect_bool;
  656. }
  657. }
  658. return false;
  659. }
  660. std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
  661. std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
  662. for (const auto &node : segment) {
  663. if (roots->size() == 1) {
  664. return roots;
  665. }
  666. auto input_size = node->size();
  667. for (size_t i = 0; i < input_size; i++) {
  668. auto in_node = node->input(i);
  669. auto in_cnode = in_node->cast<CNodePtr>();
  670. if (in_cnode != nullptr) {
  671. (void)roots->erase(in_cnode);
  672. }
  673. }
  674. }
  675. return roots;
  676. }
  677. std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
  678. std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
  679. for (const auto &node : segment) {
  680. if (nodes->size() == 1) {
  681. return nodes;
  682. }
  683. if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
  684. (void)nodes->erase(node);
  685. continue;
  686. }
  687. auto input_size = node->size();
  688. for (size_t i = 0; i < input_size; i++) {
  689. auto in_node = node->input(i);
  690. if (!in_node->isa<CNode>()) {
  691. continue;
  692. }
  693. auto in_cnode = in_node->cast<CNodePtr>();
  694. if (in_cnode != nullptr) {
  695. if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) {
  696. (void)nodes->erase(node);
  697. break;
  698. }
  699. }
  700. }
  701. }
  702. return nodes;
  703. }
  704. void FuncGraph::ReleaseFullOrderToEffectOrder() {
  705. MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << ".";
  706. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  707. std::list<AnfNodePtr> depends_order;
  708. std::vector<CNodePtr> segment;
  709. for (const auto &cnode : order_) {
  710. if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) {
  711. continue;
  712. }
  713. if (HasEffect(cnode)) {
  714. MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << ".";
  715. if (segment.size() > 0) {
  716. auto roots = FindRoots(segment);
  717. for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) {
  718. depends_order.push_back(*iter);
  719. }
  720. }
  721. segment.clear();
  722. depends_order.push_back(cnode);
  723. } else {
  724. MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << ".";
  725. segment.push_back(cnode);
  726. }
  727. }
  728. if (segment.size() > 1) {
  729. auto roots = FindRoots(segment);
  730. for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) {
  731. depends_order.push_back(*iter);
  732. }
  733. }
  734. std::vector<AnfNodePtr> depend_inputs;
  735. auto old_ret = output();
  736. for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) {
  737. if (*iter != old_ret) {
  738. depend_inputs.push_back(*iter);
  739. }
  740. }
  741. set_flags(GRAPH_FLAG_HAS_EFFECT, false);
  742. set_flags(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true);
  743. if (!depend_inputs.empty()) {
  744. SetEffectDepends(depend_inputs);
  745. }
  746. }
  747. }
  748. void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
  749. auto old_ret = output();
  750. std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret};
  751. (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end());
  752. auto new_ret = NewCNode(inputs);
  753. auto mng = manager();
  754. if (mng) {
  755. (void)mng->Replace(old_ret, new_ret);
  756. } else {
  757. return_->set_input(1, new_ret);
  758. }
  759. }
  760. const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
  761. const char kFuncGraphFlagUndetermined[] = "Undeterminate";
  762. } // namespace mindspore