You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cc 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/common/helper.h"
  17. #include <string>
  18. #include <utility>
  19. #include <unordered_set>
  20. #include <algorithm>
  21. #include <map>
  22. #include <set>
  23. #include <deque>
  24. #include "utils/utils.h"
  25. #include "base/base_ref.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "base/core_ops.h"
  28. #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
  29. #include "frontend/operator/ops.h"
  30. #include "utils/ms_utils.h"
  31. #include "runtime/device/kernel_info.h"
  32. #include "utils/ms_context.h"
  33. namespace mindspore {
  34. namespace opt {
  35. constexpr size_t kType32Len = 4;
  36. std::vector<int> Convert2Int(const std::vector<size_t> &v) {
  37. std::vector<int> result;
  38. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
  39. return result;
  40. }
  41. bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
  42. MS_EXCEPTION_IF_NULL(node);
  43. std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
  44. std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
  45. for (auto &nd : node_list) {
  46. MS_EXCEPTION_IF_NULL(nd);
  47. if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
  48. auto control_depend = nd->cast<CNodePtr>();
  49. auto prior_node = control_depend->input(kControlDependPriorIndex);
  50. auto behind_node = control_depend->input(kControlDependBehindIndex);
  51. auto it = control_depend_map.find(behind_node);
  52. if (it == control_depend_map.end()) {
  53. control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node};
  54. } else {
  55. it->second.insert(prior_node);
  56. }
  57. }
  58. }
  59. FuncGraphManagerPtr manager = graph.manager();
  60. MS_EXCEPTION_IF_NULL(manager);
  61. std::unordered_set<AnfNodePtr> seen_node;
  62. std::deque<AnfNodePtr> todo{node};
  63. while (!todo.empty()) {
  64. AnfNodePtr nd = todo.front();
  65. todo.pop_front();
  66. if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
  67. continue;
  68. }
  69. (void)seen_node.insert(nd);
  70. if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
  71. return true;
  72. }
  73. if (nd->isa<CNode>()) {
  74. auto cnode = nd->cast<CNodePtr>();
  75. MS_EXCEPTION_IF_NULL(cnode);
  76. auto inputs = cnode->inputs();
  77. (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
  78. }
  79. auto it = control_depend_map.find(nd);
  80. if (it != control_depend_map.end()) {
  81. (void)todo.insert(todo.end(), it->second.begin(), it->second.end());
  82. }
  83. }
  84. return false;
  85. }
  86. bool UnVisited(const BaseRef &n) {
  87. if (utils::isa<AnfNodePtr>(n)) {
  88. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  89. MS_EXCEPTION_IF_NULL(in);
  90. if (IsValueNode<Primitive>(in)) {
  91. auto value_node = in->cast<ValueNodePtr>();
  92. MS_EXCEPTION_IF_NULL(value_node);
  93. auto value = value_node->value();
  94. MS_EXCEPTION_IF_NULL(value);
  95. auto prim_py = value->cast<PrimitivePtr>();
  96. MS_EXCEPTION_IF_NULL(prim_py);
  97. return !prim_py->HasAttr(kAttrVisited);
  98. } else if (IsValueNode<FuncGraph>(in)) {
  99. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  100. MS_EXCEPTION_IF_NULL(func_graph);
  101. return !func_graph->has_flag(kAttrVisited);
  102. }
  103. return false;
  104. }
  105. return false;
  106. }
  107. bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) {
  108. MS_EXCEPTION_IF_NULL(node);
  109. if (!node->isa<CNode>()) {
  110. MS_LOG(ERROR) << "The node is expected to be a cnode";
  111. return false;
  112. }
  113. *cnode = node->cast<CNodePtr>();
  114. if (*cnode == nullptr) {
  115. return false;
  116. }
  117. if ((*cnode)->inputs().size() < IntToSize(input_size)) {
  118. auto op_name = AnfAlgo::GetCNodeName(*cnode);
  119. MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs.";
  120. return false;
  121. }
  122. return true;
  123. }
  124. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) {
  125. MS_EXCEPTION_IF_NULL(node);
  126. if (!node->isa<CNode>()) {
  127. MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
  128. }
  129. auto cnode = node->cast<CNodePtr>();
  130. MS_EXCEPTION_IF_NULL(cnode);
  131. if (cnode->inputs().size() != IntToSize(input_size)) {
  132. auto op_name = AnfAlgo::GetCNodeName(cnode);
  133. MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs.";
  134. }
  135. return cnode;
  136. }
  137. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) {
  138. MS_EXCEPTION_IF_NULL(cnode);
  139. if (cnode->inputs().size() != input_size) {
  140. MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size;
  141. }
  142. }
  143. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
  144. MS_EXCEPTION_IF_NULL(node_x);
  145. MS_EXCEPTION_IF_NULL(node_y);
  146. return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
  147. AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
  148. }
  149. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  150. MS_EXCEPTION_IF_NULL(func_graph);
  151. auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
  152. MS_EXCEPTION_IF_NULL(transop_cnode);
  153. auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
  154. auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
  155. MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
  156. MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1));
  157. auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1);
  158. MS_EXCEPTION_IF_NULL(transed_node);
  159. std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
  160. depend_cnode->input(kDependInputNum - 1)};
  161. AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
  162. MS_EXCEPTION_IF_NULL(replace_depend);
  163. auto transed_abstract = transed_node->abstract();
  164. replace_depend->set_abstract(transed_abstract);
  165. return replace_depend;
  166. }
  167. bool Visited(const BaseRef &n) {
  168. if (utils::isa<AnfNodePtr>(n)) {
  169. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  170. MS_EXCEPTION_IF_NULL(in);
  171. if (IsValueNode<Primitive>(in)) {
  172. auto value_node = in->cast<ValueNodePtr>();
  173. MS_EXCEPTION_IF_NULL(value_node);
  174. auto value = value_node->value();
  175. MS_EXCEPTION_IF_NULL(value);
  176. auto prim_py = value->cast<PrimitivePtr>();
  177. MS_EXCEPTION_IF_NULL(prim_py);
  178. return prim_py->HasAttr(kAttrVisited);
  179. } else if (IsValueNode<FuncGraph>(in)) {
  180. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  181. MS_EXCEPTION_IF_NULL(func_graph);
  182. return func_graph->has_flag(kAttrVisited);
  183. }
  184. return false;
  185. }
  186. return false;
  187. }
  188. void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode,
  189. std::vector<AnfNodePtr> *conv_bn1_outputs) {
  190. auto prim = std::make_shared<Primitive>(kConvBN1OpName);
  191. std::vector<AnfNodePtr> conv_bn1_inputs = {NewValueNode(prim)};
  192. MS_EXCEPTION_IF_NULL(conv_cnode);
  193. // All the inputs of conv_bn1 are from the inputs of conv
  194. for (size_t i = 1; i < conv_cnode->inputs().size(); i++) {
  195. conv_bn1_inputs.push_back(conv_cnode->input(i));
  196. }
  197. MS_EXCEPTION_IF_NULL(func_graph);
  198. CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs);
  199. MS_EXCEPTION_IF_NULL(conv_bn1_cnode);
  200. auto kernel_info = std::make_shared<device::KernelInfo>();
  201. conv_bn1_cnode->set_kernel_info(kernel_info);
  202. // Set attr for conv_bn1
  203. AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode);
  204. // Set abstract of conv_bn1
  205. MS_EXCEPTION_IF_NULL(bn_cnode);
  206. auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_cnode->abstract());
  207. MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
  208. AbstractBasePtrList conv_bn1_abstract_list;
  209. conv_bn1_abstract_list.push_back(conv_cnode->abstract());
  210. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(
  211. kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1)));
  212. conv_bn1_abstract_list.push_back(abstract_tensor);
  213. conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]);
  214. auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(conv_bn1_abstract_list);
  215. conv_bn1_cnode->set_abstract(abstract_tuple);
  216. CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs);
  217. }
  218. void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &fused_bn1_outputs,
  219. const CNodePtr &bn_node, std::vector<AnfNodePtr> *fused_bn2_outputs) {
  220. MS_EXCEPTION_IF_NULL(graph);
  221. MS_EXCEPTION_IF_NULL(bn_node);
  222. MS_EXCEPTION_IF_NULL(fused_bn2_outputs);
  223. if (bn_node->inputs().size() != kBnInputNum) {
  224. MS_LOG(EXCEPTION) << "BN node has wrong input size";
  225. }
  226. if (fused_bn1_outputs.size() != kBN1OutputNum) {
  227. MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size";
  228. }
  229. // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn
  230. std::vector<AnfNodePtr> fused_bn2_inputs = {NewValueNode(std::make_shared<Primitive>(kFusedBN2OpName))};
  231. fused_bn2_inputs.push_back(fused_bn1_outputs[0]);
  232. fused_bn2_inputs.push_back(fused_bn1_outputs[1]);
  233. fused_bn2_inputs.push_back(bn_node->input(4));
  234. fused_bn2_inputs.push_back(bn_node->input(5));
  235. auto fused_bn2 = graph->NewCNode(fused_bn2_inputs);
  236. MS_EXCEPTION_IF_NULL(fused_bn2);
  237. auto kernel_info = std::make_shared<device::KernelInfo>();
  238. fused_bn2->set_kernel_info(kernel_info);
  239. auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1),
  240. AnfAlgo::GetOutputInferDataType(bn_node, 2)};
  241. auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1),
  242. AnfAlgo::GetOutputInferShape(bn_node, 2)};
  243. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get());
  244. fused_bn2->set_scope(bn_node->scope());
  245. AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2);
  246. CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs);
  247. }
  248. void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input,
  249. const std::vector<AnfNodePtr> &fused_bn1_outputs,
  250. const std::vector<AnfNodePtr> &fused_bn2_outputs, const CNodePtr &bn_node,
  251. std::vector<AnfNodePtr> *fused_bn3_outputs) {
  252. MS_EXCEPTION_IF_NULL(graph);
  253. MS_EXCEPTION_IF_NULL(data_input);
  254. MS_EXCEPTION_IF_NULL(bn_node);
  255. MS_EXCEPTION_IF_NULL(fused_bn3_outputs);
  256. if (bn_node->inputs().size() != kBnInputNum) {
  257. MS_LOG(EXCEPTION) << "BN node has wrong input size";
  258. }
  259. if (fused_bn1_outputs.size() != kBN1OutputNum) {
  260. MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size";
  261. }
  262. if (fused_bn2_outputs.size() != kBN2OutputNum) {
  263. MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size";
  264. }
  265. // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn
  266. std::vector<AnfNodePtr> fused_bn3_inputs = {NewValueNode(std::make_shared<Primitive>(kFusedBN3OpName))};
  267. fused_bn3_inputs.push_back(data_input);
  268. fused_bn3_inputs.push_back(fused_bn1_outputs[0]);
  269. fused_bn3_inputs.push_back(fused_bn2_outputs[0]);
  270. fused_bn3_inputs.push_back(bn_node->input(2));
  271. fused_bn3_inputs.push_back(bn_node->input(3));
  272. auto fused_bn3 = graph->NewCNode(fused_bn3_inputs);
  273. MS_EXCEPTION_IF_NULL(fused_bn3);
  274. auto kernel_info = std::make_shared<device::KernelInfo>();
  275. fused_bn3->set_kernel_info(kernel_info);
  276. auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)};
  277. auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)};
  278. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get());
  279. fused_bn3->set_scope(bn_node->scope());
  280. AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3);
  281. (*fused_bn3_outputs).push_back(fused_bn3);
  282. }
  283. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
  284. std::vector<AnfNodePtr> *outputs) {
  285. MS_EXCEPTION_IF_NULL(func_graph);
  286. MS_EXCEPTION_IF_NULL(node);
  287. MS_EXCEPTION_IF_NULL(outputs);
  288. for (size_t i = 0; i < output_num; i++) {
  289. int temp = SizeToInt(i);
  290. auto idx = NewValueNode(temp);
  291. MS_EXCEPTION_IF_NULL(idx);
  292. auto imm = std::make_shared<Int32Imm>(temp);
  293. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  294. idx->set_abstract(abstract_scalar);
  295. auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  296. MS_EXCEPTION_IF_NULL(tuple_getitem);
  297. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)},
  298. {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
  299. (*outputs).push_back(tuple_getitem);
  300. }
  301. }
  302. template <typename T>
  303. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  304. size_t data_length) {
  305. MS_EXCEPTION_IF_NULL(value_tuple_ptr);
  306. MS_EXCEPTION_IF_NULL(type_ptr);
  307. std::vector<T> values;
  308. for (const auto &v : value_tuple_ptr->value()) {
  309. MS_EXCEPTION_IF_NULL(v);
  310. if (v->isa<Scalar>()) {
  311. ScalarPtr scalar = v->cast<ScalarPtr>();
  312. values.push_back(GetValue<T>(scalar));
  313. } else {
  314. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  315. return nullptr;
  316. }
  317. }
  318. std::vector<int> tensor_shape = {SizeToInt(values.size())};
  319. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
  320. MS_EXCEPTION_IF_NULL(tensor);
  321. tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
  322. tensor->set_device_info(device_info);
  323. auto data_ptr = tensor->data_c();
  324. MS_EXCEPTION_IF_NULL(data_ptr);
  325. auto elem_num = values.size() * data_length;
  326. auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
  327. if (ret_code != 0) {
  328. MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
  329. }
  330. return tensor;
  331. }
  332. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
  333. MS_EXCEPTION_IF_NULL(value_tuple);
  334. tensor::TensorPtr tensor = nullptr;
  335. if (value_tuple->value().empty()) {
  336. MS_LOG(WARNING) << "The value tuple is empty.";
  337. return nullptr;
  338. }
  339. ValuePtr v = *(value_tuple->value().begin());
  340. MS_EXCEPTION_IF_NULL(v);
  341. // Currently we only deal with the scalar tuple
  342. if (!v->isa<Scalar>()) {
  343. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  344. return nullptr;
  345. }
  346. ScalarPtr scalar = v->cast<ScalarPtr>();
  347. MS_EXCEPTION_IF_NULL(scalar);
  348. if (scalar->isa<Int32Imm>()) {
  349. tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
  350. } else if (scalar->isa<Int64Imm>()) {
  351. tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
  352. } else if (scalar->isa<FloatImm>()) {
  353. tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
  354. } else {
  355. auto type = scalar->type();
  356. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  357. MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
  358. return nullptr;
  359. }
  360. return tensor;
  361. }
  362. bool IsNopNode(const AnfNodePtr &node) {
  363. auto context_ptr = MsContext::GetInstance();
  364. MS_EXCEPTION_IF_NULL(context_ptr);
  365. if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
  366. context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  367. return false;
  368. }
  369. static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
  370. prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
  371. kFlattenGradOpName};
  372. if (node == nullptr || !node->isa<CNode>()) {
  373. return false;
  374. }
  375. CNodePtr cnode = node->cast<CNodePtr>();
  376. MS_EXCEPTION_IF_NULL(cnode);
  377. bool is_nop_node = false;
  378. if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
  379. is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
  380. }
  381. if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
  382. return false;
  383. }
  384. return true;
  385. }
  386. bool IsAllNopNode(const session::KernelGraph *const graph) {
  387. MS_EXCEPTION_IF_NULL(graph);
  388. auto execution_order = graph->execution_order();
  389. for (auto &cnode : execution_order) {
  390. MS_EXCEPTION_IF_NULL(cnode);
  391. if (!IsNopNode(cnode)) {
  392. return false;
  393. }
  394. }
  395. return true;
  396. }
  397. void HideNopNode(session::KernelGraph *const graph) {
  398. MS_EXCEPTION_IF_NULL(graph);
  399. if (IsAllNopNode(graph) == true) {
  400. return;
  401. }
  402. auto execution_order = graph->execution_order();
  403. MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
  404. std::vector<CNodePtr> new_nodes;
  405. for (auto &cnode : execution_order) {
  406. MS_EXCEPTION_IF_NULL(cnode);
  407. if (!IsNopNode(cnode)) {
  408. new_nodes.push_back(cnode);
  409. }
  410. }
  411. graph->set_execution_order(new_nodes);
  412. MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
  413. }
  414. void RemoveNopNode(session::KernelGraph *const graph) {
  415. MS_EXCEPTION_IF_NULL(graph);
  416. if (IsAllNopNode(graph) == true) {
  417. return;
  418. }
  419. bool changed = true;
  420. while (changed) {
  421. changed = false;
  422. std::vector<CNodePtr> new_nodes;
  423. for (auto &cnode : graph->execution_order()) {
  424. MS_EXCEPTION_IF_NULL(cnode);
  425. // ignore nop node itself
  426. if (IsNopNode(cnode)) {
  427. continue;
  428. }
  429. // Replace the input which is nop node
  430. std::vector<AnfNodePtr> new_inputs;
  431. new_inputs.push_back(cnode->input(0));
  432. bool need_update = false;
  433. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  434. auto input = cnode->input(i);
  435. MS_EXCEPTION_IF_NULL(input);
  436. auto cinput = input->cast<CNodePtr>();
  437. if (cinput == nullptr || !IsNopNode(cinput)) {
  438. new_inputs.push_back(input);
  439. continue;
  440. }
  441. if (cinput->inputs().size() == 2) {
  442. new_inputs.push_back(cinput->input(1));
  443. need_update = true;
  444. changed = true;
  445. } else {
  446. new_inputs.push_back(input);
  447. }
  448. }
  449. if (need_update) {
  450. cnode->set_inputs(new_inputs);
  451. }
  452. // push into new execution list
  453. new_nodes.push_back(cnode);
  454. }
  455. graph->set_execution_order(new_nodes);
  456. }
  457. }
  458. size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  459. auto out_list = GetRealNodeUsedList(graph, node);
  460. MS_EXCEPTION_IF_NULL(out_list);
  461. return out_list->size();
  462. }
  463. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  464. const AnfNodePtr &node) {
  465. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  466. MS_EXCEPTION_IF_NULL(graph);
  467. auto manager = graph->manager();
  468. MS_EXCEPTION_IF_NULL(manager);
  469. auto iter = manager->node_users().find(node);
  470. if (iter == manager->node_users().end()) {
  471. MS_LOG(EXCEPTION) << "node has no output in manager";
  472. }
  473. auto output_info_list = iter->second;
  474. for (const auto &output_info : output_info_list) {
  475. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
  476. continue;
  477. }
  478. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
  479. output_info.second == kDependAttachNodeIndex) {
  480. continue;
  481. }
  482. output_node_list->push_back(output_info);
  483. }
  484. return output_node_list;
  485. }
  486. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  487. const AnfNodePtr &node,
  488. size_t output_index) {
  489. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  490. MS_EXCEPTION_IF_NULL(graph);
  491. auto manager = graph->manager();
  492. MS_EXCEPTION_IF_NULL(manager);
  493. auto iter = manager->node_users().find(node);
  494. if (iter == manager->node_users().end()) {
  495. MS_LOG(EXCEPTION) << "node has no output in manager";
  496. }
  497. auto output_info_list = iter->second;
  498. for (const auto &output_info : output_info_list) {
  499. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
  500. continue;
  501. }
  502. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
  503. output_info.second == kDependAttachNodeIndex) {
  504. continue;
  505. }
  506. size_t used_output_index;
  507. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
  508. used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
  509. } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
  510. used_output_index = output_index;
  511. } else {
  512. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, output_info.second - 1);
  513. if (kernel_with_index.first.get() != node.get()) {
  514. MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
  515. }
  516. used_output_index = kernel_with_index.second;
  517. }
  518. if (used_output_index == output_index) {
  519. output_node_list->push_back(output_info);
  520. }
  521. }
  522. return output_node_list;
  523. }
  524. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  525. MS_EXCEPTION_IF_NULL(graph);
  526. MS_EXCEPTION_IF_NULL(node);
  527. auto output_node_list = GetRealNodeUsedList(graph, node);
  528. MS_EXCEPTION_IF_NULL(output_node_list);
  529. return output_node_list->size() > 1;
  530. }
  531. bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  532. MS_EXCEPTION_IF_NULL(graph);
  533. MS_EXCEPTION_IF_NULL(node);
  534. auto output_node_list = GetRealNodeUsedList(graph, node);
  535. MS_EXCEPTION_IF_NULL(output_node_list);
  536. if (output_node_list->empty()) {
  537. return true;
  538. }
  539. for (const auto &output : *output_node_list) {
  540. auto out_node = output.first;
  541. auto name = AnfAlgo::GetCNodeName(out_node);
  542. if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
  543. name == prim::kPrimTupleGetItem->name()) {
  544. auto result = IsNotRealUsedByOthers(graph, out_node);
  545. if (!result) {
  546. return result;
  547. }
  548. continue;
  549. }
  550. return false;
  551. }
  552. return true;
  553. }
  554. AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
  555. auto idx = NewValueNode(SizeToInt(output_idx));
  556. MS_EXCEPTION_IF_NULL(idx);
  557. auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx));
  558. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  559. idx->set_abstract(abstract_scalar);
  560. AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  561. MS_EXCEPTION_IF_NULL(tuple_getitem);
  562. tuple_getitem->set_scope(node->scope());
  563. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  564. TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
  565. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
  566. return tuple_getitem;
  567. }
  568. void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
  569. MS_EXCEPTION_IF_NULL(cnode);
  570. std::vector<AnfNodePtr> new_inputs;
  571. std::vector<std::string> new_input_names;
  572. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  573. MS_EXCEPTION_IF_NULL(primitive);
  574. primitive = primitive->Clone();
  575. auto input_names = primitive->GetAttr(kAttrInputNames);
  576. if (input_names == nullptr) {
  577. MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
  578. return;
  579. }
  580. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  581. auto inputs = cnode->inputs();
  582. new_inputs.push_back(inputs[0]);
  583. bool need_update = false;
  584. for (size_t i = 0; i < inputs.size() - 1; ++i) {
  585. auto input_node = inputs[i + 1];
  586. MS_EXCEPTION_IF_NULL(input_node);
  587. if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
  588. auto value_node = input_node->cast<ValueNodePtr>();
  589. MS_EXCEPTION_IF_NULL(value_node);
  590. MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
  591. if (i >= input_names_vec.size()) {
  592. MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
  593. }
  594. primitive->set_attr(input_names_vec[i], value_node->value());
  595. need_update = true;
  596. } else {
  597. new_inputs.push_back(input_node);
  598. if (i < input_names_vec.size()) {
  599. new_input_names.push_back(input_names_vec[i]);
  600. }
  601. }
  602. }
  603. if (need_update) {
  604. // Update cnode's inputs
  605. new_inputs[0] = NewValueNode(primitive);
  606. cnode->set_inputs(new_inputs);
  607. // Update cnode's input_names attr
  608. primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
  609. }
  610. }
  611. bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  612. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  613. auto a_node = utils::cast<AnfNodePtr>(a);
  614. auto b_node = utils::cast<AnfNodePtr>(b);
  615. MS_EXCEPTION_IF_NULL(a_node);
  616. MS_EXCEPTION_IF_NULL(b_node);
  617. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  618. auto a_value_node = a_node->cast<ValueNodePtr>();
  619. MS_EXCEPTION_IF_NULL(a_value_node);
  620. auto a_value = a_value_node->value();
  621. MS_EXCEPTION_IF_NULL(a_value);
  622. auto a_prim = a_value->cast<PrimitivePtr>();
  623. MS_EXCEPTION_IF_NULL(a_prim);
  624. auto b_value_node = b_node->cast<ValueNodePtr>();
  625. MS_EXCEPTION_IF_NULL(b_value_node);
  626. auto b_value = b_value_node->value();
  627. MS_EXCEPTION_IF_NULL(b_value);
  628. auto b_prim = b_value->cast<PrimitivePtr>();
  629. MS_EXCEPTION_IF_NULL(b_prim);
  630. return a_prim->name() == b_prim->name();
  631. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  632. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  633. if (a_value_node_ptr == nullptr) {
  634. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  635. }
  636. auto a_value_ptr = a_value_node_ptr->value();
  637. if (a_value_ptr == nullptr) {
  638. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  639. }
  640. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  641. if (b_value_node_ptr == nullptr) {
  642. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  643. }
  644. auto b_value_ptr = b_value_node_ptr->value();
  645. if (b_value_ptr == nullptr) {
  646. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  647. }
  648. return (*a_value_ptr) == (*b_value_ptr);
  649. }
  650. MS_LOG(DEBUG) << "check AnfNodePtr equal";
  651. }
  652. if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
  653. MS_LOG(DEBUG) << "check GraphPtr equal";
  654. }
  655. return a == b;
  656. }
  657. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  658. // To matchCNode and Kernel's type
  659. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  660. return true;
  661. }
  662. return a.type() == b.type();
  663. }
  664. namespace {
  665. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  666. if (utils::isa<int>(sexp)) {
  667. return NewValueNode(utils::cast<int>(sexp));
  668. }
  669. if (utils::isa<int64_t>(sexp)) {
  670. return NewValueNode(utils::cast<int64_t>(sexp));
  671. }
  672. if (utils::isa<float>(sexp)) {
  673. return NewValueNode(utils::cast<float>(sexp));
  674. }
  675. if (utils::isa<bool>(sexp)) {
  676. return NewValueNode(utils::cast<bool>(sexp));
  677. }
  678. if (utils::isa<ValuePtr>(sexp)) {
  679. return NewValueNode(utils::cast<ValuePtr>(sexp));
  680. }
  681. return nullptr;
  682. }
  683. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  684. if (utils::isa<FuncGraphPtr>(graph)) {
  685. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  686. }
  687. if (utils::isa<VarPtr>(graph)) {
  688. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  689. }
  690. return nullptr;
  691. }
  692. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  693. if (utils::isa<VarPtr>(graph)) {
  694. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  695. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  696. }
  697. if (utils::isa<FuncGraphPtr>(graph)) {
  698. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  699. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  700. }
  701. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  702. return nullptr;
  703. }
  704. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  705. bool multigraph) {
  706. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  707. std::vector<AnfNodePtr> input_nodes;
  708. const auto &tuple = utils::cast<VectorRef>(sexp);
  709. if (multigraph && utils::isa<VarPtr>(graph)) {
  710. for (auto &x : tuple) {
  711. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  712. input_nodes.push_back(node);
  713. }
  714. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  715. return std::make_shared<CNode>(input_nodes, var_ptr);
  716. }
  717. for (auto &x : tuple) {
  718. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  719. input_nodes.push_back(node);
  720. }
  721. return CreateCNodeWithGraph(input_nodes, graph);
  722. }
  723. } // namespace
  724. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  725. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  726. MS_EXCEPTION_IF_NULL(primitive_vars);
  727. if (utils::isa<VectorRef>(sexp)) {
  728. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  729. }
  730. if (utils::isa<VarPtr>(sexp)) {
  731. auto var_ptr = utils::cast<VarPtr>(sexp);
  732. MS_EXCEPTION_IF_NULL(var_ptr);
  733. if (var_ptr->primitive()) {
  734. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  735. return NewValueNode(var_ptr->primitive());
  736. }
  737. return CreateVarNodeWithSexp(sexp, graph);
  738. }
  739. if (utils::isa<AnfNodePtr>(sexp)) {
  740. return utils::cast<AnfNodePtr>(sexp);
  741. }
  742. auto value_node = CreateValueNodeWithSexp(sexp);
  743. if (value_node == nullptr) {
  744. MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  745. }
  746. return value_node;
  747. }
  748. bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
  749. MS_EXCEPTION_IF_NULL(equiv1);
  750. MS_EXCEPTION_IF_NULL(equiv2);
  751. MS_EXCEPTION_IF_NULL(var_node);
  752. auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
  753. MS_EXCEPTION_IF_NULL(equiv1_node);
  754. auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
  755. MS_EXCEPTION_IF_NULL(equiv2_node);
  756. return *equiv1_node == *equiv2_node;
  757. }
  758. AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
  759. MS_EXCEPTION_IF_NULL(equiv);
  760. MS_EXCEPTION_IF_NULL(var_node);
  761. auto iter = (*equiv).find(var_node);
  762. if (iter == (*equiv).end()) {
  763. MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
  764. return nullptr;
  765. }
  766. auto res = utils::cast<AnfNodePtr>(iter->second);
  767. if (res == nullptr) {
  768. MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
  769. }
  770. return res;
  771. }
  772. bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
  773. MS_EXCEPTION_IF_NULL(n1);
  774. MS_EXCEPTION_IF_NULL(n2);
  775. auto n1_cnode = n1->cast<CNodePtr>();
  776. auto n2_cnode = n2->cast<CNodePtr>();
  777. MS_EXCEPTION_IF_NULL(n1_cnode);
  778. MS_EXCEPTION_IF_NULL(n2_cnode);
  779. auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  780. MS_EXCEPTION_IF_NULL(index_input1);
  781. auto value_node1 = index_input1->cast<ValueNodePtr>();
  782. MS_EXCEPTION_IF_NULL(value_node1);
  783. auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  784. MS_EXCEPTION_IF_NULL(index_input2);
  785. auto value_node2 = index_input2->cast<ValueNodePtr>();
  786. MS_EXCEPTION_IF_NULL(value_node2);
  787. return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value());
  788. }
  789. bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
  790. MS_EXCEPTION_IF_NULL(node);
  791. if (!node->isa<CNode>()) {
  792. MS_LOG(INFO) << "node is not a cnode";
  793. return false;
  794. }
  795. auto cnode = node->cast<CNodePtr>();
  796. MS_EXCEPTION_IF_NULL(cnode);
  797. return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
  798. }
  799. bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
  800. MS_EXCEPTION_IF_NULL(node);
  801. TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
  802. if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
  803. return true;
  804. }
  805. MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
  806. return false;
  807. }
  808. ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
  809. MS_EXCEPTION_IF_NULL(value_node);
  810. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  811. new_value_node->set_abstract(value_node->abstract());
  812. // create kernel_info fo new value node
  813. auto kernel_info = std::make_shared<device::KernelInfo>();
  814. new_value_node->set_kernel_info(kernel_info);
  815. // create kernel_build_info for new value node
  816. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  817. // set the format of value_node to DEFAULT_FORMAT
  818. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  819. // set value node initial device data type = infer data type
  820. std::vector<TypeId> types;
  821. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
  822. types.push_back(kTypeUnknown);
  823. }
  824. kernel_build_info_builder->SetOutputsDeviceType(types);
  825. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  826. return new_value_node;
  827. }
  828. void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
  829. MS_EXCEPTION_IF_NULL(old_node);
  830. MS_EXCEPTION_IF_NULL(graph);
  831. auto manager = graph->manager();
  832. MS_EXCEPTION_IF_NULL(manager);
  833. // find BatchNorm's output which is a Depend or ControlDepend
  834. for (const auto &node_index : manager->node_users()[old_node]) {
  835. AnfNodePtr output = node_index.first;
  836. size_t index = IntToSize(node_index.second);
  837. MS_EXCEPTION_IF_NULL(output);
  838. if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
  839. auto control_depend = output->cast<CNodePtr>();
  840. MS_EXCEPTION_IF_NULL(control_depend);
  841. control_depend->set_input(index, new_node);
  842. } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
  843. auto depend = output->cast<CNodePtr>();
  844. MS_EXCEPTION_IF_NULL(depend);
  845. depend->set_input(index, new_node);
  846. }
  847. }
  848. }
  849. } // namespace opt
  850. } // namespace mindspore