You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

func_graph_cloner.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/func_graph_cloner.h"
  17. #include <algorithm>
  18. #include "ir/manager.h"
  19. #include "ir/param_value.h"
  20. #include "operator/ops.h"
  21. #include "utils/convert_utils_base.h"
  22. #include "utils/log_adapter.h"
  23. #include "utils/profile.h"
  24. #include "utils/context/ms_context.h"
  25. // namespace to support intermediate representation definition
  26. namespace mindspore {
  27. Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
  28. bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
  29. : clone_all_valuenodes_(clone_all_valuenodes),
  30. clone_all_child_graphs_(clone_all_child_graphs),
  31. clone_all_used_graphs_(clone_all_used_graphs),
  32. relation_(relation),
  33. target_relation_(target_relation == nullptr ? relation : target_relation) {
  34. for (auto &func_graph : func_graphs) {
  35. AddClone(func_graph);
  36. }
  37. scope_ = kDefaultScope;
  38. type_ = kBasic;
  39. }
  40. void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
  41. const AnfNodePtrList &params, CloneType type) {
  42. if (func_graph != nullptr) {
  43. todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
  44. type_ = type;
  45. }
  46. }
  47. void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  48. MS_EXCEPTION_IF_NULL(node);
  49. if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
  50. return;
  51. }
  52. if (node->isa<Parameter>()) {
  53. CloneParameter(node, target);
  54. } else if (node->isa<CNode>()) {
  55. CloneCNode(node, target);
  56. }
  57. }
  58. void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
  59. MS_EXCEPTION_IF_NULL(node);
  60. MS_EXCEPTION_IF_NULL(target);
  61. TraceManager::DebugTrace(node->debug_info(), relation_);
  62. auto new_param = (is_add) ? target->add_parameter() : std::make_shared<Parameter>(target);
  63. auto old_param = node->cast<ParameterPtr>();
  64. new_param->set_abstract(old_param->abstract());
  65. new_param->set_name(old_param->name());
  66. if (old_param->has_default()) {
  67. // Default parameter can be shared since it is readonly.
  68. new_param->set_default_param(old_param->default_param());
  69. }
  70. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  71. new_param->set_scope(scope);
  72. repl_node_[node] = new_param;
  73. TraceManager::EndTrace();
  74. }
  75. void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  76. MS_EXCEPTION_IF_NULL(node);
  77. MS_EXCEPTION_IF_NULL(target);
  78. TraceManager::DebugTrace(node->debug_info(), relation_);
  79. CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
  80. auto old_node = node->cast<CNodePtr>();
  81. new_node->set_abstract(old_node->abstract());
  82. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  83. new_node->set_scope(scope);
  84. new_node->set_kernel_info(old_node->kernel_info_ptr());
  85. repl_node_[old_node] = new_node;
  86. nodes_.emplace_back(old_node, new_node);
  87. TraceManager::EndTrace();
  88. }
  89. void Cloner::CloneValueNode(const AnfNodePtr &node) {
  90. MS_EXCEPTION_IF_NULL(node);
  91. TraceManager::DebugTrace(node->debug_info(), relation_);
  92. ValueNodePtr new_const = NewValueNode(GetValueNode(node));
  93. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  94. new_const->set_scope(scope);
  95. new_const->set_abstract(node->abstract());
  96. repl_node_[node] = new_const;
  97. TraceManager::EndTrace();
  98. }
  99. void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
  100. MS_EXCEPTION_IF_NULL(node);
  101. MS_EXCEPTION_IF_NULL(target);
  102. TraceManager::DebugTrace(node->debug_info(), relation_);
  103. ValueNodePtr new_const = NewValueNode(target);
  104. ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
  105. new_const->set_scope(scope);
  106. new_const->set_abstract(node->abstract());
  107. repl_node_[node] = new_const;
  108. TraceManager::EndTrace();
  109. }
  110. void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
  111. MS_EXCEPTION_IF_NULL(func_graph);
  112. MS_EXCEPTION_IF_NULL(manager_);
  113. if (!clone_all_valuenodes_) {
  114. return;
  115. }
  116. auto &value_nodes = func_graph->value_nodes();
  117. for (auto &value_node : value_nodes) {
  118. auto old_node = value_node.first;
  119. MS_EXCEPTION_IF_NULL(old_node);
  120. if (repl_node_.count(old_node) == 0) {
  121. CloneValueNode(old_node);
  122. }
  123. }
  124. }
  125. void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
  126. MS_EXCEPTION_IF_NULL(func_graph);
  127. MS_EXCEPTION_IF_NULL(manager_);
  128. if (!clone_all_child_graphs_) {
  129. return;
  130. }
  131. auto &scopes = manager_->scopes(func_graph);
  132. for (auto &graph : scopes) {
  133. if (graph != func_graph) {
  134. todo_.push_back({graph, nullptr, {}});
  135. }
  136. }
  137. }
  138. void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
  139. MS_EXCEPTION_IF_NULL(func_graph);
  140. MS_EXCEPTION_IF_NULL(manager_);
  141. if (!clone_all_used_graphs_) {
  142. return;
  143. }
  144. auto &used = func_graph->func_graphs_used();
  145. for (auto &fg : used) {
  146. todo_.push_back({fg.first, nullptr, {}});
  147. }
  148. }
  149. void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  150. MS_EXCEPTION_IF_NULL(func_graph);
  151. MS_EXCEPTION_IF_NULL(target_func_graph);
  152. for (auto &item : func_graph->parameter_default_value()) {
  153. auto nodes = DeepLinkedGraphSearch(item.second);
  154. for (auto &node : nodes) {
  155. MS_EXCEPTION_IF_NULL(node);
  156. if (node->isa<CNode>()) {
  157. CloneNode(node, target_func_graph);
  158. } else if (node->isa<ValueNode>()) {
  159. CloneValueNode(node);
  160. }
  161. }
  162. }
  163. }
  164. void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  165. MS_EXCEPTION_IF_NULL(func_graph);
  166. MS_EXCEPTION_IF_NULL(target_func_graph);
  167. MS_EXCEPTION_IF_NULL(manager_);
  168. auto return_node = repl_node_[func_graph->get_return()]->cast<CNodePtr>();
  169. if (return_node == nullptr) {
  170. MS_LOG(EXCEPTION) << "Can't find replicate node for return.";
  171. }
  172. target_func_graph->set_return(return_node);
  173. auto &cnodes = func_graph->func_graph_cnodes_index();
  174. for (auto &cnode : cnodes) {
  175. auto parent = cnode.first->first->cast<CNodePtr>();
  176. auto valuenode = parent->input(cnode.first->second);
  177. CloneValueNode(valuenode, target_func_graph);
  178. }
  179. }
  180. void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
  181. MS_EXCEPTION_IF_NULL(func_graph);
  182. auto &old_params = func_graph->parameters();
  183. if (old_params.size() != params.size()) {
  184. MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
  185. return;
  186. }
  187. for (size_t i = 0; i < old_params.size(); ++i) {
  188. repl_node_[old_params[i]] = params[i];
  189. }
  190. }
  191. void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
  192. MS_EXCEPTION_IF_NULL(func_graph);
  193. MS_EXCEPTION_IF_NULL(target_func_graph);
  194. TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
  195. *target_func_graph = std::make_shared<FuncGraph>();
  196. (*target_func_graph)->set_attrs(func_graph->attrs());
  197. (*target_func_graph)->set_transforms(func_graph->transforms());
  198. (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
  199. (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
  200. (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
  201. (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
  202. (*target_func_graph)->set_is_generate(func_graph->is_generated());
  203. (*target_func_graph)->set_stub(func_graph->stub());
  204. TraceManager::EndTrace();
  205. }
  206. void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  207. MS_EXCEPTION_IF_NULL(func_graph);
  208. MS_EXCEPTION_IF_NULL(target_func_graph);
  209. auto &params = func_graph->parameters();
  210. for (auto &param : params) {
  211. CloneParameter(param, target_func_graph, true);
  212. }
  213. repl_func_graph_[func_graph] = target_func_graph;
  214. }
  215. void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
  216. MS_EXCEPTION_IF_NULL(func_graph);
  217. auto &free_vars = manager_->free_variables_total();
  218. auto iter = free_vars.find(func_graph);
  219. if (iter == free_vars.end()) {
  220. return;
  221. }
  222. for (auto &fv_map : iter->second) {
  223. auto &free_var = fv_map.first;
  224. if (utils::isa<AnfNodePtr>(free_var)) {
  225. repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
  226. }
  227. }
  228. }
  229. void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
  230. param->set_abstract(node->abstract());
  231. if (node->isa<Parameter>()) {
  232. ParameterPtr old_param = dyn_cast<Parameter>(node);
  233. if (old_param->has_default()) {
  234. // Default parameter can be shared since it is readonly.
  235. param->set_default_param(old_param->default_param());
  236. }
  237. param->set_name(old_param->name());
  238. }
  239. }
  240. ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
  241. TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
  242. ParameterPtr param = std::make_shared<Parameter>(func_graph);
  243. TraceManager::EndTrace();
  244. CloneParameter(param, node);
  245. if (is_add) {
  246. func_graph->add_parameter(param);
  247. }
  248. repl_node_[param] = node;
  249. repl_map_node_[func_graph][node] = param;
  250. return param;
  251. }
  252. void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params,
  253. AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
  254. AnfNodePtrList parameters;
  255. std::unordered_set<AnfNodePtr> old_params;
  256. for (auto &param : func_graph->parameters()) {
  257. auto iter = repl_node_.find(param);
  258. if (iter != repl_node_.end()) {
  259. (void)old_params.insert(iter->second);
  260. parameters.push_back(param);
  261. } else {
  262. parameters.push_back(AddParameter(func_graph, param, false));
  263. (void)old_params.insert(param);
  264. }
  265. }
  266. AnfNodePtr new_param = nullptr;
  267. for (auto &param : params) {
  268. auto old_param = repl_node_[param];
  269. if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
  270. repl_node_[old_param] = old_param;
  271. repl_map_node_[func_graph][old_param] = old_param;
  272. input_params->push_back(old_param);
  273. continue;
  274. }
  275. if (old_params.find(old_param) != old_params.end()) {
  276. new_param = repl_map_node_[func_graph][old_param];
  277. input_params->push_back(new_param);
  278. continue;
  279. }
  280. new_param = AddParameter(func_graph, old_param, false);
  281. parameters.push_back(new_param);
  282. lift_params->push_back(new_param);
  283. input_params->push_back(new_param);
  284. }
  285. func_graph->set_parameters(parameters);
  286. }
  287. void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
  288. const AnfNodePtrList &params) {
  289. AnfNodePtr node = nullptr;
  290. auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
  291. auto iter = repl_func_graph.find(func_graph);
  292. if (iter == repl_func_graph.end()) {
  293. node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
  294. repl_func_graph[func_graph] = node;
  295. } else {
  296. node = iter->second;
  297. }
  298. if (node == nullptr || !node->isa<CNode>()) {
  299. return;
  300. }
  301. auto cnode = node->cast<CNodePtr>();
  302. auto inputs = cnode->inputs();
  303. (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
  304. cnode->set_inputs(inputs);
  305. OrderParameters(func_graph, inputs);
  306. }
  307. void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
  308. std::unordered_set<AnfNodePtr> old_params;
  309. for (auto &param : func_graph->parameters()) {
  310. (void)old_params.insert(repl_node_[param]);
  311. }
  312. std::unordered_set<AnfNodePtr> new_params;
  313. AnfNodePtrList parameters;
  314. // Ignore the 1st and 2nd param of inputs(such as. partial graph)
  315. for (size_t i = 2; i < inputs.size(); ++i) {
  316. auto input = inputs[i];
  317. auto param = repl_node_[input];
  318. if (old_params.find(param) != old_params.end()) {
  319. auto new_param = repl_map_node_[func_graph][param];
  320. parameters.push_back(new_param);
  321. (void)new_params.insert(new_param);
  322. }
  323. }
  324. for (auto &param : func_graph->parameters()) {
  325. if (new_params.find(param) == new_params.end()) {
  326. parameters.push_back(param);
  327. }
  328. }
  329. func_graph->set_parameters(parameters);
  330. }
  331. void Cloner::SetEdges(const FuncGraphPtr &func_graph) {
  332. MS_EXCEPTION_IF_NULL(func_graph);
  333. for (auto &node : func_graph->nodes()) {
  334. if (node == nullptr) {
  335. continue;
  336. }
  337. // Only cnode needed to be handled
  338. if (!node->isa<CNode>()) {
  339. continue;
  340. }
  341. auto cnode = node->cast<CNodePtr>();
  342. auto &inputs = cnode->inputs();
  343. for (size_t i = 0; i < inputs.size(); i++) {
  344. auto &input = inputs[i];
  345. if (IsValueNode<FuncGraph>(input)) {
  346. auto graph = GetValueNode<FuncGraphPtr>(input);
  347. auto &repl_func_graph = repl_map_func_graph_[func_graph];
  348. if (repl_func_graph.find(graph) != repl_func_graph.end()) {
  349. transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
  350. }
  351. } else {
  352. auto &repl_node = repl_map_node_[func_graph];
  353. if (repl_node.find(input) != repl_node.end()) {
  354. transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
  355. }
  356. }
  357. }
  358. }
  359. }
  360. void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
  361. const AnfNodePtrList &params) {
  362. AnfNodePtrList lift_params;
  363. AnfNodePtrList input_params;
  364. AddParameters(func_graph_user, params, &lift_params, &input_params);
  365. AddInputs(func_graph_user, func_graph, input_params);
  366. if (lift_params.empty()) {
  367. return;
  368. }
  369. for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
  370. LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
  371. }
  372. }
  373. void Cloner::Lift() {
  374. for (auto &func_graph_params : repl_func_graph_params_) {
  375. auto &func_graph = func_graph_params.first;
  376. auto &params = func_graph_params.second;
  377. for (auto &cnode : func_graph->func_graph_cnodes_index()) {
  378. LiftParameters(cnode.first->first->func_graph(), func_graph, params);
  379. }
  380. }
  381. }
  382. void Cloner::LiftParameters() {
  383. MS_EXCEPTION_IF_NULL(manager_);
  384. transaction_ = manager_->Transact();
  385. const FuncGraphSet &func_graphs = manager_->func_graphs();
  386. for (auto &func_graph : func_graphs) {
  387. GenParameters(func_graph);
  388. }
  389. Lift();
  390. for (auto &func_graph : func_graphs) {
  391. SetEdges(func_graph);
  392. }
  393. transaction_.Commit();
  394. }
  395. bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
  396. MS_EXCEPTION_IF_NULL(func_graph);
  397. // Make sure only inline once
  398. if (status_.count(func_graph) != 0) {
  399. if (is_inline == status_[func_graph]) {
  400. return false;
  401. }
  402. if (clone_all_used_graphs_) {
  403. MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
  404. return false;
  405. }
  406. }
  407. return true;
  408. }
  409. void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
  410. MS_EXCEPTION_IF_NULL(func_graph);
  411. MS_EXCEPTION_IF_NULL(target_func_graph);
  412. MS_EXCEPTION_IF_NULL(manager_);
  413. const AnfNodeSet &nodes = func_graph->nodes();
  414. for (auto &node : nodes) {
  415. CloneNode(node, target_func_graph);
  416. }
  417. }
  418. void Cloner::Run() {
  419. if (todo_.empty()) {
  420. return;
  421. }
  422. if (type_ < kLifting) {
  423. // Basic and Inline Clone
  424. FuncGraphPtrList func_graphs;
  425. (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
  426. [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
  427. manager_ = Manage(func_graphs, false);
  428. CloneNodes();
  429. LinkEdges();
  430. SetDefaults();
  431. } else {
  432. // Lifting Clone
  433. CloneInfo item = todo_.back();
  434. manager_ = Manage(item.origin);
  435. LiftParameters();
  436. }
  437. }
  438. void Cloner::CloneNodes() {
  439. while (!todo_.empty()) {
  440. CloneInfo item = todo_.back();
  441. todo_.pop_back();
  442. bool is_inline = (item.target != nullptr);
  443. FuncGraphPtr func_graph = item.origin;
  444. FuncGraphPtr target_func_graph = item.target;
  445. (void)graph_set_.insert(func_graph);
  446. if (!CheckStatus(func_graph, is_inline)) {
  447. continue;
  448. }
  449. if (is_inline) {
  450. InlineCloneParameters(func_graph, item.params);
  451. CloneAllNodes(func_graph, target_func_graph);
  452. } else {
  453. SetFuncGraphInfo(func_graph, &target_func_graph);
  454. CloneParameters(func_graph, target_func_graph);
  455. CloneAllNodes(func_graph, target_func_graph);
  456. CloneFuncGraphValueNodes(func_graph, target_func_graph);
  457. CloneFuncGraphDefaultValues(func_graph, target_func_graph);
  458. }
  459. CloneValueNodes(func_graph);
  460. AddChildGraphs(func_graph);
  461. AddTotalGraphs(func_graph);
  462. status_[func_graph] = is_inline;
  463. }
  464. }
  465. void Cloner::LinkEdges() {
  466. for (auto &node_pair : nodes_) {
  467. CNodePtr old_node = node_pair.first;
  468. CNodePtr new_node = node_pair.second;
  469. MS_EXCEPTION_IF_NULL(old_node);
  470. MS_EXCEPTION_IF_NULL(new_node);
  471. for (auto &input : old_node->inputs()) {
  472. auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
  473. new_node->add_input(new_input);
  474. }
  475. }
  476. }
  477. // For the graphs cloned, update its default value map to the cloned nodes
  478. void Cloner::SetDefaults() {
  479. for (auto &item : graph_set_) {
  480. MS_EXCEPTION_IF_NULL(item);
  481. if (repl_func_graph_.count(item) != 0) {
  482. for (auto &param_def : item->parameter_default_value()) {
  483. MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
  484. if (repl_node_.count(param_def.second) != 0) {
  485. repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
  486. } else {
  487. repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second);
  488. }
  489. }
  490. }
  491. }
  492. }
  493. AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
  494. MS_EXCEPTION_IF_NULL(root);
  495. if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
  496. MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
  497. }
  498. CloneNode(root, repl_func_graph_[root->func_graph()]);
  499. auto iter = repl_node_.find(root);
  500. if (iter != repl_node_.end()) {
  501. return iter->second;
  502. }
  503. MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
  504. }
  505. AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
  506. #ifdef ENABLE_PROFILE
  507. double time = GetTime();
  508. #endif
  509. Run();
  510. #ifdef ENABLE_PROFILE
  511. MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time);
  512. #endif
  513. return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
  514. }
  515. FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
  516. #ifdef ENABLE_PROFILE
  517. double time = GetTime();
  518. #endif
  519. Run();
  520. #ifdef ENABLE_PROFILE
  521. MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time);
  522. #endif
  523. return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
  524. }
  525. FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
  526. MS_EXCEPTION_IF_NULL(func_graph);
  527. Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
  528. return cloner[func_graph];
  529. }
  530. AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
  531. const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
  532. MS_EXCEPTION_IF_NULL(func_graph);
  533. MS_EXCEPTION_IF_NULL(target_func_graph);
  534. Cloner cloner({}, false);
  535. if (scope != nullptr) {
  536. cloner.set_scope(scope);
  537. }
  538. cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
  539. return cloner[func_graph->output()];
  540. }
  541. FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
  542. MS_EXCEPTION_IF_NULL(func_graph);
  543. Cloner cloner({}, false);
  544. cloner.AddClone(func_graph, nullptr, {}, kLifting);
  545. return cloner[func_graph];
  546. }
  547. ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
  548. MS_EXCEPTION_IF_NULL(func_graph);
  549. FuncGraphPtrList func_graphs = {func_graph};
  550. ClonerPtr cloner =
  551. std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
  552. #ifdef ENABLE_PROFILE
  553. double time = GetTime();
  554. #endif
  555. cloner->Run();
  556. #ifdef ENABLE_PROFILE
  557. MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time);
  558. #endif
  559. return cloner;
  560. }
  561. FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
  562. MS_EXCEPTION_IF_NULL(func_graph);
  563. TraceManager::DebugTrace(func_graph->debug_info(), relation);
  564. auto new_func_graph = std::make_shared<FuncGraph>();
  565. TraceManager::EndTrace();
  566. auto &parameters = func_graph->parameters();
  567. (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
  568. MS_EXCEPTION_IF_NULL(param);
  569. TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
  570. (void)new_func_graph->add_parameter();
  571. TraceManager::EndTrace();
  572. });
  573. Cloner cloner = Cloner();
  574. cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
  575. AnfNodePtr output = cloner[func_graph->output()];
  576. new_func_graph->set_output(output);
  577. new_func_graph->set_has_vararg(func_graph->has_vararg());
  578. new_func_graph->set_has_kwarg(func_graph->has_kwarg());
  579. new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
  580. new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
  581. new_func_graph->set_is_generate(func_graph->is_generated());
  582. new_func_graph->set_stub(func_graph->stub());
  583. for (auto &item : func_graph->parameter_default_value()) {
  584. new_func_graph->set_param_default_value(item.first, cloner[item.second]);
  585. }
  586. if (MsContext::GetInstance()->is_multi_graph_sink()) {
  587. if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
  588. new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
  589. }
  590. }
  591. if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  592. new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
  593. }
  594. return new_func_graph;
  595. }
  596. } // namespace mindspore