You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ir/func_graph.h"
  19. #include <algorithm>
  20. #include <sstream>
  21. #include <utility>
  22. #include "debug/trace.h"
  23. #include "ir/manager.h"
  24. #include "operator/ops.h"
  25. #include "pybind_api/export_flags.h"
  26. #include "utils/ordered_set.h"
  27. #include "utils/convert_utils_base.h"
  28. namespace mindspore {
  29. /*
  30. * Methods of Graph
  31. */
  32. FuncGraph::FuncGraph()
  33. : attrs_(),
  34. transforms_(),
  35. parameter_default_value_(),
  36. seen_(0),
  37. parameters_(),
  38. has_vararg_(false),
  39. has_kwarg_(false),
  40. kwonlyargs_count_(0),
  41. hyper_param_count_(0),
  42. is_generated_(false),
  43. return_(nullptr),
  44. manager_(std::weak_ptr<FuncGraphManager>()) {
  45. debug_info_ = std::make_shared<GraphDebugInfo>();
  46. }
  47. AnfNodePtr FuncGraph::output() const {
  48. // If return value is set, return should have two inputs.
  49. if (return_ != nullptr && return_->inputs().size() == 2) {
  50. return return_->input(1);
  51. } else {
  52. // If not set yet, return nullptr.
  53. return nullptr;
  54. }
  55. }
  56. ParameterPtr FuncGraph::add_parameter() {
  57. FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
  58. ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
  59. add_parameter(p);
  60. return p;
  61. }
  62. void FuncGraph::add_parameter(const ParameterPtr &p) {
  63. if (manager_.lock()) {
  64. std::vector<AnfNodePtr> new_params = parameters_;
  65. new_params.push_back(p);
  66. manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
  67. } else {
  68. parameters_.push_back(p);
  69. }
  70. }
  71. ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
  72. FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
  73. ParameterPtr p = std::make_shared<Parameter>(this_graph);
  74. p->set_name(name);
  75. p->debug_info()->set_name(name);
  76. std::vector<AnfNodePtr> new_params = parameters_;
  77. // append parameter
  78. new_params.push_back(p);
  79. if (manager_.lock()) {
  80. manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
  81. } else {
  82. parameters_.push_back(p);
  83. }
  84. hyper_param_count_++;
  85. return p;
  86. }
  87. bool FuncGraph::has_flag(const std::string &key) {
  88. auto iter = attrs_.find(key);
  89. if (iter != attrs_.cend()) {
  90. if (iter->second->isa<BoolImm>()) {
  91. return GetValue<bool>(iter->second);
  92. }
  93. MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
  94. }
  95. return false;
  96. }
  97. bool FuncGraph::has_attr(const std::string &key) {
  98. auto iter = attrs_.find(key);
  99. return !(iter == attrs_.cend());
  100. }
  101. ValuePtr FuncGraph::get_attr(const std::string &key) {
  102. auto iter = attrs_.find(key);
  103. return iter == attrs_.cend() ? nullptr : iter->second;
  104. }
  105. CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  106. CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
  107. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  108. order_.push_back(cnode);
  109. MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order.";
  110. }
  111. return cnode;
  112. }
  113. CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) {
  114. CNodePtr app = NewCNode(inputs);
  115. app->set_scope(scope);
  116. return app;
  117. }
  118. void FuncGraph::DumpCNodeList() {
  119. MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
  120. for (const auto &cnode : order_) {
  121. MS_LOG(INFO) << cnode->DebugString();
  122. }
  123. }
  124. std::string FuncGraph::ToString() const {
  125. return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
  126. }
  127. GraphDebugInfoPtr FuncGraph::debug_info() {
  128. MS_EXCEPTION_IF_NULL(this->debug_info_);
  129. if (this->debug_info_->get_graph() == nullptr) {
  130. this->debug_info_->set_graph(shared_from_base<FuncGraph>());
  131. }
  132. return this->debug_info_;
  133. }
  134. const AnfNodeSet &FuncGraph::nodes() { return nodes_; }
  135. void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); }
  136. void FuncGraph::ClearNodes() { nodes_.clear(); }
  137. void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); }
  138. void FuncGraph::DropNode(AnfNodePtr node) {
  139. nodes_.erase(node);
  140. auto graph = node->func_graph();
  141. // Remove the node from order list.
  142. if (graph) {
  143. graph->EraseUnusedNodeInOrder(node);
  144. }
  145. }
  146. const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }
  147. void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
  148. auto &others = source->value_nodes();
  149. for (auto it = others.begin(); it != others.end(); it++) {
  150. AddValueNode(it->first, it->second);
  151. }
  152. }
  153. void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
  154. void FuncGraph::AddValueNode(AnfNodePtr node, int count) {
  155. if (value_nodes_.count(node) == 0) {
  156. value_nodes_[node] = count;
  157. } else {
  158. value_nodes_[node] += count;
  159. }
  160. }
  161. void FuncGraph::DropValueNode(AnfNodePtr node) {
  162. if (value_nodes_.count(node) != 0) {
  163. if (value_nodes_[node] == 1) {
  164. (void)value_nodes_.erase(node);
  165. } else {
  166. value_nodes_[node]--;
  167. if (value_nodes_[node] < 0) {
  168. MS_LOG(EXCEPTION) << "Count of ValueNode '" << node
  169. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
  170. }
  171. }
  172. }
  173. }
  174. const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }
  175. void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
  176. auto &others = source->free_variables();
  177. for (auto it = others.begin(); it != others.end(); it++) {
  178. if (it->first->func_graph().get() != this) {
  179. (void)AddFreeVariable(it->first, it->second);
  180. }
  181. }
  182. }
  183. void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }
  184. bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) {
  185. if (free_variables_.count(node) == 0) {
  186. free_variables_[node] = count;
  187. return true;
  188. } else {
  189. free_variables_[node] += count;
  190. return false;
  191. }
  192. }
  193. bool FuncGraph::DropFreeVariable(AnfNodePtr node) {
  194. if (free_variables_.count(node) != 0) {
  195. if (free_variables_[node] == 1) {
  196. (void)free_variables_.erase(node);
  197. return true;
  198. } else {
  199. free_variables_[node]--;
  200. if (free_variables_[node] < 0) {
  201. MS_LOG(EXCEPTION) << "Count of free variable '" << node
  202. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
  203. }
  204. }
  205. }
  206. return false;
  207. }
  208. const BaseRefCounterMap &FuncGraph::free_variables_total() {
  209. auto mng = manager_.lock();
  210. MS_EXCEPTION_IF_NULL(mng);
  211. auto &fv_total = mng->free_variables_total();
  212. return fv_total[shared_from_base<FuncGraph>()];
  213. }
  214. std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
  215. std::vector<AnfNodePtr> nodes;
  216. const auto &fv_total = this->free_variables_total();
  217. for (auto &p : fv_total) {
  218. auto key = p.first;
  219. if (utils::isa<AnfNodePtr>(key)) {
  220. nodes.push_back(utils::cast<AnfNodePtr>(key));
  221. }
  222. }
  223. return nodes;
  224. }
  225. std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
  226. std::vector<FuncGraphPtr> func_graphs;
  227. const auto &fv_total = this->free_variables_total();
  228. for (auto &p : fv_total) {
  229. auto key = p.first;
  230. if (utils::isa<FuncGraphPtr>(key)) {
  231. func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
  232. }
  233. }
  234. return func_graphs;
  235. }
  236. const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; }
  237. void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
  238. auto &others = source->func_graphs_used();
  239. for (auto it = others.begin(); it != others.end(); it++) {
  240. (void)AddFuncGraphUsed(it->first, it->second);
  241. }
  242. func_graphs_used_.erase(source);
  243. }
  244. void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
  245. bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) {
  246. if (func_graphs_used_.count(fg) == 0) {
  247. func_graphs_used_[fg] = count;
  248. return true;
  249. } else {
  250. func_graphs_used_[fg] += count;
  251. return false;
  252. }
  253. }
  254. bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) {
  255. if (func_graphs_used_.count(fg) != 0) {
  256. if (func_graphs_used_[fg] == 1) {
  257. (void)func_graphs_used_.erase(fg);
  258. return true;
  259. } else {
  260. func_graphs_used_[fg]--;
  261. if (func_graphs_used_[fg] < 0) {
  262. MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
  263. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
  264. }
  265. }
  266. }
  267. return false;
  268. }
  269. const FuncGraphSet &FuncGraph::func_graphs_used_total() {
  270. auto mng = manager_.lock();
  271. MS_EXCEPTION_IF_NULL(mng);
  272. auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
  273. return used;
  274. }
  275. const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }
  276. void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
  277. auto &others = source->func_graph_cnodes_index();
  278. for (auto it = others.begin(); it != others.end(); it++) {
  279. // Ignore the user graph who may own itself.
  280. auto fg = it->first->first->func_graph();
  281. MS_EXCEPTION_IF_NULL(fg);
  282. if (fg.get() != this) {
  283. AddFuncGraphCNodeIndex(it->first, it->second);
  284. }
  285. }
  286. }
  287. void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }
  288. void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) {
  289. if (func_graph_cnodes_index_.count(pair) == 0) {
  290. func_graph_cnodes_index_[pair] = count;
  291. } else {
  292. func_graph_cnodes_index_[pair] += count;
  293. }
  294. }
  295. void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
  296. if (func_graph_cnodes_index_.count(pair) != 0) {
  297. if (func_graph_cnodes_index_[pair] == 1) {
  298. (void)func_graph_cnodes_index_.erase(pair);
  299. } else {
  300. func_graph_cnodes_index_[pair]--;
  301. if (func_graph_cnodes_index_[pair] < 0) {
  302. MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
  303. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
  304. }
  305. }
  306. }
  307. }
  308. const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }
  309. void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
  310. auto &others = source->j_func_graphs();
  311. for (auto it = others.begin(); it != others.end(); it++) {
  312. AddJFuncGraph(it->first, it->second);
  313. }
  314. }
  315. void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }
  316. void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
  317. if (j_func_graphs_.count(fg) == 0) {
  318. j_func_graphs_[fg] = count;
  319. } else {
  320. j_func_graphs_[fg] += count;
  321. }
  322. }
  323. void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
  324. if (j_func_graphs_.count(fg) != 0) {
  325. if (j_func_graphs_[fg] == 1) {
  326. (void)j_func_graphs_.erase(fg);
  327. } else {
  328. j_func_graphs_[fg]--;
  329. if (j_func_graphs_[fg] < 0) {
  330. MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
  331. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
  332. }
  333. }
  334. }
  335. }
  336. FuncGraphPtr FuncGraph::parent() {
  337. // report the bug early.
  338. if (manager_.lock() == nullptr) {
  339. MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
  340. << " NodeInfo: " << trace::GetDebugInfo(debug_info());
  341. }
  342. auto mng = manager_.lock();
  343. MS_EXCEPTION_IF_NULL(mng);
  344. return mng->parent(shared_from_base<FuncGraph>());
  345. }
  346. const FuncGraphSet &FuncGraph::children() {
  347. auto mng = manager_.lock();
  348. MS_EXCEPTION_IF_NULL(mng);
  349. return mng->children(shared_from_base<FuncGraph>());
  350. }
  351. const FuncGraphSet &FuncGraph::scope() {
  352. auto mng = manager_.lock();
  353. MS_EXCEPTION_IF_NULL(mng);
  354. return mng->scopes(shared_from_base<FuncGraph>());
  355. }
  356. bool FuncGraph::recursive() {
  357. auto mng = manager_.lock();
  358. MS_EXCEPTION_IF_NULL(mng);
  359. return mng->recursive(shared_from_base<FuncGraph>());
  360. }
  361. std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
  362. auto mng = manager_.lock();
  363. MS_EXCEPTION_IF_NULL(mng);
  364. return mng->recursive_graphs(shared_from_base<FuncGraph>());
  365. }
  366. AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
  367. auto itr = this->parameter_default_value_.find(name);
  368. if (itr == parameter_default_value_.end()) {
  369. return nullptr;
  370. }
  371. auto default_value = itr->second;
  372. if (default_value == nullptr) {
  373. MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist";
  374. }
  375. if (IsValueNode<Null>(default_value)) {
  376. return nullptr;
  377. }
  378. return default_value;
  379. }
  380. // set the default values
  381. void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
  382. auto all_is_null =
  383. std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode<Null>(node); });
  384. if (value_list.empty()) {
  385. all_is_null = true;
  386. }
  387. for (size_t i = 0; i < name_list.size(); ++i) {
  388. if (!all_is_null) {
  389. this->parameter_default_value_[name_list[i]] = value_list[i];
  390. }
  391. }
  392. }
  393. void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
  394. size_t FuncGraph::GetDefaultValueCount() {
  395. int null_count =
  396. std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
  397. [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<Null>(pair.second); });
  398. return parameter_default_value_.size() - IntToSize(null_count);
  399. }
  400. AnfNodePtr FuncGraph::GetVariableArgParameter() {
  401. if (!has_vararg_) {
  402. return nullptr;
  403. }
  404. if (has_kwarg_) {
  405. if (parameters_.size() < hyper_param_count_ + 2) {
  406. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  407. << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
  408. }
  409. return parameters_[parameters_.size() - hyper_param_count_ - 2];
  410. }
  411. if (parameters_.size() < hyper_param_count_ + 1) {
  412. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  413. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  414. }
  415. return parameters_[parameters_.size() - hyper_param_count_ - 1];
  416. }
  417. std::string FuncGraph::GetVariableArgName() {
  418. if (!has_vararg_) {
  419. return "";
  420. }
  421. if (has_kwarg_) {
  422. if (parameters_.size() < hyper_param_count_ + 2) {
  423. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  424. << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
  425. }
  426. return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast<ParameterPtr>()->name();
  427. }
  428. if (parameters_.size() < hyper_param_count_ + 1) {
  429. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  430. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  431. }
  432. return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>()->name();
  433. }
  434. AnfNodePtr FuncGraph::GetVariableKwargParameter() {
  435. if (has_kwarg_) {
  436. if (parameters_.size() < hyper_param_count_ + 1) {
  437. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  438. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  439. }
  440. return parameters_[parameters_.size() - hyper_param_count_ - 1];
  441. }
  442. return nullptr;
  443. }
  444. std::string FuncGraph::GetVariableKwargName() {
  445. if (has_kwarg_) {
  446. if (parameters_.size() < hyper_param_count_ + 1) {
  447. MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
  448. << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
  449. }
  450. return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>()->name();
  451. }
  452. return "";
  453. }
  454. int FuncGraph::GetPositionalArgsCount() const {
  455. int count = SizeToInt(parameters_.size());
  456. if (has_kwarg_) {
  457. count--;
  458. }
  459. if (has_vararg_) {
  460. count--;
  461. }
  462. return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
  463. }
  464. AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
  465. for (size_t i = 0; i < parameters_.size(); ++i) {
  466. MS_EXCEPTION_IF_NULL(parameters_[i]);
  467. auto param_cast = parameters_[i]->cast<ParameterPtr>();
  468. MS_EXCEPTION_IF_NULL(param_cast);
  469. if (param_cast->name() == name) {
  470. return parameters_[i];
  471. }
  472. }
  473. return nullptr;
  474. }
  475. void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
  476. std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
  477. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  478. MS_LOG(DEBUG) << "Return ordered cnodes.";
  479. return order_;
  480. } else {
  481. auto this_ptr = shared_from_base<FuncGraph>();
  482. auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1);
  483. auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);
  484. std::list<CNodePtr> cnodes;
  485. auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph);
  486. for (const auto &node : nodes) {
  487. auto cnode = dyn_cast<CNode>(node);
  488. if (cnode) {
  489. cnodes.push_back(cnode);
  490. }
  491. }
  492. return cnodes;
  493. }
  494. }
  495. void FuncGraph::EraseUnusedNodeInOrder() {
  496. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  497. auto mng = manager_.lock();
  498. if (mng) {
  499. auto &all_nodes = nodes();
  500. // Erase unused cnode.
  501. for (auto it = order_.begin(); it != order_.end();) {
  502. if (all_nodes.count(*it)) {
  503. (void)it++;
  504. } else {
  505. MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order.";
  506. it = order_.erase(it);
  507. }
  508. }
  509. }
  510. }
  511. }
  512. void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) {
  513. if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) {
  514. order_.remove(n->cast<CNodePtr>());
  515. MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list.";
  516. }
  517. }
  518. void FuncGraph::CheckOrder() {
  519. if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
  520. MS_LOG(DEBUG) << "Check graph " << ToString();
  521. for (auto it = order_.begin(); it != order_.end(); (void)it++) {
  522. for (const auto &input_node : (*it)->inputs()) {
  523. if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) {
  524. // Need to reorder the wrong order node.
  525. auto found = std::find(order_.begin(), it, input_node);
  526. if (found == it) {
  527. DumpCNodeList();
  528. MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString()
  529. << " doesn't obey the input dependency, "
  530. << "as input " << input_node->DebugString() << " is not ahead of itself.";
  531. }
  532. }
  533. }
  534. }
  535. auto mng = manager_.lock();
  536. if (mng != nullptr) {
  537. const auto &all_nodes = nodes();
  538. if (all_nodes.size() != (order_.size() + parameters_.size())) {
  539. DumpCNodeList();
  540. MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
  541. << all_nodes.size() - parameters_.size() << ".";
  542. }
  543. }
  544. MS_LOG(DEBUG) << "Check order okay.";
  545. }
  546. }
  547. size_t NewFgSeenGeneration() {
  548. static size_t fg_seen_generation = 0;
  549. return ++fg_seen_generation;
  550. }
  551. const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
  552. const char kFuncGraphFlagUndetermined[] = "Undeterminate";
  553. } // namespace mindspore