You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph_cloner.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/func_graph_cloner.h"
  17. #include <algorithm>
  18. #include "ir/manager.h"
  19. #include "ir/param_value_py.h"
  20. #include "operator/ops.h"
  21. #include "utils/convert_utils_base.h"
  22. #include "utils/log_adapter.h"
  23. #include "utils/profile.h"
  24. #include "utils/context/ms_context.h"
  25. // namespace to support intermediate representation definition
  26. namespace mindspore {
  27. Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
  28. bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
  29. : clone_all_valuenodes_(clone_all_valuenodes),
  30. clone_all_child_graphs_(clone_all_child_graphs),
  31. clone_all_used_graphs_(clone_all_used_graphs),
  32. relation_(relation),
  33. target_relation_(target_relation == nullptr ? relation : target_relation) {
  34. for (auto &func_graph : func_graphs) {
  35. AddClone(func_graph);
  36. }
  37. scope_ = kDefaultScope;
  38. type_ = kBasic;
  39. }
  40. void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
  41. const AnfNodePtrList &params, CloneType type) {
  42. if (func_graph != nullptr) {
  43. todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
  44. type_ = type;
  45. }
  46. }
  47. void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  48. MS_EXCEPTION_IF_NULL(node);
  49. if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
  50. return;
  51. }
  52. if (node->isa<Parameter>()) {
  53. CloneParameter(node, target);
  54. } else if (node->isa<CNode>()) {
  55. CloneCNode(node, target);
  56. }
  57. }
  58. void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
  59. MS_EXCEPTION_IF_NULL(node);
  60. MS_EXCEPTION_IF_NULL(target);
  61. TraceManager::DebugTrace(node->debug_info(), relation_);
  62. auto new_param = (is_add) ? target->add_parameter() : std::make_shared<Parameter>(target);
  63. auto old_param = node->cast<ParameterPtr>();
  64. new_param->set_abstract(old_param->abstract());
  65. new_param->set_name(old_param->name());
  66. if (old_param->has_default()) {
  67. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
  68. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  69. new_param->set_default_param(param_value_new);
  70. }
  71. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  72. new_param->set_scope(scope);
  73. repl_node_[node] = new_param;
  74. TraceManager::EndTrace();
  75. }
  76. void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  77. MS_EXCEPTION_IF_NULL(node);
  78. MS_EXCEPTION_IF_NULL(target);
  79. TraceManager::DebugTrace(node->debug_info(), relation_);
  80. CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
  81. auto old_node = node->cast<CNodePtr>();
  82. new_node->set_abstract(old_node->abstract());
  83. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  84. new_node->set_scope(scope);
  85. new_node->set_kernel_info(old_node->kernel_info_ptr());
  86. repl_node_[old_node] = new_node;
  87. nodes_.emplace_back(old_node, new_node);
  88. TraceManager::EndTrace();
  89. }
  90. void Cloner::CloneValueNode(const AnfNodePtr &node) {
  91. MS_EXCEPTION_IF_NULL(node);
  92. TraceManager::DebugTrace(node->debug_info(), relation_);
  93. ValueNodePtr new_const = NewValueNode(GetValueNode(node));
  94. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  95. new_const->set_scope(scope);
  96. new_const->set_abstract(node->abstract());
  97. repl_node_[node] = new_const;
  98. TraceManager::EndTrace();
  99. }
  100. void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  101. MS_EXCEPTION_IF_NULL(node);
  102. MS_EXCEPTION_IF_NULL(target);
  103. TraceManager::DebugTrace(node->debug_info(), relation_);
  104. ValueNodePtr new_const = NewValueNode(target);
  105. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  106. new_const->set_scope(scope);
  107. new_const->set_abstract(node->abstract());
  108. repl_node_[node] = new_const;
  109. TraceManager::EndTrace();
  110. }
  111. void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
  112. MS_EXCEPTION_IF_NULL(func_graph);
  113. MS_EXCEPTION_IF_NULL(manager_);
  114. if (!clone_all_valuenodes_) {
  115. return;
  116. }
  117. auto &value_nodes = func_graph->value_nodes();
  118. for (auto &value_node : value_nodes) {
  119. auto old_node = value_node.first;
  120. MS_EXCEPTION_IF_NULL(old_node);
  121. if (repl_node_.count(old_node) == 0) {
  122. CloneValueNode(old_node);
  123. }
  124. }
  125. }
  126. void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
  127. MS_EXCEPTION_IF_NULL(func_graph);
  128. MS_EXCEPTION_IF_NULL(manager_);
  129. if (!clone_all_child_graphs_) {
  130. return;
  131. }
  132. auto &scopes = manager_->scopes(func_graph);
  133. for (auto &graph : scopes) {
  134. if (graph != func_graph) {
  135. todo_.push_back({graph, nullptr, {}});
  136. }
  137. }
  138. }
  139. void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
  140. MS_EXCEPTION_IF_NULL(func_graph);
  141. MS_EXCEPTION_IF_NULL(manager_);
  142. if (!clone_all_used_graphs_) {
  143. return;
  144. }
  145. auto &used = func_graph->func_graphs_used();
  146. for (auto &fg : used) {
  147. todo_.push_back({fg.first, nullptr, {}});
  148. }
  149. }
  150. void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  151. MS_EXCEPTION_IF_NULL(func_graph);
  152. MS_EXCEPTION_IF_NULL(target_func_graph);
  153. for (auto &item : func_graph->parameter_default_value()) {
  154. auto nodes = DeepLinkedGraphSearch(item.second);
  155. for (auto &node : nodes) {
  156. MS_EXCEPTION_IF_NULL(node);
  157. if (node->isa<CNode>()) {
  158. CloneNode(node, target_func_graph);
  159. } else if (node->isa<ValueNode>()) {
  160. CloneValueNode(node);
  161. }
  162. }
  163. }
  164. }
  165. void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  166. MS_EXCEPTION_IF_NULL(func_graph);
  167. MS_EXCEPTION_IF_NULL(target_func_graph);
  168. MS_EXCEPTION_IF_NULL(manager_);
  169. auto return_node = repl_node_[func_graph->get_return()]->cast<CNodePtr>();
  170. if (return_node == nullptr) {
  171. MS_LOG(EXCEPTION) << "Can't find replicate node for return.";
  172. }
  173. target_func_graph->set_return(return_node);
  174. auto &cnodes = func_graph->func_graph_cnodes_index();
  175. for (auto &cnode : cnodes) {
  176. auto parent = cnode.first->first->cast<CNodePtr>();
  177. auto valuenode = parent->input(cnode.first->second);
  178. CloneValueNode(valuenode, target_func_graph);
  179. }
  180. }
  181. void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
  182. MS_EXCEPTION_IF_NULL(func_graph);
  183. auto &old_params = func_graph->parameters();
  184. if (old_params.size() != params.size()) {
  185. MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
  186. return;
  187. }
  188. for (size_t i = 0; i < old_params.size(); ++i) {
  189. repl_node_[old_params[i]] = params[i];
  190. }
  191. }
  192. void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
  193. MS_EXCEPTION_IF_NULL(func_graph);
  194. MS_EXCEPTION_IF_NULL(target_func_graph);
  195. TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
  196. *target_func_graph = std::make_shared<FuncGraph>();
  197. (*target_func_graph)->set_attrs(func_graph->attrs());
  198. (*target_func_graph)->set_transforms(func_graph->transforms());
  199. (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
  200. (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
  201. (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
  202. (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
  203. (*target_func_graph)->set_is_generate(func_graph->is_generated());
  204. TraceManager::EndTrace();
  205. }
  206. void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  207. MS_EXCEPTION_IF_NULL(func_graph);
  208. MS_EXCEPTION_IF_NULL(target_func_graph);
  209. auto &params = func_graph->parameters();
  210. for (auto &param : params) {
  211. CloneParameter(param, target_func_graph, true);
  212. }
  213. repl_func_graph_[func_graph] = target_func_graph;
  214. }
  215. void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
  216. MS_EXCEPTION_IF_NULL(func_graph);
  217. auto &free_vars = manager_->free_variables_total();
  218. auto iter = free_vars.find(func_graph);
  219. if (iter == free_vars.end()) {
  220. return;
  221. }
  222. for (auto &fv_map : iter->second) {
  223. auto &free_var = fv_map.first;
  224. if (utils::isa<AnfNodePtr>(free_var)) {
  225. repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
  226. }
  227. }
  228. }
  229. void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
  230. param->set_abstract(node->abstract());
  231. if (node->isa<Parameter>()) {
  232. ParameterPtr old_param = dyn_cast<Parameter>(node);
  233. if (old_param->has_default()) {
  234. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
  235. auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
  236. param->set_default_param(param_value_new);
  237. }
  238. param->set_name(old_param->name());
  239. }
  240. }
  241. ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
  242. TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
  243. ParameterPtr param = std::make_shared<Parameter>(func_graph);
  244. TraceManager::EndTrace();
  245. CloneParameter(param, node);
  246. if (is_add) {
  247. func_graph->add_parameter(param);
  248. }
  249. repl_node_[param] = node;
  250. repl_map_node_[func_graph][node] = param;
  251. return param;
  252. }
  253. void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params,
  254. AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
  255. AnfNodePtrList parameters;
  256. std::unordered_set<AnfNodePtr> old_params;
  257. for (auto &param : func_graph->parameters()) {
  258. auto iter = repl_node_.find(param);
  259. if (iter != repl_node_.end()) {
  260. (void)old_params.insert(iter->second);
  261. parameters.push_back(param);
  262. } else {
  263. parameters.push_back(AddParameter(func_graph, param, false));
  264. (void)old_params.insert(param);
  265. }
  266. }
  267. AnfNodePtr new_param = nullptr;
  268. for (auto &param : params) {
  269. auto old_param = repl_node_[param];
  270. if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
  271. repl_node_[old_param] = old_param;
  272. repl_map_node_[func_graph][old_param] = old_param;
  273. input_params->push_back(old_param);
  274. continue;
  275. }
  276. if (old_params.find(old_param) != old_params.end()) {
  277. new_param = repl_map_node_[func_graph][old_param];
  278. input_params->push_back(new_param);
  279. continue;
  280. }
  281. new_param = AddParameter(func_graph, old_param, false);
  282. parameters.push_back(new_param);
  283. lift_params->push_back(new_param);
  284. input_params->push_back(new_param);
  285. }
  286. func_graph->set_parameters(parameters);
  287. }
  288. void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
  289. const AnfNodePtrList &params) {
  290. AnfNodePtr node = nullptr;
  291. auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
  292. auto iter = repl_func_graph.find(func_graph);
  293. if (iter == repl_func_graph.end()) {
  294. node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
  295. repl_func_graph[func_graph] = node;
  296. } else {
  297. node = iter->second;
  298. }
  299. if (node == nullptr || !node->isa<CNode>()) {
  300. return;
  301. }
  302. auto cnode = node->cast<CNodePtr>();
  303. auto inputs = cnode->inputs();
  304. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  305. cnode->set_inputs(inputs);
  306. OrderParameters(func_graph, inputs);
  307. }
  308. void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
  309. std::unordered_set<AnfNodePtr> old_params;
  310. for (auto &param : func_graph->parameters()) {
  311. (void)old_params.insert(repl_node_[param]);
  312. }
  313. std::unordered_set<AnfNodePtr> new_params;
  314. AnfNodePtrList parameters;
  315. // Ignore the 1st and 2nd param of inputs(such as. partial graph)
  316. for (size_t i = 2; i < inputs.size(); ++i) {
  317. auto input = inputs[i];
  318. auto param = repl_node_[input];
  319. if (old_params.find(param) != old_params.end()) {
  320. auto new_param = repl_map_node_[func_graph][param];
  321. parameters.push_back(new_param);
  322. (void)new_params.insert(new_param);
  323. }
  324. }
  325. for (auto &param : func_graph->parameters()) {
  326. if (new_params.find(param) == new_params.end()) {
  327. parameters.push_back(param);
  328. }
  329. }
  330. func_graph->set_parameters(parameters);
  331. }
  332. void Cloner::SetEdges(const FuncGraphPtr &func_graph) {
  333. MS_EXCEPTION_IF_NULL(func_graph);
  334. for (auto &node : func_graph->nodes()) {
  335. if (node == nullptr) {
  336. continue;
  337. }
  338. // Only cnode needed to be handled
  339. if (!node->isa<CNode>()) {
  340. continue;
  341. }
  342. auto cnode = node->cast<CNodePtr>();
  343. auto &inputs = cnode->inputs();
  344. for (size_t i = 0; i < inputs.size(); i++) {
  345. auto &input = inputs[i];
  346. if (IsValueNode<FuncGraph>(input)) {
  347. auto graph = GetValueNode<FuncGraphPtr>(input);
  348. auto &repl_func_graph = repl_map_func_graph_[func_graph];
  349. if (repl_func_graph.find(graph) != repl_func_graph.end()) {
  350. transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
  351. }
  352. } else {
  353. auto &repl_node = repl_map_node_[func_graph];
  354. if (repl_node.find(input) != repl_node.end()) {
  355. transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
  356. }
  357. }
  358. }
  359. }
  360. }
  361. void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
  362. const AnfNodePtrList &params) {
  363. AnfNodePtrList lift_params;
  364. AnfNodePtrList input_params;
  365. AddParameters(func_graph_user, params, &lift_params, &input_params);
  366. AddInputs(func_graph_user, func_graph, input_params);
  367. if (lift_params.empty()) {
  368. return;
  369. }
  370. for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
  371. LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
  372. }
  373. }
  374. void Cloner::Lift() {
  375. for (auto &func_graph_params : repl_func_graph_params_) {
  376. auto &func_graph = func_graph_params.first;
  377. auto &params = func_graph_params.second;
  378. for (auto &cnode : func_graph->func_graph_cnodes_index()) {
  379. LiftParameters(cnode.first->first->func_graph(), func_graph, params);
  380. }
  381. }
  382. }
  383. void Cloner::LiftParameters() {
  384. MS_EXCEPTION_IF_NULL(manager_);
  385. transaction_ = manager_->Transact();
  386. const FuncGraphSet &func_graphs = manager_->func_graphs();
  387. for (auto &func_graph : func_graphs) {
  388. GenParameters(func_graph);
  389. }
  390. Lift();
  391. for (auto &func_graph : func_graphs) {
  392. SetEdges(func_graph);
  393. }
  394. transaction_.Commit();
  395. }
  396. bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
  397. MS_EXCEPTION_IF_NULL(func_graph);
  398. // Make sure only inline once
  399. if (status_.count(func_graph) != 0) {
  400. if (is_inline == status_[func_graph]) {
  401. return false;
  402. }
  403. if (clone_all_used_graphs_) {
  404. MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
  405. return false;
  406. }
  407. }
  408. return true;
  409. }
  410. void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  411. MS_EXCEPTION_IF_NULL(func_graph);
  412. MS_EXCEPTION_IF_NULL(target_func_graph);
  413. MS_EXCEPTION_IF_NULL(manager_);
  414. const AnfNodeSet &nodes = func_graph->nodes();
  415. for (auto &node : nodes) {
  416. CloneNode(node, target_func_graph);
  417. }
  418. }
  419. void Cloner::Run() {
  420. if (todo_.empty()) {
  421. return;
  422. }
  423. if (type_ < kLifting) {
  424. // Basic and Inline Clone
  425. FuncGraphPtrList func_graphs;
  426. (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
  427. [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
  428. manager_ = Manage(func_graphs, false);
  429. CloneNodes();
  430. LinkEdges();
  431. SetDefaults();
  432. } else {
  433. // Lifting Clone
  434. CloneInfo item = todo_.back();
  435. manager_ = Manage(item.origin);
  436. LiftParameters();
  437. }
  438. }
  439. void Cloner::CloneNodes() {
  440. while (!todo_.empty()) {
  441. CloneInfo item = todo_.back();
  442. todo_.pop_back();
  443. bool is_inline = (item.target != nullptr);
  444. FuncGraphPtr func_graph = item.origin;
  445. FuncGraphPtr target_func_graph = item.target;
  446. (void)graph_set_.insert(func_graph);
  447. if (!CheckStatus(func_graph, is_inline)) {
  448. continue;
  449. }
  450. if (is_inline) {
  451. InlineCloneParameters(func_graph, item.params);
  452. CloneAllNodes(func_graph, target_func_graph);
  453. } else {
  454. SetFuncGraphInfo(func_graph, &target_func_graph);
  455. CloneParameters(func_graph, target_func_graph);
  456. CloneAllNodes(func_graph, target_func_graph);
  457. CloneFuncGraphValueNodes(func_graph, target_func_graph);
  458. CloneFuncGraphDefaultValues(func_graph, target_func_graph);
  459. }
  460. CloneValueNodes(func_graph);
  461. AddChildGraphs(func_graph);
  462. AddTotalGraphs(func_graph);
  463. status_[func_graph] = is_inline;
  464. }
  465. }
  466. void Cloner::LinkEdges() {
  467. for (auto &node_pair : nodes_) {
  468. CNodePtr old_node = node_pair.first;
  469. CNodePtr new_node = node_pair.second;
  470. MS_EXCEPTION_IF_NULL(old_node);
  471. MS_EXCEPTION_IF_NULL(new_node);
  472. for (auto &input : old_node->inputs()) {
  473. auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
  474. new_node->add_input(new_input);
  475. }
  476. }
  477. }
  478. // For the graphs cloned, update its default value map to the cloned nodes
  479. void Cloner::SetDefaults() {
  480. for (auto &item : graph_set_) {
  481. MS_EXCEPTION_IF_NULL(item);
  482. if (repl_func_graph_.count(item) != 0) {
  483. for (auto &param_def : item->parameter_default_value()) {
  484. MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
  485. if (repl_node_.count(param_def.second) != 0) {
  486. repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
  487. } else {
  488. repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second);
  489. }
  490. }
  491. }
  492. }
  493. }
  494. AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
  495. MS_EXCEPTION_IF_NULL(root);
  496. if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
  497. MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
  498. }
  499. CloneNode(root, repl_func_graph_[root->func_graph()]);
  500. auto iter = repl_node_.find(root);
  501. if (iter != repl_node_.end()) {
  502. return iter->second;
  503. }
  504. MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
  505. }
  506. AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
  507. #ifdef ENABLE_PROFILE
  508. double time = GetTime();
  509. #endif
  510. Run();
  511. #ifdef ENABLE_PROFILE
  512. MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time);
  513. #endif
  514. return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
  515. }
  516. FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
  517. #ifdef ENABLE_PROFILE
  518. double time = GetTime();
  519. #endif
  520. Run();
  521. #ifdef ENABLE_PROFILE
  522. MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time);
  523. #endif
  524. return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
  525. }
  526. FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
  527. MS_EXCEPTION_IF_NULL(func_graph);
  528. Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
  529. return cloner[func_graph];
  530. }
  531. AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
  532. const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
  533. MS_EXCEPTION_IF_NULL(func_graph);
  534. MS_EXCEPTION_IF_NULL(target_func_graph);
  535. Cloner cloner({}, false);
  536. if (scope != nullptr) {
  537. cloner.set_scope(scope);
  538. }
  539. cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
  540. return cloner[func_graph->output()];
  541. }
  542. FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
  543. MS_EXCEPTION_IF_NULL(func_graph);
  544. Cloner cloner({}, false);
  545. cloner.AddClone(func_graph, nullptr, {}, kLifting);
  546. return cloner[func_graph];
  547. }
  548. ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
  549. MS_EXCEPTION_IF_NULL(func_graph);
  550. FuncGraphPtrList func_graphs = {func_graph};
  551. ClonerPtr cloner =
  552. std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
  553. #ifdef ENABLE_PROFILE
  554. double time = GetTime();
  555. #endif
  556. cloner->Run();
  557. #ifdef ENABLE_PROFILE
  558. MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time);
  559. #endif
  560. return cloner;
  561. }
  562. FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
  563. MS_EXCEPTION_IF_NULL(func_graph);
  564. TraceManager::DebugTrace(func_graph->debug_info(), relation);
  565. auto new_func_graph = std::make_shared<FuncGraph>();
  566. TraceManager::EndTrace();
  567. auto &parameters = func_graph->parameters();
  568. (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
  569. MS_EXCEPTION_IF_NULL(param);
  570. TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
  571. (void)new_func_graph->add_parameter();
  572. TraceManager::EndTrace();
  573. });
  574. Cloner cloner = Cloner();
  575. cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
  576. AnfNodePtr output = cloner[func_graph->output()];
  577. new_func_graph->set_output(output);
  578. new_func_graph->set_has_vararg(func_graph->has_vararg());
  579. new_func_graph->set_has_kwarg(func_graph->has_kwarg());
  580. new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
  581. new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
  582. new_func_graph->set_is_generate(func_graph->is_generated());
  583. for (auto &item : func_graph->parameter_default_value()) {
  584. new_func_graph->set_param_default_value(item.first, cloner[item.second]);
  585. }
  586. if (MsContext::GetInstance()->is_multi_graph_sink()) {
  587. if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
  588. new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
  589. }
  590. }
  591. if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  592. new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  593. }
  594. return new_func_graph;
  595. }
  596. } // namespace mindspore