You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph_analyzer.cc 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/func_graph_analyzer.h"
  17. #include <algorithm>
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include "base/core_ops.h"
  22. #include "utils/utils.h"
  23. namespace mindspore {
  24. const int64_t kGetAll = -1;
  25. using FuncClosurePtr = std::shared_ptr<FuncClosure>;
  26. class ValueGetter;
  27. using ValueGetterPtr = std::shared_ptr<ValueGetter>;
  28. class ValueManager;
  29. using ValueManagerPtr = std::shared_ptr<ValueManager>;
  30. ValueGetterPtr CreateValueGetter(const AnfNodePtr &node, const ValueManagerPtr &manager);
  31. class ValueManager : public std::enable_shared_from_this<ValueManager> {
  32. public:
  33. ValueManager() = default;
  34. ~ValueManager() = default;
  35. ValueGetterPtr GetValueGetter(const AnfNodePtr &node) {
  36. MS_LOG(DEBUG) << "Try get value getter of node: " << node->DebugString();
  37. const auto &it = values_getters_.find(node);
  38. if (it == values_getters_.end()) {
  39. auto new_value_getter = CreateValueGetter(node, shared_from_this());
  40. values_getters_[node] = new_value_getter;
  41. MS_LOG(DEBUG) << "Create new value getter of node: " << node->DebugString();
  42. return new_value_getter;
  43. }
  44. return it->second;
  45. }
  46. bool UpdateGraphRelations(const std::vector<FuncClosurePtr> &func_closures, const AnfNodePtr &call) {
  47. MS_LOG(DEBUG) << "Func closure size: " << func_closures.size() << ", call: " << call->DebugString();
  48. auto change1 = UpdateGraphRealCallers(func_closures, call);
  49. auto change2 = UpdateCallerClosures(func_closures, call);
  50. return change1 || change2;
  51. }
  52. std::vector<FuncClosurePtr> GetCallClosures(const AnfNodePtr &call, const FuncGraphPtr &fg) {
  53. const auto &it = caller_closures_.find(call);
  54. if (it == caller_closures_.end()) {
  55. return {};
  56. }
  57. const auto &closures = it->second;
  58. std::vector<FuncClosurePtr> ret;
  59. (void)std::copy_if(closures.begin(), closures.end(), std::back_inserter(ret),
  60. [&fg](const FuncClosurePtr &closure) { return closure->func_graph_ == fg; });
  61. return ret;
  62. }
  63. std::vector<AnfNodePtr> GetArg(const AnfNodePtr &param, const AnfNodePtr &call) {
  64. MS_EXCEPTION_IF_NULL(param);
  65. auto fg = param->func_graph();
  66. MS_EXCEPTION_IF_NULL(fg);
  67. const auto &parameters = fg->parameters();
  68. int64_t param_index = -1;
  69. for (size_t i = 0; i < parameters.size(); i++) {
  70. if (parameters[i] == param) {
  71. param_index = i;
  72. }
  73. }
  74. if (param_index == -1) {
  75. MS_LOG(EXCEPTION) << "Failed failed arg of parameter: " << param->DebugString()
  76. << ",call: " << call->DebugString();
  77. }
  78. std::vector<AnfNodePtr> ret_args;
  79. auto call_cnode = call->cast<CNodePtr>();
  80. MS_EXCEPTION_IF_NULL(call_cnode);
  81. auto closures = GetCallClosures(call, fg);
  82. for (const auto &closure : closures) {
  83. auto args = closure->GetArgs();
  84. (void)std::copy(call_cnode->inputs().begin() + 1, call_cnode->inputs().end(), std::back_inserter(args));
  85. if (parameters.size() != args.size()) {
  86. MS_LOG(EXCEPTION) << "Parameters size and args size are not equal, parameters size: " << parameters.size()
  87. << ", args size: " << args.size() << ". Parameter: " << param->DebugString()
  88. << ", call: " << call_cnode->DebugString();
  89. }
  90. ret_args.emplace_back(args[param_index]);
  91. }
  92. return ret_args;
  93. }
  94. HashMap<FuncGraphPtr, std::vector<CNodePtr>> func_graph_real_users_;
  95. HashMap<AnfNodePtr, std::vector<FuncClosurePtr>> caller_closures_;
  96. bool has_incorporate_call_ = false;
  97. private:
  98. HashMap<AnfNodePtr, ValueGetterPtr> values_getters_;
  99. bool UpdateGraphRealCallers(const std::vector<FuncClosurePtr> &func_closures, const AnfNodePtr &call) {
  100. bool change = false;
  101. for (const auto &fg_closure : func_closures) {
  102. auto map_it = func_graph_real_users_.find(fg_closure->func_graph_);
  103. if (map_it != func_graph_real_users_.end()) {
  104. const auto &real_callers = map_it->second;
  105. if (std::find(real_callers.begin(), real_callers.end(), call->cast<CNodePtr>()) != real_callers.end()) {
  106. continue;
  107. }
  108. }
  109. MS_LOG(DEBUG) << "Fg: " << fg_closure->func_graph_->ToString() << ", user: " << call->DebugString();
  110. func_graph_real_users_[fg_closure->func_graph_].push_back(call->cast<CNodePtr>());
  111. change = true;
  112. }
  113. return change;
  114. }
  115. bool UpdateCallerClosures(const std::vector<FuncClosurePtr> &func_closures, const AnfNodePtr &call) {
  116. auto map_it = caller_closures_.find(call);
  117. if (map_it != caller_closures_.end()) {
  118. bool change = false;
  119. auto &closures = map_it->second;
  120. std::copy_if(func_closures.begin(), func_closures.end(), std::back_inserter(closures),
  121. [&closures, &change](const FuncClosurePtr &fg_closure) {
  122. if (!fg_closure->ExistInList(closures)) {
  123. change = true;
  124. return true;
  125. }
  126. return false;
  127. });
  128. return change;
  129. }
  130. caller_closures_[call] = func_closures;
  131. return true;
  132. }
  133. };
  134. class ValueGetter {
  135. public:
  136. ValueGetter(const AnfNodePtr &anf_node, const ValueManagerPtr &manager) : anf_node_(anf_node), manager_(manager) {}
  137. ~ValueGetter() = default;
  138. virtual ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path);
  139. virtual std::vector<FuncClosurePtr> GetFuncGraphs();
  140. protected:
  141. AnfNodePtr anf_node_ = nullptr;
  142. ValueManagerPtr manager_ = nullptr;
  143. };
  144. ValueGetterPtr ValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  145. return nullptr;
  146. }
  147. std::vector<FuncClosurePtr> ValueGetter::GetFuncGraphs() { return {}; }
  148. class MultipleValueGetter : public ValueGetter {
  149. public:
  150. explicit MultipleValueGetter(const ValueManagerPtr &manager) : ValueGetter(nullptr, manager) {}
  151. ~MultipleValueGetter() = default;
  152. MultipleValueGetter(const std::vector<ValueGetterPtr> &value_getters, const ValueManagerPtr &manager)
  153. : ValueGetter(nullptr, manager), value_getters_(value_getters) {}
  154. void AddValueGetter(const ValueGetterPtr &value_getter) {
  155. if (std::find(value_getters_.begin(), value_getters_.end(), value_getter) == value_getters_.end()) {
  156. value_getters_.push_back(value_getter);
  157. }
  158. }
  159. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  160. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  161. private:
  162. std::vector<ValueGetterPtr> value_getters_;
  163. };
  164. ValueGetterPtr MultipleValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  165. auto new_multiple_value_getter = std::make_shared<MultipleValueGetter>(manager_);
  166. for (auto &value_getter : value_getters_) {
  167. // Copy a new path.
  168. auto new_path = std::make_shared<HashSet<AnfNodePtr>>(*visit_path);
  169. new_multiple_value_getter->AddValueGetter(value_getter->Visit(index, new_path));
  170. }
  171. return new_multiple_value_getter;
  172. }
  173. std::vector<FuncClosurePtr> MultipleValueGetter::GetFuncGraphs() {
  174. std::vector<FuncClosurePtr> ret_func_closures;
  175. for (const auto &value_getter : value_getters_) {
  176. const auto &func_closures = value_getter->GetFuncGraphs();
  177. (void)std::copy_if(func_closures.begin(), func_closures.end(), std::back_inserter(ret_func_closures),
  178. [&ret_func_closures](const auto &closure) { return !closure->ExistInList(ret_func_closures); });
  179. }
  180. return ret_func_closures;
  181. }
  182. class MakeTupleValueGetter : public ValueGetter {
  183. public:
  184. MakeTupleValueGetter(const AnfNodePtr &make_tuple, const ValueManagerPtr &manager)
  185. : ValueGetter(make_tuple, manager) {}
  186. ~MakeTupleValueGetter() = default;
  187. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  188. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  189. };
  190. ValueGetterPtr MakeTupleValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  191. (void)visit_path->insert(anf_node_);
  192. const auto &make_tuple = anf_node_->cast<CNodePtr>();
  193. MS_EXCEPTION_IF_NULL(make_tuple);
  194. if (index == kGetAll) {
  195. auto multiple_value_getter = std::make_shared<MultipleValueGetter>(manager_);
  196. for (size_t i = 1; i < make_tuple->size(); i++) {
  197. auto new_path = std::make_shared<HashSet<AnfNodePtr>>(*visit_path);
  198. multiple_value_getter->AddValueGetter(manager_->GetValueGetter(make_tuple->input(i))->Visit(0, new_path));
  199. }
  200. return multiple_value_getter;
  201. }
  202. const auto &input_i = make_tuple->input(LongToSize(index + 1));
  203. auto input_i_getter = manager_->GetValueGetter(input_i);
  204. if (input_i_getter == nullptr) {
  205. MS_LOG(EXCEPTION) << "Make tuple: " << anf_node_->DebugString()
  206. << " get input value getter failed. Index: " << index << ", input_i: " << input_i->DebugString();
  207. }
  208. return input_i_getter;
  209. }
  210. std::vector<FuncClosurePtr> MakeTupleValueGetter::GetFuncGraphs() {
  211. MS_LOG(EXCEPTION) << "MakeTupleValueGetter has no func graphs, anf_node_:" << anf_node_->DebugString();
  212. }
  213. class TupleGetItemValueGetter : public ValueGetter {
  214. public:
  215. TupleGetItemValueGetter(const AnfNodePtr &tuple_getitem, const ValueManagerPtr &manager)
  216. : ValueGetter(tuple_getitem, manager) {}
  217. ~TupleGetItemValueGetter() = default;
  218. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  219. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  220. };
  221. ValueGetterPtr TupleGetItemValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  222. (void)visit_path->insert(anf_node_);
  223. auto tuple_getitem = anf_node_->cast<CNodePtr>();
  224. MS_EXCEPTION_IF_NULL(tuple_getitem);
  225. // Get cur index
  226. auto output_index_value_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
  227. auto value_node = output_index_value_node->cast<ValueNodePtr>();
  228. MS_EXCEPTION_IF_NULL(value_node);
  229. auto cur_index = LongToSize(GetValue<int64_t>(value_node->value()));
  230. // Get real input value getter
  231. auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
  232. const auto &real_input_value_getter = manager_->GetValueGetter(real_input).get();
  233. MS_EXCEPTION_IF_NULL(real_input_value_getter);
  234. return real_input_value_getter->Visit(cur_index, visit_path)->Visit(index, visit_path);
  235. }
  236. std::vector<FuncClosurePtr> TupleGetItemValueGetter::GetFuncGraphs() {
  237. MS_LOG(EXCEPTION) << "TupleGetItemValueGetter has no func graphs, anf_node_:" << anf_node_->DebugString();
  238. }
  239. class DependValueGetter : public ValueGetter {
  240. public:
  241. DependValueGetter(const AnfNodePtr &depend, const ValueManagerPtr &manager) : ValueGetter(depend, manager) {}
  242. ~DependValueGetter() = default;
  243. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  244. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  245. };
  246. ValueGetterPtr DependValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  247. (void)visit_path->insert(anf_node_);
  248. auto depend = anf_node_->cast<CNodePtr>();
  249. MS_EXCEPTION_IF_NULL(depend);
  250. auto real_input = depend->input(kRealInputIndexInDepend);
  251. return manager_->GetValueGetter(real_input)->Visit(index, visit_path);
  252. }
  253. std::vector<FuncClosurePtr> DependValueGetter::GetFuncGraphs() {
  254. MS_LOG(EXCEPTION) << "DependValueGetter has no func graphs, anf_node_: " << anf_node_->DebugString();
  255. }
  256. class PartialValueGetter : public ValueGetter, public std::enable_shared_from_this<PartialValueGetter> {
  257. public:
  258. PartialValueGetter(const AnfNodePtr &partial, const ValueManagerPtr &manager) : ValueGetter(partial, manager) {}
  259. ~PartialValueGetter() = default;
  260. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  261. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  262. private:
  263. ValueGetterPtr real_value_getter_ = nullptr;
  264. };
  265. ValueGetterPtr PartialValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  266. (void)visit_path->insert(anf_node_);
  267. auto partial = anf_node_->cast<CNodePtr>();
  268. MS_EXCEPTION_IF_NULL(partial);
  269. auto constexpr function_input_index = 1;
  270. auto real_input = partial->input(function_input_index);
  271. real_value_getter_ = manager_->GetValueGetter(real_input)->Visit(index, visit_path);
  272. return shared_from_this();
  273. }
  274. std::vector<FuncClosurePtr> PartialValueGetter::GetFuncGraphs() {
  275. if (real_value_getter_ == nullptr) {
  276. MS_LOG(EXCEPTION) << "Real value getter is null, please visit before get func graphs.node:"
  277. << anf_node_->DebugString();
  278. }
  279. auto input_closures = real_value_getter_->GetFuncGraphs();
  280. constexpr auto arg_start_idx = 2;
  281. auto partial = anf_node_->cast<CNodePtr>();
  282. std::vector<FuncClosurePtr> closures;
  283. for (const auto &closure : input_closures) {
  284. auto arg_indexes = closure->arg_indexes_;
  285. auto arg_users = closure->arg_users_;
  286. for (size_t i = arg_start_idx; i < partial->inputs().size(); i++) {
  287. arg_indexes.emplace_back(i);
  288. arg_users.emplace_back(partial);
  289. }
  290. closures.emplace_back(std::make_shared<FuncClosure>(closure->func_graph_, arg_indexes, arg_users));
  291. }
  292. return closures;
  293. }
  294. class CallerValueGetter : public ValueGetter {
  295. public:
  296. CallerValueGetter(const AnfNodePtr &call, const ValueManagerPtr &manager) : ValueGetter(call, manager) {}
  297. ~CallerValueGetter() = default;
  298. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  299. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  300. };
  301. ValueGetterPtr CallerValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  302. // Get the func_graph called.
  303. auto call = anf_node_->cast<CNodePtr>();
  304. MS_EXCEPTION_IF_NULL(call);
  305. auto input0 = call->input(0);
  306. std::vector<FuncClosurePtr> called_func_graphs;
  307. const auto &input0_getter = manager_->GetValueGetter(input0).get();
  308. // If get func graph from caller output, incorporate call exist.
  309. manager_->has_incorporate_call_ = true;
  310. if (input0_getter != nullptr) {
  311. auto new_path = std::make_shared<HashSet<AnfNodePtr>>();
  312. auto value_getter = input0_getter->Visit(0, new_path);
  313. called_func_graphs = value_getter->GetFuncGraphs();
  314. }
  315. if (called_func_graphs.empty()) {
  316. MS_LOG(EXCEPTION) << "Call node get value failed,node: " << anf_node_->DebugString();
  317. }
  318. // Get the call return value getters
  319. std::vector<ValueGetterPtr> output_value_getters;
  320. for (const auto &fg_closure : called_func_graphs) {
  321. auto new_path = std::make_shared<HashSet<AnfNodePtr>>(*visit_path);
  322. const auto &output_value_getter =
  323. manager_->GetValueGetter(fg_closure->func_graph_->output())->Visit(index, new_path);
  324. output_value_getters.push_back(output_value_getter);
  325. }
  326. if (output_value_getters.size() == 1) {
  327. return output_value_getters.back();
  328. } else {
  329. auto new_multiple_value_getter = std::make_shared<MultipleValueGetter>(manager_);
  330. for (const auto &output_value_getter : output_value_getters) {
  331. new_multiple_value_getter->AddValueGetter(output_value_getter);
  332. }
  333. return new_multiple_value_getter;
  334. }
  335. }
  336. std::vector<FuncClosurePtr> CallerValueGetter::GetFuncGraphs() {
  337. MS_LOG(EXCEPTION) << "Caller node can't call the func get func graphs, call node: " << anf_node_->DebugString();
  338. }
  339. class SwitchValueGetter : public ValueGetter {
  340. public:
  341. SwitchValueGetter(const AnfNodePtr &switch_node, const ValueManagerPtr &manager)
  342. : ValueGetter(switch_node, manager) {}
  343. ~SwitchValueGetter() = default;
  344. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  345. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  346. };
  347. ValueGetterPtr SwitchValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  348. auto constexpr true_branch_index = 2;
  349. auto constexpr false_branch_index = 3;
  350. auto switch_node = anf_node_->cast<CNodePtr>();
  351. MS_EXCEPTION_IF_NULL(switch_node);
  352. auto true_branch_node = switch_node->input(true_branch_index);
  353. auto false_branch_node = switch_node->input(false_branch_index);
  354. auto true_value_getter = manager_->GetValueGetter(true_branch_node)->Visit(index, visit_path);
  355. auto new_path = std::make_shared<HashSet<AnfNodePtr>>(*visit_path);
  356. auto false_value_getter = manager_->GetValueGetter(false_branch_node)->Visit(index, new_path);
  357. auto multiple_value_getter = std::make_shared<MultipleValueGetter>(manager_);
  358. multiple_value_getter->AddValueGetter(true_value_getter);
  359. multiple_value_getter->AddValueGetter(false_value_getter);
  360. return multiple_value_getter;
  361. }
  362. std::vector<FuncClosurePtr> SwitchValueGetter::GetFuncGraphs() {
  363. MS_LOG(EXCEPTION) << "Switch node can't call the func get func graphs, switch: " << anf_node_->DebugString();
  364. }
  365. class SwitchLayerValueGetter : public ValueGetter {
  366. public:
  367. SwitchLayerValueGetter(const AnfNodePtr &switch_layer, const ValueManagerPtr &manager)
  368. : ValueGetter(switch_layer, manager) {}
  369. ~SwitchLayerValueGetter() = default;
  370. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  371. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  372. };
  373. ValueGetterPtr SwitchLayerValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  374. constexpr auto funcs_make_tuple_index = 2;
  375. const auto &switch_layer = anf_node_->cast<CNodePtr>();
  376. MS_EXCEPTION_IF_NULL(switch_layer);
  377. auto switch_layer_input1 = switch_layer->input(funcs_make_tuple_index);
  378. return manager_->GetValueGetter(switch_layer_input1)->Visit(kGetAll, visit_path);
  379. }
  380. std::vector<FuncClosurePtr> SwitchLayerValueGetter::GetFuncGraphs() {
  381. MS_LOG(EXCEPTION) << "SwitchLayer node can't call the func get func graphs, switch layer: "
  382. << anf_node_->DebugString();
  383. }
  384. // ParameterValueGetter should be analysis after others caller
  385. class ParameterValueGetter : public ValueGetter {
  386. public:
  387. ParameterValueGetter(const AnfNodePtr &parameter, const ValueManagerPtr &manager) : ValueGetter(parameter, manager) {}
  388. ~ParameterValueGetter() = default;
  389. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) override;
  390. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  391. };
  392. ValueGetterPtr ParameterValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &visit_path) {
  393. // If get func graph from parameter, incorporate call exist.
  394. manager_->has_incorporate_call_ = true;
  395. // If anf_node_ in visit path, it remarks there is a recursive call
  396. if (visit_path->find(anf_node_) != visit_path->end()) {
  397. MS_LOG(INFO) << "Node: " << anf_node_->DebugString() << " has been visited.";
  398. return std::make_shared<MultipleValueGetter>(manager_);
  399. }
  400. // Add parameter to path
  401. (void)visit_path->insert(anf_node_);
  402. // Find node users
  403. auto param_func = anf_node_->func_graph();
  404. MS_EXCEPTION_IF_NULL(param_func);
  405. auto multiple_value_getter = std::make_shared<MultipleValueGetter>(manager_);
  406. const auto &it = manager_->func_graph_real_users_.find(param_func);
  407. if (it != manager_->func_graph_real_users_.end()) {
  408. const auto &calls = it->second;
  409. for (const auto &call : calls) {
  410. const auto &args = manager_->GetArg(anf_node_, call);
  411. for (const auto &arg : args) {
  412. auto new_path = std::make_shared<HashSet<AnfNodePtr>>(*visit_path);
  413. auto arg_value_getter = manager_->GetValueGetter(arg)->Visit(index, new_path);
  414. multiple_value_getter->AddValueGetter(arg_value_getter);
  415. }
  416. }
  417. }
  418. return multiple_value_getter;
  419. }
  420. std::vector<FuncClosurePtr> ParameterValueGetter::GetFuncGraphs() {
  421. // If parameter has not find it's arg value getter, we return a empty func graphs.
  422. MS_LOG(INFO) << "Undetermined parameter function,node: " << anf_node_->DebugString();
  423. return {};
  424. }
  425. class DirectValueGetter : public ValueGetter, public std::enable_shared_from_this<DirectValueGetter> {
  426. public:
  427. DirectValueGetter(const AnfNodePtr &value_node, const ValueManagerPtr &manager) : ValueGetter(value_node, manager) {}
  428. ~DirectValueGetter() = default;
  429. ValueGetterPtr Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &) override;
  430. std::vector<FuncClosurePtr> GetFuncGraphs() override;
  431. private:
  432. std::vector<FuncClosurePtr> func_graphs_;
  433. };
  434. ValueGetterPtr DirectValueGetter::Visit(int64_t index, const std::shared_ptr<HashSet<AnfNodePtr>> &) {
  435. MS_LOG(DEBUG) << "Visit direct value getter: " << anf_node_->DebugString();
  436. return shared_from_this();
  437. }
  438. std::vector<FuncClosurePtr> DirectValueGetter::GetFuncGraphs() {
  439. if (!IsValueNode<FuncGraph>(anf_node_)) {
  440. MS_LOG(EXCEPTION) << "Expect a func graph value node, but got an illegal value node:" << anf_node_->DebugString();
  441. }
  442. if (func_graphs_.empty()) {
  443. func_graphs_.emplace_back(std::make_shared<FuncClosure>(GetValueNode<FuncGraphPtr>(anf_node_),
  444. std::vector<size_t>(), std::vector<CNodePtr>()));
  445. }
  446. return func_graphs_;
  447. }
  448. bool IsFuncGraphCallNode(const AnfNodePtr &node) {
  449. if (!node->isa<CNode>()) {
  450. return false;
  451. }
  452. auto input0 = node->cast<CNodePtr>()->input(0);
  453. if (IsValueNode<Primitive>(input0)) {
  454. return false;
  455. }
  456. return true;
  457. }
  458. ValueGetterPtr CreateValueGetter(const AnfNodePtr &node, const ValueManagerPtr &manager) {
  459. if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  460. return std::make_shared<MakeTupleValueGetter>(node, manager);
  461. }
  462. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  463. return std::make_shared<TupleGetItemValueGetter>(node, manager);
  464. }
  465. if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
  466. return std::make_shared<DependValueGetter>(node, manager);
  467. }
  468. if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
  469. return std::make_shared<PartialValueGetter>(node, manager);
  470. }
  471. if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
  472. return std::make_shared<SwitchValueGetter>(node, manager);
  473. }
  474. if (IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
  475. return std::make_shared<SwitchLayerValueGetter>(node, manager);
  476. }
  477. if (IsFuncGraphCallNode(node)) {
  478. return std::make_shared<CallerValueGetter>(node, manager);
  479. }
  480. if (node->isa<Parameter>()) {
  481. return std::make_shared<ParameterValueGetter>(node, manager);
  482. }
  483. if (node->isa<ValueNode>()) {
  484. return std::make_shared<DirectValueGetter>(node, manager);
  485. }
  486. // Others are prim cnode.
  487. MS_LOG(EXCEPTION) << "Unexpected value getter node: " << node->DebugString();
  488. }
  489. std::vector<AnfNodePtr> GetAllCallNodes(const FuncGraphPtr &func_graph) {
  490. std::vector<AnfNodePtr> calls;
  491. const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
  492. std::copy_if(all_nodes.begin(), all_nodes.end(), std::back_inserter(calls),
  493. [](const AnfNodePtr &node) { return IsFuncGraphCallNode(node); });
  494. return calls;
  495. }
  496. bool FuncClosure::ExistInList(const std::vector<std::shared_ptr<FuncClosure>> &list) const {
  497. return std::any_of(list.begin(), list.end(), [this](const auto &list_item) { return *list_item == *this; });
  498. }
  499. std::vector<AnfNodePtr> FuncClosure::GetArgs() const {
  500. std::vector<AnfNodePtr> args;
  501. for (size_t i = 0; i < arg_indexes_.size(); i++) {
  502. args.emplace_back(arg_users_[i]->input(arg_indexes_[i]));
  503. }
  504. return args;
  505. }
  506. std::string FuncClosure::ToString() const {
  507. std::ostringstream buffer;
  508. buffer << "\nfg:," << func_graph_->ToString();
  509. for (size_t i = 0; i < arg_users_.size(); i++) {
  510. buffer << "\narg[" << i << "]:" << arg_users_[i]->input(arg_indexes_[i])->ToString();
  511. }
  512. buffer << "\n===================================================";
  513. return buffer.str();
  514. }
  515. FuncGraphAnalyzer::FuncGraphAnalyzer(const FuncGraphPtr &func_graph) {
  516. root_graph_ = func_graph;
  517. value_manager_ = std::make_shared<ValueManager>();
  518. }
  519. void FuncGraphAnalyzer::Run() {
  520. MS_LOG(INFO) << "Start.";
  521. const auto &calls = GetAllCallNodes(root_graph_);
  522. size_t cycle = 0;
  523. bool change = true;
  524. while (change) {
  525. change = false;
  526. MS_LOG(INFO) << "Func graph call analysis cycle:" << cycle;
  527. for (const auto &call : calls) {
  528. MS_LOG(INFO) << "Start analysis call node: " << call->DebugString();
  529. auto input0 = call->cast<CNodePtr>()->input(0);
  530. auto value_getter = value_manager_->GetValueGetter(input0);
  531. value_getter = value_getter->Visit(0, std::make_shared<HashSet<AnfNodePtr>>());
  532. change = value_manager_->UpdateGraphRelations(value_getter->GetFuncGraphs(), call) || change;
  533. }
  534. ++cycle;
  535. }
  536. DumpFuncGraphRealUsers();
  537. MS_LOG(INFO) << "End.";
  538. }
  539. std::vector<CNodePtr> FuncGraphAnalyzer::GetFuncGraphCallers(const FuncGraphPtr &func_graph) const {
  540. MS_EXCEPTION_IF_NULL(func_graph);
  541. auto it = value_manager_->func_graph_real_users_.find(func_graph);
  542. if (it == value_manager_->func_graph_real_users_.end()) {
  543. MS_LOG(INFO) << "Find func graph:" << func_graph->ToString() << " failed.";
  544. return {};
  545. }
  546. return it->second;
  547. }
  548. std::vector<FuncGraphPtr> FuncGraphAnalyzer::GetCallerFuncGraphs(const AnfNodePtr &node) const {
  549. const auto &closures = GetCallClosures(node);
  550. std::vector<FuncGraphPtr> func_graphs;
  551. for (const auto &closure : closures) {
  552. if (std::find(func_graphs.begin(), func_graphs.end(), closure->func_graph_) == func_graphs.end()) {
  553. func_graphs.emplace_back(closure->func_graph_);
  554. }
  555. }
  556. return func_graphs;
  557. }
  558. const std::vector<FuncClosurePtr> &FuncGraphAnalyzer::GetCallClosures(const AnfNodePtr &call) const {
  559. MS_EXCEPTION_IF_NULL(call);
  560. auto it = value_manager_->caller_closures_.find(call);
  561. if (it != value_manager_->caller_closures_.end()) {
  562. return it->second;
  563. }
  564. MS_LOG(EXCEPTION) << "Find closure of call: " << call->DebugString() << " failed.";
  565. }
  566. std::vector<AnfNodePtr> FuncGraphAnalyzer::GetArg(const AnfNodePtr &param, const AnfNodePtr &call) const {
  567. return value_manager_->GetArg(param, call);
  568. }
  569. void FuncGraphAnalyzer::DumpFuncGraphRealUsers() const {
  570. const auto &func_graph_callers = value_manager_->func_graph_real_users_;
  571. MS_LOG(INFO) << "Func graph size:" << func_graph_callers.size();
  572. size_t fg_index = 0;
  573. std::ostringstream buffer;
  574. buffer << "\n";
  575. for (auto &it : func_graph_callers) {
  576. const auto fg = it.first;
  577. const auto callers = it.second;
  578. buffer << "FuncGraph[" << fg_index++ << "]:" << fg->ToString() << "\n";
  579. for (size_t i = 0; i < callers.size(); i++) {
  580. buffer << "---->Caller[" << i << "]:" << callers[i]->DebugString() << "\n";
  581. }
  582. }
  583. MS_LOG(INFO) << buffer.str();
  584. }
  585. bool FuncGraphAnalyzer::ExistClosure() const {
  586. for (const auto &[call, closures] : value_manager_->caller_closures_) {
  587. for (const auto &closure : closures) {
  588. if (closure->arg_indexes_.empty()) {
  589. continue;
  590. }
  591. const auto &last_arg_user = closure->arg_users_.back();
  592. // Partial's arg and call are in same graph, this is not closure.
  593. if (last_arg_user->func_graph() != call->func_graph()) {
  594. return true;
  595. }
  596. }
  597. }
  598. return false;
  599. }
  600. bool FuncGraphAnalyzer::HasIncorporateCall() const { return value_manager_->has_incorporate_call_; }
  601. } // namespace mindspore
  602. // namespace mindspore