You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager.cc 28 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ir/manager.h"
  19. #include <algorithm>
  20. #include <numeric>
  21. #include <list>
  22. #include "debug/trace_base.h"
  23. #include "ir/func_graph.h"
  24. #include "utils/profile.h"
  25. #include "utils/convert_utils_base.h"
  26. #include "operator/ops.h"
  27. namespace mindspore {
  28. FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
  29. auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
  30. m->Init();
  31. return m;
  32. }
  33. FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
  34. FuncGraphManagerPtr m = nullptr;
  35. bool root = false;
  36. for (auto &fg : func_graphs) {
  37. if (fg == nullptr) {
  38. continue;
  39. }
  40. if (fg->manager() != nullptr) {
  41. m = fg->manager();
  42. break;
  43. }
  44. }
  45. if (m == nullptr) {
  46. std::vector<FuncGraphPtr> tmp;
  47. m = MakeManager(tmp, manage);
  48. root = true;
  49. }
  50. for (auto &fg : func_graphs) {
  51. if (fg == nullptr) {
  52. continue;
  53. }
  54. m->AddFuncGraph(fg, root);
  55. }
  56. return m;
  57. }
  58. FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
  59. std::vector<FuncGraphPtr> func_graphs = {func_graph};
  60. return Manage(func_graphs, manage);
  61. }
  62. FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
  63. : roots_(roots), is_manage_(manage) {
  64. Reset();
  65. }
  66. void FuncGraphManager::Reset() {
  67. func_graphs_ = FuncGraphSet();
  68. all_nodes_ = AnfNodeSet();
  69. node_users_ = NodeUsersMap();
  70. signals_ = std::make_shared<Signals>();
  71. func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
  72. func_graph_parent_ = std::make_shared<ParentComputer>(this);
  73. children_ = std::make_shared<ChildrenComputer>(this);
  74. scopes_ = std::make_shared<ScopeComputer>(this);
  75. free_variables_total_ = std::make_shared<FVTotalComputer>(this);
  76. func_graphs_used_total_ = std::make_shared<FuncGraphsUsedTotalComputer>(this);
  77. recursive_ = std::make_shared<RecursiveComputer>(this);
  78. j_total_ = std::make_shared<FuncGraphJTotalComputer>(this);
  79. limit_ = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1);
  80. }
  81. void FuncGraphManager::Init() {
  82. auto roots = roots_;
  83. roots_ = FuncGraphSet();
  84. for (auto &fg : roots) {
  85. AddFuncGraph(fg, true);
  86. }
  87. }
  88. FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
  89. MS_EXCEPTION_IF_NULL(fg);
  90. MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
  91. func_graph_parents_total_->Recompute(fg);
  92. MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
  93. return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
  94. }
  95. FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
  96. MS_EXCEPTION_IF_NULL(fg);
  97. MS_EXCEPTION_IF_NULL(func_graph_parent_);
  98. MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
  99. func_graph_parent_->Recompute(fg);
  100. if (func_graph_parent_->parent_analysis().count(fg) == 0) {
  101. MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString();
  102. return nullptr;
  103. }
  104. MS_LOG(DEBUG) << "End parents func graph " << fg->ToString();
  105. return func_graph_parent_->parent_analysis()[fg];
  106. }
  107. FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
  108. MS_EXCEPTION_IF_NULL(fg);
  109. MS_EXCEPTION_IF_NULL(children_);
  110. MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
  111. children_->Recompute(fg);
  112. return children_->children_analysis()[fg];
  113. }
  114. FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
  115. MS_EXCEPTION_IF_NULL(fg);
  116. MS_EXCEPTION_IF_NULL(scopes_);
  117. MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
  118. scopes_->Recompute(fg);
  119. MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString();
  120. return scopes_->scope_analysis()[fg];
  121. }
  122. FVTotalMap &FuncGraphManager::free_variables_total() const {
  123. MS_EXCEPTION_IF_NULL(free_variables_total_);
  124. free_variables_total_->Recompute();
  125. return free_variables_total_->fv_total_analysis();
  126. }
  127. FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
  128. MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
  129. func_graphs_used_total_->Recompute(fg);
  130. return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
  131. }
  132. bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
  133. MS_EXCEPTION_IF_NULL(fg);
  134. recursive_->Recompute(fg);
  135. if (recursive_->recursive_analysis().count(fg) == 0) {
  136. MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
  137. return false;
  138. }
  139. return recursive_->recursive_analysis()[fg];
  140. }
  141. std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
  142. MS_EXCEPTION_IF_NULL(fg);
  143. if (recursive(fg)) {
  144. if (!recursive_->recursive_map().count(fg)) {
  145. auto trace = std::list<FuncGraphPtr>();
  146. recursive_->CheckRecursiveGraphs(fg, &trace);
  147. }
  148. if (recursive_->recursive_map().count(fg) == 0) {
  149. MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
  150. return nullptr;
  151. }
  152. return recursive_->recursive_map()[fg];
  153. } else {
  154. return nullptr;
  155. }
  156. }
  157. bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const {
  158. MS_EXCEPTION_IF_NULL(j_total_);
  159. MS_EXCEPTION_IF_NULL(fg);
  160. j_total_->Recompute(fg);
  161. if (j_total_->j_total_analysis().count(fg) == 0) {
  162. MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
  163. return false;
  164. }
  165. return j_total_->j_total_analysis()[fg];
  166. }
  167. // add a func graph to this manager, optionally as a root func graph.
  168. void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
  169. MS_EXCEPTION_IF_NULL(func_graph);
  170. if (is_root) {
  171. roots_.add(func_graph);
  172. }
  173. if (func_graphs_.contains(func_graph)) {
  174. return;
  175. }
  176. AddIntoManaged(func_graph);
  177. std::vector<AnfNodePtr> para = func_graph->parameters();
  178. AcquireNodes(para);
  179. std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
  180. AcquireNodes(return_vec);
  181. }
  182. // clear the all information in manager
  183. void FuncGraphManager::Clear() {
  184. func_graphs_.clear();
  185. all_nodes_.clear();
  186. node_users_.clear();
  187. roots_.clear();
  188. signals_->InvalidateComputer();
  189. }
  190. void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
  191. MS_LOG(DEBUG) << "Start keep roots";
  192. bool root_exist = false;
  193. for (auto &item : func_graphs) {
  194. if (roots_.contains(item)) {
  195. root_exist = true;
  196. break;
  197. }
  198. }
  199. // if the new_root in roots_, we add new_root first, then calculate the func_graphs
  200. // relation to new_root, remove the func_graphs not relation to new_root
  201. // if the new_root not in roots_, we clear the all func_graphs in manager
  202. // then add the new_root
  203. if (root_exist || func_graphs.empty()) {
  204. FuncGraphSet roots(func_graphs);
  205. if (roots.empty()) {
  206. roots = roots_;
  207. } else {
  208. roots_.clear();
  209. for (auto &item : roots) {
  210. AddFuncGraph(item, true);
  211. }
  212. }
  213. FuncGraphSet keep;
  214. for (auto &item : roots) {
  215. MS_LOG(DEBUG) << "roots: " << item->ToString();
  216. keep.update(func_graphs_used_total(item));
  217. #ifdef DEBUG
  218. for (auto &k : keep) {
  219. MS_LOG(DEBUG) << "keep: " << k->ToString();
  220. }
  221. #endif
  222. }
  223. MaybeDropFuncGraphs(func_graphs_ - keep, true);
  224. } else {
  225. Clear();
  226. FuncGraphSet roots(func_graphs);
  227. for (auto &item : roots) {
  228. AddFuncGraph(item, true);
  229. }
  230. }
  231. }
  232. void FuncGraphManager::RemoveRoots() {
  233. MS_LOG(DEBUG) << "Start remove roots";
  234. roots_.clear();
  235. MaybeDropFuncGraphs(func_graphs_, true);
  236. }
  237. void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
  238. MS_EXCEPTION_IF_NULL(fg);
  239. if (is_manage_) {
  240. if (fg->manager() != nullptr && (&(*fg->manager()) != this)) {
  241. MS_LOG(WARNING) << "A func graph can only have one manager.";
  242. }
  243. FuncGraphManagerPtr this_manager = shared_from_this();
  244. fg->set_manager(this_manager);
  245. }
  246. func_graphs_.add(fg);
  247. }
  248. void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
  249. FuncGraphSet todo(func_graphs);
  250. std::set<FuncGraphPtr> dropped;
  251. // int count = 0;
  252. while (!todo.empty()) {
  253. FuncGraphPtr func_graph = todo.pop();
  254. MS_EXCEPTION_IF_NULL(func_graph);
  255. MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString();
  256. if (roots_.contains(func_graph)) {
  257. MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
  258. continue;
  259. }
  260. auto &users_cnode_index = func_graph->func_graph_cnodes_index();
  261. if (!users_cnode_index.empty() && !ignore_users) {
  262. MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
  263. continue;
  264. }
  265. if (dropped.find(func_graph) != dropped.end()) {
  266. MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString();
  267. continue;
  268. }
  269. (void)dropped.insert(func_graph);
  270. std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
  271. todo.update(MaybeDropNodes(return_vec));
  272. }
  273. for (auto &fg : dropped) {
  274. MS_EXCEPTION_IF_NULL(fg);
  275. all_nodes_.difference_update(fg->parameters());
  276. (void)func_graphs_.erase(fg);
  277. if (fg->manager().get() == this) {
  278. fg->set_manager(nullptr);
  279. }
  280. MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString();
  281. }
  282. }
  283. void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) {
  284. MS_EXCEPTION_IF_NULL(inp);
  285. if (direction == kDecEdge) {
  286. MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString();
  287. auto &users_node = node_users_[inp];
  288. if (!users_node.contains(make_pair(node, index))) {
  289. return;
  290. }
  291. (void)users_node.erase(make_pair(node, index));
  292. DropEdge(node, index, inp);
  293. } else {
  294. MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
  295. if (IsValueNode<FuncGraph>(inp)) {
  296. MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
  297. AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
  298. }
  299. auto &users_node = node_users_[inp];
  300. users_node.add(make_pair(node, index));
  301. AddEdge(node, index, inp);
  302. }
  303. }
  304. void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) {
  305. MS_EXCEPTION_IF_NULL(node);
  306. if (node->isa<CNode>()) {
  307. auto cnode = node->cast<CNodePtr>();
  308. int index = 0;
  309. for (auto &inp : cnode->inputs()) {
  310. ProcessEdge(cnode, index, inp, direction);
  311. ++index;
  312. }
  313. }
  314. }
  315. IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) {
  316. if (all_nodes_.contains(node)) {
  317. return EXCLUDE;
  318. } else {
  319. return FOLLOW;
  320. }
  321. }
  322. void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
  323. AnfNodeSet acq;
  324. for (auto &node : nodes) {
  325. AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit_));
  326. all_nodes_.update(new_nodes);
  327. acq.update(new_nodes);
  328. }
  329. for (auto &node : acq) {
  330. MS_EXCEPTION_IF_NULL(node);
  331. auto fg = node->func_graph();
  332. if (fg != nullptr) {
  333. fg->AddNode(node);
  334. }
  335. ProcessInputs(node, kIncEdge);
  336. }
  337. }
  338. FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
  339. AnfNodeSet nodes_ordered(nodes);
  340. FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
  341. while (!nodes_ordered.empty()) {
  342. AnfNodePtr node = nodes_ordered.pop();
  343. MS_EXCEPTION_IF_NULL(node);
  344. if (!all_nodes_.contains(node)) {
  345. continue;
  346. }
  347. AnfNodeIndexSet &users = node_users_[node];
  348. std::vector<AnfNodePtr> parameters;
  349. if (!users.empty() ||
  350. (node->isa<Parameter>() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) {
  351. continue;
  352. }
  353. if (IsValueNode<FuncGraph>(node)) {
  354. auto fg = GetValueNode<FuncGraphPtr>(node);
  355. func_graphs_to_check->add(fg);
  356. MS_LOG(DEBUG) << "Set value of node " << node->DebugString() << " from func graph " << fg->ToString()
  357. << " to null";
  358. }
  359. ProcessInputs(node, kDecEdge);
  360. (void)all_nodes_.erase(node);
  361. if (node->func_graph() != nullptr) {
  362. node->func_graph()->DropNode(node);
  363. }
  364. if (node->isa<CNode>()) {
  365. auto cnode = node->cast<CNodePtr>();
  366. nodes_ordered.update(cnode->inputs());
  367. }
  368. (void)node_users_.erase(node);
  369. }
  370. return func_graphs_to_check;
  371. }
  372. void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters) {
  373. auto tr = Transact();
  374. tr.SetParameters(fg, parameters);
  375. tr.Commit();
  376. }
  377. bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
  378. auto tr = Transact();
  379. bool success = tr.Replace(old_node, new_node);
  380. if (success) {
  381. tr.Commit();
  382. }
  383. return success;
  384. }
  385. void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
  386. auto tr = Transact();
  387. tr.SetEdge(node, index, value);
  388. tr.Commit();
  389. }
  390. void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) {
  391. AnfNodePtr source_return = source->get_return();
  392. AnfNodePtr source_output = source->output();
  393. AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
  394. int index = 0;
  395. (void)node_users_[source_prim].erase(make_pair(source_return, index));
  396. DropEdge(source_return, index, source_prim);
  397. index = 1;
  398. (void)node_users_[source_output].erase(make_pair(source_return, index));
  399. DropEdge(source_return, index, source_output);
  400. (void)all_nodes_.erase(source_return);
  401. (void)node_users_.erase(source_return);
  402. source->DropNode(source_return);
  403. for (auto &node : source->nodes()) {
  404. node->set_func_graph(target);
  405. if (node->scope() == kDefaultScope) {
  406. node->set_scope(scope);
  407. }
  408. }
  409. MoveAllNodes(source, target);
  410. all_nodes_.difference_update(source->parameters());
  411. (void)func_graphs_.erase(source);
  412. if (source->manager().get() == this) {
  413. source->set_manager(nullptr);
  414. }
  415. }
  416. void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) {
  417. auto fg = node->func_graph();
  418. if (input->isa<ValueNode>()) {
  419. fg->AddValueNode(input);
  420. if (IsValueNode<FuncGraph>(input)) {
  421. auto used = GetValueNode<FuncGraphPtr>(input);
  422. used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
  423. if (fg->AddFuncGraphUsed(used)) {
  424. signals_->InvalidateComputer();
  425. }
  426. if (IsPrimitiveCNode(node, prim::kPrimJ)) {
  427. fg->AddJFuncGraph(used);
  428. }
  429. }
  430. } else if (fg != nullptr && fg != input->func_graph()) {
  431. if (fg->AddFreeVariable(input)) {
  432. signals_->InvalidateComputer();
  433. }
  434. }
  435. }
  436. void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) {
  437. auto fg = node->func_graph();
  438. if (input->isa<ValueNode>()) {
  439. fg->DropValueNode(input);
  440. if (IsValueNode<FuncGraph>(input)) {
  441. auto used = GetValueNode<FuncGraphPtr>(input);
  442. used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
  443. if (fg->DropFuncGraphUsed(used)) {
  444. signals_->InvalidateComputer();
  445. }
  446. if (IsPrimitiveCNode(node, prim::kPrimJ)) {
  447. fg->DropJFuncGraph(used);
  448. }
  449. }
  450. } else if (fg != nullptr && fg != input->func_graph()) {
  451. if (fg->DropFreeVariable(input)) {
  452. signals_->InvalidateComputer();
  453. }
  454. }
  455. }
  456. void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
  457. target->CopyNodes(source);
  458. target->CopyValueNodes(source);
  459. target->CopyFuncGraphCNodesIndex(source);
  460. target->CopyFreeVariables(source);
  461. target->CopyFuncGraphsUsed(source);
  462. target->CopyJFuncGraphs(source);
  463. signals_->InvalidateComputer();
  464. source->ClearNodes();
  465. source->ClearValueNodes();
  466. source->ClearFuncGraphCNodesIndex();
  467. source->ClearFreeVariables();
  468. source->ClearFuncGraphsUsed();
  469. source->ClearJFuncGraphs();
  470. }
  471. FuncGraphTransaction FuncGraphManager::Transact() {
  472. auto tr = FuncGraphTransaction(this);
  473. return tr;
  474. }
  475. void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges,
  476. EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) {
  477. for (auto &iter : changes) {
  478. auto operation = iter.op;
  479. auto args = iter.args;
  480. if (operation == Change::kTxSetEdge) {
  481. auto edge = args.cast<ArgsOfSetEdge>();
  482. auto old_node = edge.root_node->input(edge.index);
  483. (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
  484. (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
  485. (*rms)[old_node] += 1;
  486. (*adds)[edge.new_node] += 1;
  487. edge.root_node->set_input(edge.index, edge.new_node);
  488. } else if (operation == Change::kTxSetParams) {
  489. auto param = args.cast<ArgsOfSetParams>();
  490. MS_EXCEPTION_IF_NULL(param.func_graph);
  491. auto old_parameters = param.func_graph->parameters();
  492. for (auto &p : param.params) {
  493. (*adds)[p] += 1;
  494. }
  495. for (auto &p : old_parameters) {
  496. (*rms)[p] += 1;
  497. }
  498. param.func_graph->set_parameters(param.params);
  499. }
  500. }
  501. }
  502. void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) {
  503. EdgeTupleCounter add_edges;
  504. EdgeTupleCounter rm_edges;
  505. Counter<AnfNodePtr> adds;
  506. Counter<AnfNodePtr> rms;
  507. ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms);
  508. auto sub_edges = add_edges - rm_edges;
  509. for (auto &iter : sub_edges) {
  510. auto root_node = iter.first.first;
  511. int index = iter.first.second.first;
  512. auto new_node = iter.first.second.second;
  513. ProcessEdge(root_node, index, new_node, kIncEdge);
  514. }
  515. auto sub_nodes = adds - rms;
  516. std::vector<AnfNodePtr> nodes;
  517. (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes),
  518. [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
  519. AcquireNodes(nodes);
  520. auto sub_edges_reverse = rm_edges - add_edges;
  521. for (auto &iter : sub_edges_reverse) {
  522. auto root_node = iter.first.first;
  523. int index = iter.first.second.first;
  524. auto old_node = iter.first.second.second;
  525. ProcessEdge(root_node, index, old_node, kDecEdge);
  526. }
  527. auto sub_nodes_reverse = rms - adds;
  528. std::vector<AnfNodePtr> nodes_reverse;
  529. (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse),
  530. [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
  531. auto drop_func_graphs = MaybeDropNodes(nodes_reverse);
  532. MaybeDropFuncGraphs(*drop_func_graphs);
  533. }
  534. void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params) {
  535. changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
  536. }
  537. bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
  538. MS_EXCEPTION_IF_NULL(old_node);
  539. MS_EXCEPTION_IF_NULL(new_node);
  540. FuncGraphPtr old_func_graph = old_node->func_graph();
  541. if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) {
  542. MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
  543. return false;
  544. }
  545. auto users = manager_->node_users()[old_node];
  546. for (auto &node : users) {
  547. SetEdge(node.first, node.second, new_node);
  548. }
  549. return true;
  550. }
  551. void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
  552. if (k < 0) {
  553. MS_LOG(EXCEPTION) << "Invalid value k = " << k;
  554. }
  555. MS_EXCEPTION_IF_NULL(src_node);
  556. auto cnode = src_node->cast<CNodePtr>();
  557. if (cnode == nullptr) {
  558. MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed.";
  559. }
  560. changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)});
  561. }
  562. void FuncGraphTransaction::Commit() {
  563. std::vector<Change> changes;
  564. changes_.swap(changes);
  565. manager_->CommitChanges(changes);
  566. }
  567. DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
  568. MS_EXCEPTION_IF_NULL(manager_);
  569. manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
  570. validate_ = false;
  571. }
  572. void DepComputer::Recompute() {
  573. if (!validate_) {
  574. RealRecompute();
  575. validate_ = true;
  576. }
  577. }
  578. void DepComputer::Recompute(const FuncGraphPtr &fg) {
  579. if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
  580. RealRecompute(fg);
  581. func_graphs_validate_[fg] = true;
  582. }
  583. }
  584. FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) {
  585. if (fg->seen_ == seen_num) {
  586. return std::make_shared<FuncGraphSet>();
  587. }
  588. FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
  589. // Append all the fvs in fg.
  590. auto &fvs = fg->free_variables();
  591. for (auto fv : fvs) {
  592. parents->add(fv.first->func_graph());
  593. }
  594. // Search the fv in fg's child func graph.
  595. auto &fgs = fg->func_graphs_used();
  596. for (auto &item : fgs) {
  597. fg->seen_ = seen_num;
  598. auto gt = item.first;
  599. parents->update(SeekParents(gt, seen_num));
  600. }
  601. (void)parents->erase(fg);
  602. return parents;
  603. }
  604. void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
  605. MS_EXCEPTION_IF_NULL(fg);
  606. func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration()));
  607. }
  608. bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
  609. auto l1 = lhs.second.size();
  610. auto l2 = rhs.second.size();
  611. return l1 < l2;
  612. }
  613. void ParentComputer::RealRecompute(FuncGraphPtr fg) {
  614. this->parent_analysis_[fg] = nullptr;
  615. // Note: must be a copy other than reference as it is modified thereafter.
  616. auto deps = this->manager_->func_graph_parents_total(fg);
  617. if (deps.empty()) {
  618. this->parent_analysis_[fg] = nullptr;
  619. return;
  620. } else if (deps.size() == 1) {
  621. this->parent_analysis_[fg] = deps.pop();
  622. return;
  623. } else {
  624. // return nearest parent as parent
  625. FuncGraphSet deps_copy(deps);
  626. for (auto &dep : deps) {
  627. auto parent_deps = this->manager_->func_graph_parents_total(dep);
  628. for (auto &p_d : parent_deps) {
  629. if (deps_copy.count(p_d)) {
  630. (void)deps_copy.erase(p_d);
  631. }
  632. }
  633. if (deps_copy.size() == 1) {
  634. this->parent_analysis_[fg] = deps_copy.pop();
  635. return;
  636. }
  637. }
  638. }
  639. }
  640. void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
  641. MS_EXCEPTION_IF_NULL(manager_);
  642. auto used_fg_total = manager_->func_graphs_used_total(fg);
  643. for (auto &used_fg : used_fg_total) {
  644. if (manager_->parent(used_fg) == fg) {
  645. children_analysis_[fg].add(used_fg);
  646. }
  647. }
  648. }
  649. void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
  650. MS_EXCEPTION_IF_NULL(manager_);
  651. auto &children = manager_->children(fg);
  652. scope_analysis_[fg] = FuncGraphSet();
  653. scope_analysis_[fg].add(fg);
  654. for (auto &child : children) {
  655. scope_analysis_[fg].add(child);
  656. }
  657. }
  658. void FVTotalComputer::RealRecompute() {
  659. auto manager = DepComputer::manager_;
  660. MS_EXCEPTION_IF_NULL(manager);
  661. for (auto &fg : manager->func_graphs()) {
  662. fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
  663. }
  664. for (auto &fg : manager->func_graphs()) {
  665. // add all free variable nodes
  666. AnfNodeCounterMap items = fg->free_variables();
  667. for (auto &iter : items) {
  668. auto curr = fg;
  669. while (curr != nullptr) {
  670. fv_total_analysis_[curr][iter.first] = iter.second;
  671. curr = manager->parent(curr);
  672. if (curr != nullptr) {
  673. const AnfNodeSet &all_nodes = curr->nodes();
  674. if (all_nodes.contains(iter.first)) {
  675. break;
  676. }
  677. }
  678. }
  679. }
  680. // add all FGs of free variables
  681. auto &used = fg->func_graphs_used();
  682. for (auto &iter : used) {
  683. auto p = manager->parent(iter.first);
  684. if (p == nullptr) {
  685. continue;
  686. }
  687. auto curr = fg;
  688. while (curr != p) {
  689. fv_total_analysis_[curr][iter.first] = iter.second;
  690. curr = manager->parent(curr);
  691. }
  692. }
  693. }
  694. }
  695. void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
  696. MS_EXCEPTION_IF_NULL(manager_);
  697. std::vector<FuncGraphPtr> todo;
  698. std::vector<FuncGraphPtr> todo_new;
  699. todo.push_back(fg);
  700. while (!todo.empty()) {
  701. todo_new.clear();
  702. for (auto &gt : todo) {
  703. for (auto &item : gt->func_graphs_used()) {
  704. auto used_fg = item.first;
  705. if (used_fg == fg) {
  706. func_graph_used_total_analysis_[fg].add(used_fg);
  707. continue;
  708. }
  709. if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) {
  710. todo_new.push_back(used_fg);
  711. }
  712. MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString();
  713. func_graph_used_total_analysis_[fg].add(used_fg);
  714. }
  715. }
  716. todo = todo_new;
  717. }
  718. }
  719. bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
  720. MS_EXCEPTION_IF_NULL(manager);
  721. std::vector<FuncGraphPtr> todo;
  722. std::vector<FuncGraphPtr> todo_new;
  723. todo.push_back(fg);
  724. FuncGraphSet used_total;
  725. while (!todo.empty()) {
  726. todo_new.clear();
  727. for (auto &gt : todo) {
  728. for (auto &item : gt->func_graphs_used()) {
  729. auto used_g = item.first;
  730. if (used_g == fg) {
  731. return true;
  732. }
  733. if (used_total.count(used_g) == 0) {
  734. todo_new.push_back(used_g);
  735. }
  736. used_total.add(used_g);
  737. }
  738. }
  739. todo = todo_new;
  740. }
  741. return false;
  742. }
  743. void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
  744. this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
  745. }
  746. void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
  747. MS_EXCEPTION_IF_NULL(trace);
  748. auto res = std::find(trace->begin(), trace->end(), fg);
  749. // find recursive
  750. if (res != trace->end()) {
  751. auto recur_ptr = std::make_shared<std::list<FuncGraphPtr>>(res, trace->end());
  752. for (auto iter = res; iter != trace->end(); (void)iter++) {
  753. MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString();
  754. recursive_map_[*iter] = recur_ptr;
  755. }
  756. } else {
  757. trace->push_back(fg);
  758. auto &items = fg->func_graphs_used();
  759. for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
  760. CheckRecursiveGraphs(iter->first, trace);
  761. }
  762. trace->pop_back();
  763. if (!recursive_map_.count(fg)) {
  764. recursive_map_[fg] = nullptr;
  765. }
  766. }
  767. }
  768. bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
  769. if (fg->seen_ == seen_num) {
  770. MS_LOG(DEBUG) << fg->ToString() << " had been checked";
  771. return false;
  772. }
  773. auto &j_fgs = fg->j_func_graphs();
  774. if (!j_fgs.empty()) {
  775. // check g1->J(fg)->g2->g cycle;
  776. auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) {
  777. return iter.first->seen_ != seen_num;
  778. });
  779. if (contains_j != j_fgs.end()) {
  780. MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
  781. return true;
  782. }
  783. }
  784. fg->seen_ = seen_num;
  785. // check if func graphs used contains J(func_graph);
  786. for (auto &item : fg->func_graphs_used()) {
  787. auto used_g = item.first;
  788. if (SeekJ(used_g, seen_num)) {
  789. MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
  790. return true;
  791. }
  792. }
  793. MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)";
  794. return false;
  795. }
  796. void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
  797. this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration());
  798. }
  799. } // namespace mindspore