You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph_cloner.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/func_graph_cloner.h"
  17. #include <algorithm>
  18. #include "ir/manager.h"
  19. #include "ir/param_value_py.h"
  20. #include "operator/ops.h"
  21. #include "utils/log_adapter.h"
  22. #include "utils/profile.h"
  23. #include "utils/context/ms_context.h"
  24. // namespace to support intermediate representation definition
  25. namespace mindspore {
  26. Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
  27. bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
  28. : clone_all_valuenodes_(clone_all_valuenodes),
  29. clone_all_child_graphs_(clone_all_child_graphs),
  30. clone_all_used_graphs_(clone_all_used_graphs),
  31. relation_(relation),
  32. target_relation_(target_relation == nullptr ? relation : target_relation) {
  33. for (auto &func_graph : func_graphs) {
  34. AddClone(func_graph);
  35. }
  36. scope_ = kDefaultScope;
  37. type_ = kBasic;
  38. }
  39. void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
  40. const AnfNodePtrList &params, CloneType type) {
  41. if (func_graph != nullptr) {
  42. todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
  43. type_ = type;
  44. }
  45. }
  46. void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  47. MS_EXCEPTION_IF_NULL(node);
  48. if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
  49. return;
  50. }
  51. if (node->isa<Parameter>()) {
  52. CloneParameter(node, target);
  53. } else if (node->isa<CNode>()) {
  54. CloneCNode(node, target);
  55. }
  56. }
  57. void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
  58. MS_EXCEPTION_IF_NULL(node);
  59. MS_EXCEPTION_IF_NULL(target);
  60. TraceManager::DebugTrace(node->debug_info(), relation_);
  61. auto new_param = (is_add) ? target->add_parameter() : std::make_shared<Parameter>(target);
  62. auto old_param = node->cast<ParameterPtr>();
  63. new_param->set_abstract(old_param->abstract());
  64. new_param->set_name(old_param->name());
  65. if (old_param->has_default()) {
  66. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
  67. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  68. new_param->set_default_param(param_value_new);
  69. }
  70. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  71. new_param->set_scope(scope);
  72. repl_node_[node] = new_param;
  73. TraceManager::EndTrace();
  74. }
  75. void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  76. MS_EXCEPTION_IF_NULL(node);
  77. MS_EXCEPTION_IF_NULL(target);
  78. TraceManager::DebugTrace(node->debug_info(), relation_);
  79. CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
  80. auto old_node = node->cast<CNodePtr>();
  81. new_node->set_abstract(old_node->abstract());
  82. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  83. new_node->set_scope(scope);
  84. repl_node_[old_node] = new_node;
  85. nodes_.emplace_back(old_node, new_node);
  86. TraceManager::EndTrace();
  87. }
  88. void Cloner::CloneValueNode(const AnfNodePtr &node) {
  89. MS_EXCEPTION_IF_NULL(node);
  90. TraceManager::DebugTrace(node->debug_info(), relation_);
  91. ValueNodePtr new_const = NewValueNode(GetValueNode(node));
  92. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  93. new_const->set_scope(scope);
  94. new_const->set_abstract(node->abstract());
  95. repl_node_[node] = new_const;
  96. TraceManager::EndTrace();
  97. }
  98. void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  99. MS_EXCEPTION_IF_NULL(node);
  100. MS_EXCEPTION_IF_NULL(target);
  101. TraceManager::DebugTrace(node->debug_info(), relation_);
  102. ValueNodePtr new_const = NewValueNode(target);
  103. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  104. new_const->set_scope(scope);
  105. new_const->set_abstract(node->abstract());
  106. repl_node_[node] = new_const;
  107. TraceManager::EndTrace();
  108. }
  109. void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
  110. MS_EXCEPTION_IF_NULL(func_graph);
  111. MS_EXCEPTION_IF_NULL(manager_);
  112. if (!clone_all_valuenodes_) {
  113. return;
  114. }
  115. auto &value_nodes = func_graph->value_nodes();
  116. for (auto &value_node : value_nodes) {
  117. auto old_node = value_node.first;
  118. MS_EXCEPTION_IF_NULL(old_node);
  119. if (repl_node_.count(old_node) == 0) {
  120. CloneValueNode(old_node);
  121. }
  122. }
  123. }
  124. void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
  125. MS_EXCEPTION_IF_NULL(func_graph);
  126. MS_EXCEPTION_IF_NULL(manager_);
  127. if (!clone_all_child_graphs_) {
  128. return;
  129. }
  130. auto &scopes = manager_->scopes(func_graph);
  131. for (auto &graph : scopes) {
  132. if (graph != func_graph) {
  133. todo_.push_back({graph, nullptr, {}});
  134. }
  135. }
  136. }
  137. void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
  138. MS_EXCEPTION_IF_NULL(func_graph);
  139. MS_EXCEPTION_IF_NULL(manager_);
  140. if (!clone_all_used_graphs_) {
  141. return;
  142. }
  143. auto &used = func_graph->func_graphs_used();
  144. for (auto &fg : used) {
  145. todo_.push_back({fg.first, nullptr, {}});
  146. }
  147. }
  148. void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  149. MS_EXCEPTION_IF_NULL(func_graph);
  150. MS_EXCEPTION_IF_NULL(target_func_graph);
  151. for (auto &item : func_graph->parameter_default_value()) {
  152. auto nodes = DeepLinkedGraphSearch(item.second);
  153. for (auto &node : nodes) {
  154. MS_EXCEPTION_IF_NULL(node);
  155. if (node->isa<CNode>()) {
  156. CloneNode(node, target_func_graph);
  157. } else if (node->isa<ValueNode>()) {
  158. CloneValueNode(node);
  159. }
  160. }
  161. }
  162. }
  163. void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  164. MS_EXCEPTION_IF_NULL(func_graph);
  165. MS_EXCEPTION_IF_NULL(target_func_graph);
  166. MS_EXCEPTION_IF_NULL(manager_);
  167. auto return_node = repl_node_[func_graph->get_return()]->cast<CNodePtr>();
  168. if (return_node == nullptr) {
  169. MS_LOG(EXCEPTION) << "Can't find replicate node for return.";
  170. }
  171. target_func_graph->set_return(return_node);
  172. auto &cnodes = func_graph->func_graph_cnodes_index();
  173. for (auto &cnode : cnodes) {
  174. auto parent = cnode.first->first->cast<CNodePtr>();
  175. auto valuenode = parent->input(cnode.first->second);
  176. CloneValueNode(valuenode, target_func_graph);
  177. }
  178. }
  179. void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
  180. MS_EXCEPTION_IF_NULL(func_graph);
  181. auto &old_params = func_graph->parameters();
  182. if (old_params.size() != params.size()) {
  183. MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
  184. return;
  185. }
  186. for (size_t i = 0; i < old_params.size(); ++i) {
  187. repl_node_[old_params[i]] = params[i];
  188. }
  189. }
  190. void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
  191. MS_EXCEPTION_IF_NULL(func_graph);
  192. MS_EXCEPTION_IF_NULL(target_func_graph);
  193. TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
  194. *target_func_graph = std::make_shared<FuncGraph>();
  195. (*target_func_graph)->set_flags(func_graph->flags());
  196. (*target_func_graph)->set_transforms(func_graph->transforms());
  197. (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
  198. (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
  199. (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
  200. (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
  201. (*target_func_graph)->set_is_generate(func_graph->is_generated());
  202. TraceManager::EndTrace();
  203. }
  204. void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  205. MS_EXCEPTION_IF_NULL(func_graph);
  206. MS_EXCEPTION_IF_NULL(target_func_graph);
  207. auto &params = func_graph->parameters();
  208. for (auto &param : params) {
  209. CloneParameter(param, target_func_graph, true);
  210. }
  211. repl_func_graph_[func_graph] = target_func_graph;
  212. }
  213. void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
  214. MS_EXCEPTION_IF_NULL(func_graph);
  215. auto &free_vars = manager_->free_variables_total();
  216. auto iter = free_vars.find(func_graph);
  217. if (iter == free_vars.end()) {
  218. return;
  219. }
  220. for (auto &fv_map : iter->second) {
  221. auto &free_var = fv_map.first;
  222. if (utils::isa<AnfNodePtr>(free_var)) {
  223. repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
  224. }
  225. }
  226. }
  227. void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
  228. param->set_abstract(node->abstract());
  229. if (node->isa<Parameter>()) {
  230. ParameterPtr old_param = dyn_cast<Parameter>(node);
  231. if (old_param->has_default()) {
  232. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
  233. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  234. param->set_default_param(param_value_new);
  235. }
  236. param->set_name(old_param->name());
  237. }
  238. }
  239. ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
  240. TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
  241. ParameterPtr param = std::make_shared<Parameter>(func_graph);
  242. TraceManager::EndTrace();
  243. CloneParameter(param, node);
  244. if (is_add) {
  245. func_graph->add_parameter(param);
  246. }
  247. repl_node_[param] = node;
  248. repl_map_node_[func_graph][node] = param;
  249. return param;
  250. }
  251. void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params,
  252. AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
  253. AnfNodePtrList parameters;
  254. std::unordered_set<AnfNodePtr> old_params;
  255. for (auto &param : func_graph->parameters()) {
  256. auto iter = repl_node_.find(param);
  257. if (iter != repl_node_.end()) {
  258. (void)old_params.insert(iter->second);
  259. parameters.push_back(param);
  260. } else {
  261. parameters.push_back(AddParameter(func_graph, param, false));
  262. (void)old_params.insert(param);
  263. }
  264. }
  265. AnfNodePtr new_param = nullptr;
  266. for (auto &param : params) {
  267. auto old_param = repl_node_[param];
  268. if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
  269. repl_node_[old_param] = old_param;
  270. repl_map_node_[func_graph][old_param] = old_param;
  271. input_params->push_back(old_param);
  272. continue;
  273. }
  274. if (old_params.find(old_param) != old_params.end()) {
  275. new_param = repl_map_node_[func_graph][old_param];
  276. input_params->push_back(new_param);
  277. continue;
  278. }
  279. new_param = AddParameter(func_graph, old_param, false);
  280. parameters.push_back(new_param);
  281. lift_params->push_back(new_param);
  282. input_params->push_back(new_param);
  283. }
  284. func_graph->set_parameters(parameters);
  285. }
  286. void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
  287. const AnfNodePtrList &params) {
  288. AnfNodePtr node = nullptr;
  289. auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
  290. auto iter = repl_func_graph.find(func_graph);
  291. if (iter == repl_func_graph.end()) {
  292. node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
  293. repl_func_graph[func_graph] = node;
  294. } else {
  295. node = iter->second;
  296. }
  297. if (node == nullptr || !node->isa<CNode>()) {
  298. return;
  299. }
  300. auto cnode = node->cast<CNodePtr>();
  301. auto inputs = cnode->inputs();
  302. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  303. cnode->set_inputs(inputs);
  304. OrderParameters(func_graph, inputs);
  305. }
  306. void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
  307. std::unordered_set<AnfNodePtr> old_params;
  308. for (auto &param : func_graph->parameters()) {
  309. (void)old_params.insert(repl_node_[param]);
  310. }
  311. std::unordered_set<AnfNodePtr> new_params;
  312. AnfNodePtrList parameters;
  313. // Ignore the 1st and 2nd param of inputs(such as. partial graph)
  314. for (size_t i = 2; i < inputs.size(); ++i) {
  315. auto input = inputs[i];
  316. auto param = repl_node_[input];
  317. if (old_params.find(param) != old_params.end()) {
  318. auto new_param = repl_map_node_[func_graph][param];
  319. parameters.push_back(new_param);
  320. (void)new_params.insert(new_param);
  321. }
  322. }
  323. for (auto &param : func_graph->parameters()) {
  324. if (new_params.find(param) == new_params.end()) {
  325. parameters.push_back(param);
  326. }
  327. }
  328. func_graph->set_parameters(parameters);
  329. }
  330. void Cloner::SetEdges(const FuncGraphPtr &func_graph) {
  331. MS_EXCEPTION_IF_NULL(func_graph);
  332. for (auto &node : func_graph->nodes()) {
  333. if (node == nullptr) {
  334. continue;
  335. }
  336. // Only cnode needed to be handled
  337. if (!node->isa<CNode>()) {
  338. continue;
  339. }
  340. auto cnode = node->cast<CNodePtr>();
  341. auto &inputs = cnode->inputs();
  342. for (size_t i = 0; i < inputs.size(); i++) {
  343. auto &input = inputs[i];
  344. if (IsValueNode<FuncGraph>(input)) {
  345. auto graph = GetValueNode<FuncGraphPtr>(input);
  346. auto &repl_func_graph = repl_map_func_graph_[func_graph];
  347. if (repl_func_graph.find(graph) != repl_func_graph.end()) {
  348. transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
  349. }
  350. } else {
  351. auto &repl_node = repl_map_node_[func_graph];
  352. if (repl_node.find(input) != repl_node.end()) {
  353. transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
  354. }
  355. }
  356. }
  357. }
  358. }
  359. void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
  360. const AnfNodePtrList &params) {
  361. AnfNodePtrList lift_params;
  362. AnfNodePtrList input_params;
  363. AddParameters(func_graph_user, params, &lift_params, &input_params);
  364. AddInputs(func_graph_user, func_graph, input_params);
  365. if (lift_params.empty()) {
  366. return;
  367. }
  368. for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
  369. LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
  370. }
  371. }
  372. void Cloner::Lift() {
  373. for (auto &func_graph_params : repl_func_graph_params_) {
  374. auto &func_graph = func_graph_params.first;
  375. auto &params = func_graph_params.second;
  376. for (auto &cnode : func_graph->func_graph_cnodes_index()) {
  377. LiftParameters(cnode.first->first->func_graph(), func_graph, params);
  378. }
  379. }
  380. }
  381. void Cloner::LiftParameters() {
  382. MS_EXCEPTION_IF_NULL(manager_);
  383. transaction_ = manager_->Transact();
  384. const FuncGraphSet &func_graphs = manager_->func_graphs();
  385. for (auto &func_graph : func_graphs) {
  386. GenParameters(func_graph);
  387. }
  388. Lift();
  389. for (auto &func_graph : func_graphs) {
  390. SetEdges(func_graph);
  391. }
  392. transaction_.Commit();
  393. }
  394. bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
  395. MS_EXCEPTION_IF_NULL(func_graph);
  396. // Make sure only inline once
  397. if (status_.count(func_graph) != 0) {
  398. if (is_inline == status_[func_graph]) {
  399. return false;
  400. }
  401. if (clone_all_used_graphs_) {
  402. MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
  403. return false;
  404. }
  405. }
  406. return true;
  407. }
  408. void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  409. MS_EXCEPTION_IF_NULL(func_graph);
  410. MS_EXCEPTION_IF_NULL(target_func_graph);
  411. MS_EXCEPTION_IF_NULL(manager_);
  412. const AnfNodeSet &nodes = func_graph->nodes();
  413. for (auto &node : nodes) {
  414. CloneNode(node, target_func_graph);
  415. }
  416. }
  417. void Cloner::Run() {
  418. if (todo_.empty()) {
  419. return;
  420. }
  421. if (type_ < kLifting) {
  422. // Basic and Inline Clone
  423. FuncGraphPtrList func_graphs;
  424. (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
  425. [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
  426. manager_ = Manage(func_graphs, false);
  427. CloneNodes();
  428. LinkEdges();
  429. SetDefaults();
  430. } else {
  431. // Lifting Clone
  432. CloneInfo item = todo_.back();
  433. manager_ = Manage(item.origin);
  434. LiftParameters();
  435. }
  436. }
  437. void Cloner::CloneNodes() {
  438. while (!todo_.empty()) {
  439. CloneInfo item = todo_.back();
  440. todo_.pop_back();
  441. bool is_inline = (item.target != nullptr);
  442. FuncGraphPtr func_graph = item.origin;
  443. FuncGraphPtr target_func_graph = item.target;
  444. (void)graph_set_.insert(func_graph);
  445. if (!CheckStatus(func_graph, is_inline)) {
  446. continue;
  447. }
  448. if (is_inline) {
  449. InlineCloneParameters(func_graph, item.params);
  450. CloneAllNodes(func_graph, target_func_graph);
  451. } else {
  452. SetFuncGraphInfo(func_graph, &target_func_graph);
  453. CloneParameters(func_graph, target_func_graph);
  454. CloneAllNodes(func_graph, target_func_graph);
  455. CloneFuncGraphValueNodes(func_graph, target_func_graph);
  456. CloneFuncGraphDefaultValues(func_graph, target_func_graph);
  457. }
  458. CloneValueNodes(func_graph);
  459. AddChildGraphs(func_graph);
  460. AddTotalGraphs(func_graph);
  461. status_[func_graph] = is_inline;
  462. }
  463. }
  464. void Cloner::LinkEdges() {
  465. for (auto &node_pair : nodes_) {
  466. CNodePtr old_node = node_pair.first;
  467. CNodePtr new_node = node_pair.second;
  468. MS_EXCEPTION_IF_NULL(old_node);
  469. MS_EXCEPTION_IF_NULL(new_node);
  470. for (auto &input : old_node->inputs()) {
  471. auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
  472. new_node->add_input(new_input);
  473. }
  474. }
  475. }
  476. // For the graphs cloned, update its default value map to the cloned nodes
  477. void Cloner::SetDefaults() {
  478. for (auto &item : graph_set_) {
  479. MS_EXCEPTION_IF_NULL(item);
  480. if (repl_func_graph_.count(item) != 0) {
  481. for (auto &param_def : item->parameter_default_value()) {
  482. MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
  483. if (repl_node_.count(param_def.second) != 0) {
  484. repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
  485. } else {
  486. repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second);
  487. }
  488. }
  489. }
  490. }
  491. }
  492. AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
  493. MS_EXCEPTION_IF_NULL(root);
  494. if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
  495. MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
  496. }
  497. CloneNode(root, repl_func_graph_[root->func_graph()]);
  498. auto iter = repl_node_.find(root);
  499. if (iter != repl_node_.end()) {
  500. return iter->second;
  501. }
  502. MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
  503. }
  504. AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
  505. #ifdef ENABLE_PROFILE
  506. double time = GetTime();
  507. #endif
  508. Run();
  509. #ifdef ENABLE_PROFILE
  510. MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time);
  511. #endif
  512. return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
  513. }
  514. FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
  515. #ifdef ENABLE_PROFILE
  516. double time = GetTime();
  517. #endif
  518. Run();
  519. #ifdef ENABLE_PROFILE
  520. MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time);
  521. #endif
  522. return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
  523. }
  524. FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
  525. MS_EXCEPTION_IF_NULL(func_graph);
  526. Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
  527. return cloner[func_graph];
  528. }
  529. AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
  530. const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
  531. MS_EXCEPTION_IF_NULL(func_graph);
  532. MS_EXCEPTION_IF_NULL(target_func_graph);
  533. Cloner cloner({}, false);
  534. if (scope != nullptr) {
  535. cloner.set_scope(scope);
  536. }
  537. cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
  538. return cloner[func_graph->output()];
  539. }
  540. FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
  541. MS_EXCEPTION_IF_NULL(func_graph);
  542. Cloner cloner({}, false);
  543. cloner.AddClone(func_graph, nullptr, {}, kLifting);
  544. return cloner[func_graph];
  545. }
  546. ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
  547. MS_EXCEPTION_IF_NULL(func_graph);
  548. FuncGraphPtrList func_graphs = {func_graph};
  549. ClonerPtr cloner =
  550. std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
  551. #ifdef ENABLE_PROFILE
  552. double time = GetTime();
  553. #endif
  554. cloner->Run();
  555. #ifdef ENABLE_PROFILE
  556. MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time);
  557. #endif
  558. return cloner;
  559. }
  560. FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
  561. MS_EXCEPTION_IF_NULL(func_graph);
  562. TraceManager::DebugTrace(func_graph->debug_info(), relation);
  563. auto new_func_graph = std::make_shared<FuncGraph>();
  564. TraceManager::EndTrace();
  565. auto &parameters = func_graph->parameters();
  566. (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
  567. MS_EXCEPTION_IF_NULL(param);
  568. TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
  569. (void)new_func_graph->add_parameter();
  570. TraceManager::EndTrace();
  571. });
  572. Cloner cloner = Cloner();
  573. cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
  574. AnfNodePtr output = cloner[func_graph->output()];
  575. new_func_graph->set_output(output);
  576. new_func_graph->set_has_vararg(func_graph->has_vararg());
  577. new_func_graph->set_has_kwarg(func_graph->has_kwarg());
  578. new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
  579. new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
  580. new_func_graph->set_is_generate(func_graph->is_generated());
  581. for (auto &item : func_graph->parameter_default_value()) {
  582. new_func_graph->set_param_default_value(item.first, cloner[item.second]);
  583. }
  584. if (MsContext::GetInstance()->is_multi_graph_sink()) {
  585. if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
  586. new_func_graph->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
  587. }
  588. }
  589. return new_func_graph;
  590. }
  591. } // namespace mindspore