You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager.cc 40 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ir/manager.h"
  19. #include <algorithm>
  20. #include <numeric>
  21. #include <list>
  22. #include "./common.h"
  23. #include "utils/profile.h"
  24. #include "operator/ops.h"
  25. #include "debug/trace.h"
  26. namespace mindspore {
  27. FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
  28. auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
  29. m->Init();
  30. return m;
  31. }
  32. FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
  33. FuncGraphManagerPtr m = nullptr;
  34. bool root = false;
  35. for (auto &fg : func_graphs) {
  36. if (fg == nullptr) {
  37. continue;
  38. }
  39. if (fg->manager() != nullptr) {
  40. m = fg->manager();
  41. break;
  42. }
  43. }
  44. if (m == nullptr) {
  45. std::vector<FuncGraphPtr> tmp;
  46. m = MakeManager(tmp, manage);
  47. root = true;
  48. }
  49. for (auto &fg : func_graphs) {
  50. if (fg == nullptr) {
  51. continue;
  52. }
  53. m->AddFuncGraph(fg, root);
  54. }
  55. return m;
  56. }
  57. FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
  58. std::vector<FuncGraphPtr> func_graphs = {func_graph};
  59. return Manage(func_graphs, manage);
  60. }
  61. FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
  62. : roots_(roots), is_manage_(manage) {
  63. Reset();
  64. }
  65. void FuncGraphManager::Reset() {
  66. func_graphs_ = FuncGraphSet();
  67. all_nodes_ = AnfNodeSet();
  68. node_users_ = NodeUsersMap();
  69. signals_ = std::make_shared<Signals>();
  70. nodes_ = std::make_shared<NodesCollector>(this);
  71. valuenodes_ = std::make_shared<ValueNodesCollector>(this);
  72. free_variables_direct_ = std::make_shared<FVDirectCollector>(this);
  73. func_graph_valuenodes_ = std::make_shared<FuncGraphValueNodesCollector>(this);
  74. func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this);
  75. func_graph_users_ = std::make_shared<FuncGraphUsersCollector>(this);
  76. func_graph_user_cnodes_ = std::make_shared<FuncGraphUserNodesCollector>(this);
  77. func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this);
  78. func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this);
  79. func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this);
  80. func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
  81. func_graph_parent_ = std::make_shared<ParentComputer>(this);
  82. children_ = std::make_shared<ChildrenComputer>(this);
  83. scopes_ = std::make_shared<ScopeComputer>(this);
  84. free_variables_total_ = std::make_shared<FVTotalComputer>(this);
  85. func_graphs_used_total_ = std::make_shared<FuncGraphsUsedTotalComputer>(this);
  86. recursive_ = std::make_shared<RecursiveComputer>(this);
  87. j_total_ = std::make_shared<FuncGraphJTotalComputer>(this);
  88. }
  89. void FuncGraphManager::Init() {
  90. auto roots = roots_;
  91. roots_ = FuncGraphSet();
  92. for (auto &fg : roots) {
  93. AddFuncGraph(fg, true);
  94. }
  95. }
  96. FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
  97. MS_EXCEPTION_IF_NULL(fg);
  98. MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
  99. func_graph_parents_total_->Recompute(fg);
  100. MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
  101. return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
  102. }
  103. FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
  104. MS_EXCEPTION_IF_NULL(fg);
  105. MS_EXCEPTION_IF_NULL(func_graph_parent_);
  106. MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
  107. func_graph_parent_->Recompute(fg);
  108. if (func_graph_parent_->parent_analysis().count(fg) == 0) {
  109. MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString();
  110. return nullptr;
  111. }
  112. MS_LOG(DEBUG) << "End parents func graph " << fg->ToString();
  113. return func_graph_parent_->parent_analysis()[fg];
  114. }
  115. FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
  116. MS_EXCEPTION_IF_NULL(fg);
  117. MS_EXCEPTION_IF_NULL(children_);
  118. MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
  119. children_->Recompute(fg);
  120. return children_->children_analysis()[fg];
  121. }
  122. FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
  123. MS_EXCEPTION_IF_NULL(fg);
  124. MS_EXCEPTION_IF_NULL(scopes_);
  125. MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
  126. scopes_->Recompute(fg);
  127. MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString();
  128. return scopes_->scope_analysis()[fg];
  129. }
  130. FVTotalMap &FuncGraphManager::free_variables_total() const {
  131. MS_EXCEPTION_IF_NULL(free_variables_total_);
  132. free_variables_total_->Recompute();
  133. return free_variables_total_->fv_total_analysis();
  134. }
  135. FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
  136. MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
  137. func_graphs_used_total_->Recompute(fg);
  138. return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
  139. }
  140. bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
  141. MS_EXCEPTION_IF_NULL(fg);
  142. recursive_->Recompute(fg);
  143. if (recursive_->recursive_analysis().count(fg) == 0) {
  144. MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
  145. return false;
  146. }
  147. return recursive_->recursive_analysis()[fg];
  148. }
  149. std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
  150. MS_EXCEPTION_IF_NULL(fg);
  151. if (recursive(fg)) {
  152. if (!recursive_->recursive_map().count(fg)) {
  153. auto trace = std::list<FuncGraphPtr>();
  154. recursive_->CheckRecursiveGraphs(fg, &trace);
  155. }
  156. if (recursive_->recursive_map().count(fg) == 0) {
  157. MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
  158. return nullptr;
  159. }
  160. return recursive_->recursive_map()[fg];
  161. } else {
  162. return nullptr;
  163. }
  164. }
  165. bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const {
  166. MS_EXCEPTION_IF_NULL(j_total_);
  167. MS_EXCEPTION_IF_NULL(fg);
  168. j_total_->Recompute(fg);
  169. if (j_total_->j_total_analysis().count(fg) == 0) {
  170. MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
  171. return false;
  172. }
  173. return j_total_->j_total_analysis()[fg];
  174. }
  175. // add a func graph to this manager, optionally as a root func graph.
  176. void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
  177. MS_EXCEPTION_IF_NULL(func_graph);
  178. if (is_root) {
  179. roots_.add(func_graph);
  180. }
  181. if (func_graphs_.contains(func_graph)) {
  182. return;
  183. }
  184. AddIntoManaged(func_graph);
  185. MS_EXCEPTION_IF_NULL(signals_);
  186. signals_->AddFuncGraph(func_graph);
  187. std::vector<AnfNodePtr> para = func_graph->parameters();
  188. AcquireNodes(para);
  189. std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
  190. AcquireNodes(return_vec);
  191. }
  192. // clear the all information in manager
  193. void FuncGraphManager::Clear() {
  194. func_graphs_.clear();
  195. all_nodes_.clear();
  196. node_users_.clear();
  197. roots_.clear();
  198. signals_->InvalidateCollector();
  199. signals_->InvalidateComputer();
  200. }
  201. void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
  202. MS_LOG(DEBUG) << "Start keep roots";
  203. bool root_exist = false;
  204. for (auto &item : func_graphs) {
  205. if (roots_.contains(item)) {
  206. root_exist = true;
  207. break;
  208. }
  209. }
  210. // if the new_root in roots_, we add new_root first, then calculate the func_graphs
  211. // relation to new_root, remove the func_graphs not relation to new_root
  212. // if the new_root not in roots_, we clear the all func_graphs in manager
  213. // then add the new_root
  214. if (root_exist || func_graphs.empty()) {
  215. FuncGraphSet roots(func_graphs);
  216. if (roots.empty()) {
  217. roots = roots_;
  218. } else {
  219. roots_.clear();
  220. for (auto &item : roots) {
  221. AddFuncGraph(item, true);
  222. }
  223. }
  224. FuncGraphSet keep;
  225. for (auto &item : roots) {
  226. MS_LOG(DEBUG) << "roots: " << item->ToString();
  227. keep.update(func_graphs_used_total(item));
  228. #ifdef DEBUG
  229. for (auto &k : keep) {
  230. MS_LOG(DEBUG) << "keep: " << k->ToString();
  231. }
  232. #endif
  233. }
  234. MaybeDropFuncGraphs(func_graphs_ - keep, true);
  235. } else {
  236. Clear();
  237. FuncGraphSet roots(func_graphs);
  238. for (auto &item : roots) {
  239. AddFuncGraph(item, true);
  240. }
  241. }
  242. }
  243. void FuncGraphManager::RemoveRoots() {
  244. MS_LOG(DEBUG) << "Start remove roots";
  245. roots_.clear();
  246. MaybeDropFuncGraphs(func_graphs_, true);
  247. }
  248. void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
  249. MS_EXCEPTION_IF_NULL(fg);
  250. if (is_manage_) {
  251. if (fg->manager() != nullptr && (&(*fg->manager()) != this)) {
  252. MS_LOG(WARNING) << "A func graph can only have one manager.";
  253. }
  254. FuncGraphManagerPtr this_manager = shared_from_this();
  255. fg->set_manager(this_manager);
  256. }
  257. func_graphs_.add(fg);
  258. }
  259. void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
  260. FuncGraphSet todo(func_graphs);
  261. std::set<FuncGraphPtr> dropped;
  262. // int count = 0;
  263. while (!todo.empty()) {
  264. FuncGraphPtr func_graph = todo.pop();
  265. MS_EXCEPTION_IF_NULL(func_graph);
  266. MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString();
  267. if (roots_.contains(func_graph)) {
  268. MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
  269. continue;
  270. }
  271. MS_EXCEPTION_IF_NULL(func_graph_users_);
  272. auto &users = func_graph_users_->count_func_graphs_map()[func_graph];
  273. if (!users.empty() && !ignore_users) {
  274. MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
  275. continue;
  276. }
  277. if (dropped.find(func_graph) != dropped.end()) {
  278. MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString();
  279. continue;
  280. }
  281. (void)dropped.insert(func_graph);
  282. std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
  283. todo.update(MaybeDropNodes(return_vec));
  284. }
  285. MS_EXCEPTION_IF_NULL(signals_);
  286. for (auto &fg : dropped) {
  287. MS_EXCEPTION_IF_NULL(fg);
  288. signals_->DropFuncGraph(fg);
  289. all_nodes_.difference_update(fg->parameters());
  290. (void)func_graphs_.erase(fg);
  291. if (fg->manager().get() == this) {
  292. fg->set_manager(nullptr);
  293. }
  294. MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString();
  295. }
  296. }
  297. void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) {
  298. MS_EXCEPTION_IF_NULL(inp);
  299. if (direction == kDecEdge) {
  300. MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString();
  301. auto &users_node = node_users_[inp];
  302. if (!users_node.contains(make_pair(node, index))) {
  303. return;
  304. }
  305. (void)users_node.erase(make_pair(node, index));
  306. signals_->DropEdge(node, index, inp);
  307. } else {
  308. MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
  309. if (inp->func_graph() != nullptr) {
  310. AddFuncGraph(inp->func_graph());
  311. }
  312. if (IsValueNode<FuncGraph>(inp)) {
  313. MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
  314. AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
  315. }
  316. auto &users_node = node_users_[inp];
  317. users_node.add(make_pair(node, index));
  318. MS_EXCEPTION_IF_NULL(signals_);
  319. signals_->AddEdge(node, index, inp);
  320. }
  321. }
  322. void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) {
  323. MS_EXCEPTION_IF_NULL(node);
  324. if (node->isa<CNode>()) {
  325. auto cnode = node->cast<CNodePtr>();
  326. int index = 0;
  327. for (auto &inp : cnode->inputs()) {
  328. ProcessEdge(cnode, index, inp, direction);
  329. ++index;
  330. }
  331. }
  332. }
  333. IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) {
  334. if (all_nodes_.contains(node)) {
  335. return EXCLUDE;
  336. } else {
  337. return FOLLOW;
  338. }
  339. }
  340. void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
  341. AnfNodeSet acq;
  342. for (auto &node : nodes) {
  343. std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1);
  344. AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit));
  345. all_nodes_.update(new_nodes);
  346. acq.update(new_nodes);
  347. }
  348. for (auto &node : acq) {
  349. MS_EXCEPTION_IF_NULL(node);
  350. FuncGraphPtr fg = node->func_graph();
  351. if (fg != nullptr) {
  352. AddFuncGraph(fg);
  353. }
  354. signals_->AddNode(node);
  355. ProcessInputs(node, kIncEdge);
  356. }
  357. }
  358. FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
  359. AnfNodeSet nodes_ordered(nodes);
  360. FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
  361. MS_EXCEPTION_IF_NULL(signals_);
  362. while (!nodes_ordered.empty()) {
  363. AnfNodePtr node = nodes_ordered.pop();
  364. MS_EXCEPTION_IF_NULL(node);
  365. if (!all_nodes_.contains(node)) {
  366. continue;
  367. }
  368. AnfNodeIndexSet &users = node_users_[node];
  369. std::vector<AnfNodePtr> parameters;
  370. if (!users.empty() ||
  371. (node->isa<Parameter>() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) {
  372. continue;
  373. }
  374. if (IsValueNode<FuncGraph>(node)) {
  375. auto fg = GetValueNode<FuncGraphPtr>(node);
  376. func_graphs_to_check->add(fg);
  377. MS_LOG(DEBUG) << "Set value of node " << node->DebugString() << " from func graph " << fg->ToString()
  378. << " to null";
  379. }
  380. ProcessInputs(node, kDecEdge);
  381. (void)all_nodes_.erase(node);
  382. signals_->DropNode(node);
  383. if (node->isa<CNode>()) {
  384. auto cnode = node->cast<CNodePtr>();
  385. nodes_ordered.update(cnode->inputs());
  386. }
  387. (void)node_users_.erase(node);
  388. }
  389. return func_graphs_to_check;
  390. }
  391. void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters) {
  392. auto tr = Transact();
  393. tr.SetParameters(fg, parameters);
  394. tr.Commit();
  395. }
  396. bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
  397. auto tr = Transact();
  398. bool success = tr.Replace(old_node, new_node);
  399. if (success) {
  400. tr.Commit();
  401. }
  402. return success;
  403. }
  404. void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
  405. auto tr = Transact();
  406. tr.SetEdge(node, index, value);
  407. tr.Commit();
  408. }
  409. void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) {
  410. AnfNodePtr source_return = source->get_return();
  411. AnfNodePtr source_output = source->output();
  412. AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
  413. int index = 0;
  414. (void)node_users_[source_prim].erase(make_pair(source_return, index));
  415. signals_->DropEdge(source_return, index, source_prim);
  416. index = 1;
  417. (void)node_users_[source_output].erase(make_pair(source_return, index));
  418. signals_->DropEdge(source_return, index, source_output);
  419. (void)all_nodes_.erase(source_return);
  420. (void)node_users_.erase(source_return);
  421. signals_->DropNode(source_return);
  422. for (auto &node : source->nodes()) {
  423. node->set_func_graph(target);
  424. if (node->scope() == kDefaultScope) {
  425. node->set_scope(scope);
  426. }
  427. }
  428. for (auto &used : source->func_graphs_used()) {
  429. (void)func_graph_users_->Inc(used.first, target, used.second);
  430. (void)this->func_graph_users()[used.first].erase(source);
  431. }
  432. for (auto &child : this->func_graph_child_direct()[source]) {
  433. (void)func_graph_parents_direct_->Inc(child.first, target, child.second);
  434. (void)this->func_graph_parents_direct()[child.first].erase(source);
  435. }
  436. for (auto &fv_count : this->free_variables_direct()[source]) {
  437. auto fv_g = fv_count.first->func_graph();
  438. auto &count_on_g = this->func_graph_child_direct()[fv_g];
  439. auto pair = count_on_g.find(source);
  440. if (fv_g != target && pair != count_on_g.end()) {
  441. (void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
  442. }
  443. (void)count_on_g.erase(source);
  444. }
  445. signals_->MoveAllCNode(source, target);
  446. signals_->InvalidateComputer();
  447. signals_->DropFuncGraph(source);
  448. all_nodes_.difference_update(source->parameters());
  449. (void)func_graphs_.erase(source);
  450. if (source->manager().get() == this) {
  451. source->set_manager(nullptr);
  452. }
  453. }
  454. FuncGraphTransaction FuncGraphManager::Transact() {
  455. auto tr = FuncGraphTransaction(this);
  456. return tr;
  457. }
  458. void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges,
  459. EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) {
  460. for (auto &iter : changes) {
  461. auto operation = iter.op;
  462. auto args = iter.args;
  463. if (operation == Change::kTxSetEdge) {
  464. auto edge = args.cast<ArgsOfSetEdge>();
  465. auto old_node = edge.root_node->input(edge.index);
  466. (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
  467. (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
  468. (*rms)[old_node] += 1;
  469. (*adds)[edge.new_node] += 1;
  470. edge.root_node->set_input(edge.index, edge.new_node);
  471. } else if (operation == Change::kTxSetParams) {
  472. auto param = args.cast<ArgsOfSetParams>();
  473. MS_EXCEPTION_IF_NULL(param.func_graph);
  474. auto old_parameters = param.func_graph->parameters();
  475. for (auto &p : param.params) {
  476. (*adds)[p] += 1;
  477. }
  478. for (auto &p : old_parameters) {
  479. (*rms)[p] += 1;
  480. }
  481. param.func_graph->set_parameters(param.params);
  482. }
  483. }
  484. }
  485. void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) {
  486. EdgeTupleCounter add_edges;
  487. EdgeTupleCounter rm_edges;
  488. Counter<AnfNodePtr> adds;
  489. Counter<AnfNodePtr> rms;
  490. ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms);
  491. auto sub_edges = add_edges - rm_edges;
  492. for (auto &iter : sub_edges) {
  493. auto root_node = iter.first.first;
  494. int index = iter.first.second.first;
  495. auto new_node = iter.first.second.second;
  496. ProcessEdge(root_node, index, new_node, kIncEdge);
  497. }
  498. auto sub_nodes = adds - rms;
  499. std::vector<AnfNodePtr> nodes;
  500. (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes),
  501. [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
  502. AcquireNodes(nodes);
  503. auto sub_edges_reverse = rm_edges - add_edges;
  504. for (auto &iter : sub_edges_reverse) {
  505. auto root_node = iter.first.first;
  506. int index = iter.first.second.first;
  507. auto old_node = iter.first.second.second;
  508. ProcessEdge(root_node, index, old_node, kDecEdge);
  509. }
  510. auto sub_nodes_reverse = rms - adds;
  511. std::vector<AnfNodePtr> nodes_reverse;
  512. (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse),
  513. [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
  514. auto drop_func_graphs = MaybeDropNodes(nodes_reverse);
  515. MaybeDropFuncGraphs(*drop_func_graphs);
  516. }
  517. void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params) {
  518. changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
  519. }
  520. bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
  521. MS_EXCEPTION_IF_NULL(old_node);
  522. MS_EXCEPTION_IF_NULL(new_node);
  523. FuncGraphPtr old_func_graph = old_node->func_graph();
  524. if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) {
  525. MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
  526. return false;
  527. }
  528. auto users = manager_->node_users()[old_node];
  529. for (auto &node : users) {
  530. SetEdge(node.first, node.second, new_node);
  531. }
  532. return true;
  533. }
  534. void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
  535. if (k < 0) {
  536. MS_LOG(EXCEPTION) << "Invalid value k = " << k;
  537. }
  538. MS_EXCEPTION_IF_NULL(src_node);
  539. auto cnode = src_node->cast<CNodePtr>();
  540. if (cnode == nullptr) {
  541. MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed.";
  542. }
  543. changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)});
  544. }
  545. void FuncGraphTransaction::Commit() {
  546. std::vector<Change> changes;
  547. changes_.swap(changes);
  548. manager_->CommitChanges(changes);
  549. }
  550. FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
  551. : manager_(manager), include_func_graph_none_(false) {
  552. manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
  553. manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
  554. manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge);
  555. manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge);
  556. manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
  557. }
  558. NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
  559. include_func_graph_none_ = true;
  560. nodes_analysis_[nullptr] = AnfNodeSet();
  561. manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode);
  562. manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode);
  563. }
  564. void NodesCollector::OnAddNode(AnfNodePtr n) {
  565. if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) {
  566. nodes_analysis_[n->func_graph()] = AnfNodeSet();
  567. }
  568. nodes_analysis_[n->func_graph()].add(n);
  569. }
  570. void NodesCollector::OnDropNode(AnfNodePtr n) {
  571. (void)nodes_analysis_[n->func_graph()].erase(n);
  572. auto graph = n->func_graph();
  573. // Remove the node from order list.
  574. if (graph) {
  575. graph->EraseUnusedNodeInOrder(n);
  576. }
  577. }
  578. void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  579. // change the owner of node except for the src's return node
  580. for (auto &it : nodes_analysis_[src]) {
  581. nodes_analysis_[dst].add(it);
  582. }
  583. (void)nodes_analysis_.erase(src);
  584. }
  585. void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
  586. DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
  587. MS_EXCEPTION_IF_NULL(manager_);
  588. manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
  589. }
  590. void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
  591. bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
  592. auto &d = count_nodes_map_[func_graph];
  593. if (d.count(key) == 0) {
  594. d[key] = count;
  595. return true;
  596. } else {
  597. d[key] += count;
  598. }
  599. return false;
  600. }
  601. bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
  602. MS_EXCEPTION_IF_NULL(func_graph);
  603. auto &d = count_nodes_map_[func_graph];
  604. if (d.count(key) != 0) {
  605. if (d[key] == count) {
  606. (void)d.erase(key);
  607. return true;
  608. } else {
  609. d[key] -= count;
  610. if (d[key] < 0) {
  611. MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
  612. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  613. }
  614. }
  615. }
  616. return false;
  617. }
  618. bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) {
  619. if (count > 0) {
  620. return Inc(func_graph, key, count);
  621. } else if (count < 0) {
  622. return Dec(func_graph, key, -count);
  623. } else {
  624. MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
  625. << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  626. }
  627. }
  628. bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
  629. auto &d = count_func_graphs_map_[func_graph];
  630. if (d.count(key) == 0) {
  631. d[key] = count;
  632. return true;
  633. } else {
  634. d[key] += count;
  635. }
  636. return false;
  637. }
  638. bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
  639. auto &d = count_func_graphs_map_[func_graph];
  640. if (d.count(key) != 0) {
  641. if (d[key] == count) {
  642. (void)d.erase(key);
  643. return true;
  644. } else {
  645. d[key] -= count;
  646. if (d[key] < 0) {
  647. MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
  648. << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  649. }
  650. }
  651. }
  652. return false;
  653. }
  654. bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
  655. if (count > 0) {
  656. return Inc(func_graph, key, count);
  657. } else if (count < 0) {
  658. return Dec(func_graph, key, -count);
  659. } else {
  660. MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
  661. << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  662. }
  663. }
  664. void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  665. MS_EXCEPTION_IF_NULL(node);
  666. if (inp->isa<ValueNode>()) {
  667. (void)Mod(node->func_graph(), inp, direction);
  668. }
  669. }
  670. void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  671. for (auto &it : count_nodes_map_[src]) {
  672. (void)Inc(dst, it.first, it.second);
  673. }
  674. (void)count_nodes_map_.erase(src);
  675. }
  676. // if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self
  677. void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  678. if (IsValueNode<FuncGraph>(inp)) {
  679. (void)Mod(GetValueNode<FuncGraphPtr>(inp), inp, direction);
  680. }
  681. }
  682. void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  683. for (auto &it : count_nodes_map_[src]) {
  684. (void)Inc(dst, it.first, it.second);
  685. }
  686. (void)count_nodes_map_.erase(src);
  687. }
  688. void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  689. MS_EXCEPTION_IF_NULL(node);
  690. MS_EXCEPTION_IF_NULL(inp);
  691. FuncGraphPtr fg1 = node->func_graph();
  692. FuncGraphPtr fg2 = inp->func_graph();
  693. if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
  694. (void)Mod(fg1, inp, direction);
  695. }
  696. }
  697. void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  698. for (auto &it : count_nodes_map_[src]) {
  699. FuncGraphPtr fg2 = it.first->func_graph();
  700. if (fg2 != dst) {
  701. (void)Inc(dst, it.first, it.second);
  702. }
  703. }
  704. (void)count_nodes_map_.erase(src);
  705. }
  706. static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
  707. FuncGraphPtr gn = std::make_shared<FuncGraph>();
  708. (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
  709. return gn;
  710. }
  711. void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  712. MS_EXCEPTION_IF_NULL(node);
  713. MS_EXCEPTION_IF_NULL(inp);
  714. FuncGraphPtr fg1 = node->func_graph();
  715. FuncGraphPtr fg2 = inp->func_graph();
  716. if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
  717. (void)Mod(fg2, fg1, direction);
  718. }
  719. }
  720. void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  721. for (auto &it : count_func_graphs_map_[src]) {
  722. FuncGraphPtr fg = it.first;
  723. if (fg != dst) {
  724. (void)Inc(dst, fg, it.second);
  725. }
  726. }
  727. (void)count_func_graphs_map_.erase(src);
  728. }
  729. void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  730. MS_EXCEPTION_IF_NULL(node);
  731. FuncGraphPtr fg1 = node->func_graph();
  732. // possible child parent
  733. if (IsValueNode<FuncGraph>(inp)) {
  734. FuncGraphPtr fg2 = GetValueNode<FuncGraphPtr>(inp);
  735. if (Mod(fg1, ParentProxy(fg2), direction)) {
  736. manager_->signals()->InvalidateComputer();
  737. }
  738. }
  739. // from fv
  740. FuncGraphPtr fg2 = inp->func_graph();
  741. if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
  742. // node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent
  743. if (Mod(fg1, fg2, direction)) {
  744. manager_->signals()->InvalidateComputer();
  745. }
  746. }
  747. }
  748. void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  749. for (auto &it : count_func_graphs_map_[src]) {
  750. if (it.first != dst) {
  751. (void)Inc(dst, it.first, it.second);
  752. }
  753. }
  754. (void)count_func_graphs_map_.erase(src);
  755. }
  756. void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  757. MS_EXCEPTION_IF_NULL(node);
  758. if (IsValueNode<FuncGraph>(inp)) {
  759. (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
  760. }
  761. }
  762. void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  763. // all graph use in src need to change to dst, so meger the to dst use
  764. for (auto &it : count_func_graphs_map_[src]) {
  765. (void)Inc(dst, it.first, it.second);
  766. }
  767. (void)count_func_graphs_map_[dst].erase(src);
  768. (void)count_func_graphs_map_.erase(src);
  769. }
  770. void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  771. MS_EXCEPTION_IF_NULL(node);
  772. if (IsValueNode<FuncGraph>(inp)) {
  773. (void)Mod(GetValueNode<FuncGraphPtr>(inp), node->func_graph(), direction);
  774. }
  775. }
  776. void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) {
  777. // all graph use in src need to change to dst, so add dst user
  778. (void)count_func_graphs_map_.erase(src);
  779. }
  780. void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  781. MS_EXCEPTION_IF_NULL(node);
  782. if (IsValueNode<FuncGraph>(inp)) {
  783. (void)Mod(GetValueNode<FuncGraphPtr>(inp), node, direction);
  784. }
  785. }
  786. void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  787. for (auto &it : count_nodes_map_[src]) {
  788. (void)Inc(dst, it.first, it.second);
  789. }
  790. (void)count_nodes_map_.erase(src);
  791. }
  792. void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
  793. if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
  794. (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
  795. MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph "
  796. << GetValueNode<FuncGraphPtr>(inp)->ToString() << " which contains J(func_graph), dir: " << direction;
  797. }
  798. }
  799. void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
  800. // all graph use in src need to change to dst, so meger the to dst use
  801. for (auto &it : count_func_graphs_map_[src]) {
  802. (void)Inc(dst, it.first, it.second);
  803. }
  804. (void)count_func_graphs_map_.erase(src);
  805. }
  806. DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
  807. MS_EXCEPTION_IF_NULL(manager_);
  808. manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
  809. validate_ = false;
  810. }
  811. void DepComputer::Recompute() {
  812. if (!validate_) {
  813. RealRecompute();
  814. validate_ = true;
  815. }
  816. }
  817. void DepComputer::Recompute(const FuncGraphPtr &fg) {
  818. if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
  819. RealRecompute(fg);
  820. func_graphs_validate_[fg] = true;
  821. }
  822. }
  823. FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
  824. if (path == nullptr || path->contains(fg)) {
  825. return std::make_shared<FuncGraphSet>();
  826. }
  827. FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
  828. FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_;
  829. for (auto &dep : deps[fg]) {
  830. MS_EXCEPTION_IF_NULL(dep.first);
  831. auto proxy = dep.first->transforms().find("proxy");
  832. if (proxy != dep.first->transforms().end()) {
  833. path->add(fg);
  834. auto gt = proxy->second.func_graph();
  835. parents->update(SeekParents(gt, path));
  836. } else {
  837. parents->add(dep.first);
  838. }
  839. }
  840. (void)parents->erase(fg);
  841. return parents;
  842. }
  843. void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
  844. MS_EXCEPTION_IF_NULL(fg);
  845. all_parents_direct_ = &(manager_->func_graph_parents_direct());
  846. MS_LOG(DEBUG) << fg->ToString() << " total func graph dep size:" << (*all_parents_direct_)[fg].size();
  847. func_graph_parents_total_analysis_[fg].update(SeekParents(fg));
  848. MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
  849. }
  850. bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
  851. auto l1 = lhs.second.size();
  852. auto l2 = rhs.second.size();
  853. return l1 < l2;
  854. }
  855. void ParentComputer::RealRecompute(FuncGraphPtr fg) {
  856. this->parent_analysis_[fg] = nullptr;
  857. // Note: must be a copy other than reference as it is modified thereafter.
  858. auto deps = this->manager_->func_graph_parents_total(fg);
  859. if (deps.empty()) {
  860. this->parent_analysis_[fg] = nullptr;
  861. return;
  862. } else if (deps.size() == 1) {
  863. this->parent_analysis_[fg] = deps.pop();
  864. return;
  865. } else {
  866. // return nearest parent as parent
  867. FuncGraphSet deps_copy(deps);
  868. for (auto &dep : deps) {
  869. auto parent_deps = this->manager_->func_graph_parents_total(dep);
  870. for (auto &p_d : parent_deps) {
  871. if (deps_copy.count(p_d)) {
  872. (void)deps_copy.erase(p_d);
  873. }
  874. }
  875. if (deps_copy.size() == 1) {
  876. this->parent_analysis_[fg] = deps_copy.pop();
  877. return;
  878. }
  879. }
  880. }
  881. }
  882. void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
  883. MS_EXCEPTION_IF_NULL(manager_);
  884. auto used_fg_total = manager_->func_graphs_used_total(fg);
  885. for (auto &used_fg : used_fg_total) {
  886. if (manager_->parent(used_fg) == fg) {
  887. children_analysis_[fg].add(used_fg);
  888. }
  889. }
  890. }
  891. void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
  892. MS_EXCEPTION_IF_NULL(manager_);
  893. auto &children = manager_->children(fg);
  894. scope_analysis_[fg] = FuncGraphSet();
  895. scope_analysis_[fg].add(fg);
  896. for (auto &child : children) {
  897. scope_analysis_[fg].add(child);
  898. }
  899. }
  900. void FVTotalComputer::RealRecompute() {
  901. auto manager = DepComputer::manager_;
  902. MS_EXCEPTION_IF_NULL(manager);
  903. for (auto &fg : manager->func_graphs()) {
  904. fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
  905. count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
  906. count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
  907. }
  908. for (auto &fg : manager->func_graphs()) {
  909. AnfNodeCounterMap items = manager->free_variables_direct()[fg];
  910. for (auto &iter : items) {
  911. auto curr = fg;
  912. while (curr) {
  913. (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
  914. curr = manager->parent(curr);
  915. const AnfNodeSet &nodes = manager->nodes()[curr];
  916. if (nodes.contains(iter.first)) {
  917. break;
  918. }
  919. }
  920. }
  921. auto items_fg = manager->func_graphs_used()[fg];
  922. for (auto &iter : items_fg) {
  923. auto p = manager->parent(iter.first);
  924. if (p == nullptr) {
  925. continue;
  926. }
  927. auto curr = fg;
  928. while (curr != p) {
  929. (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
  930. curr = manager->parent(curr);
  931. }
  932. }
  933. }
  934. for (auto &fg : manager->func_graphs()) {
  935. auto &fvp = count_nodes_map_[fg];
  936. auto &fvg = count_func_graphs_map_[fg];
  937. for (auto &item : fvp) {
  938. fv_total_analysis_[fg][item.first] = item.second;
  939. }
  940. for (auto &item : fvg) {
  941. fv_total_analysis_[fg][item.first] = item.second;
  942. }
  943. }
  944. }
  945. void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
  946. MS_EXCEPTION_IF_NULL(manager_);
  947. auto &used = this->manager_->func_graphs_used();
  948. std::vector<FuncGraphPtr> todo;
  949. std::vector<FuncGraphPtr> todo_new;
  950. todo.push_back(fg);
  951. while (!todo.empty()) {
  952. todo_new.clear();
  953. for (auto &gt : todo) {
  954. for (auto &item : used[gt]) {
  955. auto used_fg = item.first;
  956. if (used_fg == fg) {
  957. func_graph_used_total_analysis_[fg].add(used_fg);
  958. continue;
  959. }
  960. if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) {
  961. todo_new.push_back(used_fg);
  962. }
  963. MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString();
  964. func_graph_used_total_analysis_[fg].add(used_fg);
  965. }
  966. }
  967. todo = todo_new;
  968. }
  969. }
  970. bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
  971. MS_EXCEPTION_IF_NULL(manager);
  972. auto &used = manager->func_graphs_used();
  973. std::vector<FuncGraphPtr> todo;
  974. std::vector<FuncGraphPtr> todo_new;
  975. todo.push_back(fg);
  976. FuncGraphSet used_total;
  977. while (!todo.empty()) {
  978. todo_new.clear();
  979. for (auto &gt : todo) {
  980. for (auto &item : used[gt]) {
  981. auto used_g = item.first;
  982. if (used_g == fg) {
  983. return true;
  984. }
  985. if (used_total.count(used_g) == 0) {
  986. todo_new.push_back(used_g);
  987. }
  988. used_total.add(used_g);
  989. }
  990. }
  991. todo = todo_new;
  992. }
  993. return false;
  994. }
  995. void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
  996. this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
  997. }
  998. void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
  999. MS_EXCEPTION_IF_NULL(trace);
  1000. auto res = std::find(trace->begin(), trace->end(), fg);
  1001. // find recursive
  1002. if (res != trace->end()) {
  1003. auto recur_ptr = std::make_shared<std::list<FuncGraphPtr>>(res, trace->end());
  1004. for (auto iter = res; iter != trace->end(); (void)iter++) {
  1005. MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString();
  1006. recursive_map_[*iter] = recur_ptr;
  1007. }
  1008. } else {
  1009. trace->push_back(fg);
  1010. auto &used_fgs = manager_->func_graphs_used()[fg];
  1011. for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
  1012. CheckRecursiveGraphs(iter->first, trace);
  1013. }
  1014. trace->pop_back();
  1015. if (!recursive_map_.count(fg)) {
  1016. recursive_map_[fg] = nullptr;
  1017. }
  1018. }
  1019. }
  1020. bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
  1021. MS_EXCEPTION_IF_NULL(path);
  1022. if (path->contains(fg)) {
  1023. MS_LOG(DEBUG) << fg->ToString() << " had been checked";
  1024. return false;
  1025. }
  1026. MS_EXCEPTION_IF_NULL(manager_);
  1027. auto &func_graph_counter_map = manager_->func_graph_j_direct();
  1028. if (!func_graph_counter_map[fg].empty()) {
  1029. // check g1->J(fg)->g2->g cycle;
  1030. auto contains_j =
  1031. std::find_if(func_graph_counter_map[fg].begin(), func_graph_counter_map[fg].end(),
  1032. [path](const std::pair<FuncGraphPtr, int> iter) { return !path->contains(iter.first); });
  1033. if (contains_j != func_graph_counter_map[fg].end()) {
  1034. MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
  1035. return true;
  1036. }
  1037. }
  1038. path->add(fg);
  1039. // check if func graphs used contains J(func_graph);
  1040. auto &used = this->manager_->func_graphs_used();
  1041. for (auto &item : used[fg]) {
  1042. auto used_g = item.first;
  1043. if (SeekJ(used_g, path)) {
  1044. MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
  1045. return true;
  1046. }
  1047. }
  1048. MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)";
  1049. return false;
  1050. }
  1051. void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
  1052. std::shared_ptr<FuncGraphSet> path = std::make_shared<FuncGraphSet>();
  1053. this->j_total_analysis_[fg] = SeekJ(fg, path);
  1054. }
  1055. } // namespace mindspore