You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cc 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/common/helper.h"
  17. #include <string>
  18. #include <utility>
  19. #include <unordered_set>
  20. #include <algorithm>
  21. #include <map>
  22. #include <set>
  23. #include <deque>
  24. #include "utils/utils.h"
  25. #include "base/base_ref.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "base/core_ops.h"
  28. #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
  29. #include "frontend/operator/ops.h"
  30. #include "utils/ms_utils.h"
  31. #include "runtime/device/kernel_info.h"
  32. #include "utils/ms_context.h"
  33. namespace mindspore {
  34. namespace opt {
  35. constexpr size_t kType32Len = 4;
  36. std::vector<int> Convert2Int(const std::vector<size_t> &v) {
  37. std::vector<int> result;
  38. (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
  39. return result;
  40. }
  41. bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
  42. MS_EXCEPTION_IF_NULL(node);
  43. std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
  44. std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
  45. for (auto &nd : node_list) {
  46. MS_EXCEPTION_IF_NULL(nd);
  47. if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
  48. auto control_depend = nd->cast<CNodePtr>();
  49. auto prior_node = control_depend->input(kControlDependPriorIndex);
  50. auto behind_node = control_depend->input(kControlDependBehindIndex);
  51. auto it = control_depend_map.find(behind_node);
  52. if (it == control_depend_map.end()) {
  53. control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node};
  54. } else {
  55. it->second.insert(prior_node);
  56. }
  57. }
  58. }
  59. FuncGraphManagerPtr manager = graph.manager();
  60. MS_EXCEPTION_IF_NULL(manager);
  61. std::unordered_set<AnfNodePtr> seen_node;
  62. std::deque<AnfNodePtr> todo{node};
  63. while (!todo.empty()) {
  64. AnfNodePtr nd = todo.front();
  65. todo.pop_front();
  66. if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
  67. continue;
  68. }
  69. (void)seen_node.insert(nd);
  70. if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
  71. return true;
  72. }
  73. if (nd->isa<CNode>()) {
  74. auto cnode = nd->cast<CNodePtr>();
  75. MS_EXCEPTION_IF_NULL(cnode);
  76. auto inputs = cnode->inputs();
  77. (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
  78. }
  79. auto it = control_depend_map.find(nd);
  80. if (it != control_depend_map.end()) {
  81. (void)todo.insert(todo.end(), it->second.begin(), it->second.end());
  82. }
  83. }
  84. return false;
  85. }
  86. bool UnVisited(const BaseRef &n) {
  87. if (utils::isa<AnfNodePtr>(n)) {
  88. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  89. MS_EXCEPTION_IF_NULL(in);
  90. if (IsValueNode<Primitive>(in)) {
  91. auto value_node = in->cast<ValueNodePtr>();
  92. MS_EXCEPTION_IF_NULL(value_node);
  93. auto value = value_node->value();
  94. MS_EXCEPTION_IF_NULL(value);
  95. auto prim_py = value->cast<PrimitivePtr>();
  96. MS_EXCEPTION_IF_NULL(prim_py);
  97. return !prim_py->HasAttr(kAttrVisited);
  98. } else if (IsValueNode<FuncGraph>(in)) {
  99. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  100. MS_EXCEPTION_IF_NULL(func_graph);
  101. return !func_graph->has_flag(kAttrVisited);
  102. }
  103. return false;
  104. }
  105. return false;
  106. }
  107. bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) {
  108. MS_EXCEPTION_IF_NULL(node);
  109. if (!node->isa<CNode>()) {
  110. MS_LOG(ERROR) << "The node is expected to be a cnode";
  111. return false;
  112. }
  113. *cnode = node->cast<CNodePtr>();
  114. if (*cnode == nullptr) {
  115. return false;
  116. }
  117. if ((*cnode)->inputs().size() < IntToSize(input_size)) {
  118. auto op_name = AnfAlgo::GetCNodeName(*cnode);
  119. MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs.";
  120. return false;
  121. }
  122. return true;
  123. }
  124. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) {
  125. MS_EXCEPTION_IF_NULL(node);
  126. if (!node->isa<CNode>()) {
  127. MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
  128. }
  129. auto cnode = node->cast<CNodePtr>();
  130. MS_EXCEPTION_IF_NULL(cnode);
  131. if (cnode->inputs().size() != IntToSize(input_size)) {
  132. auto op_name = AnfAlgo::GetCNodeName(cnode);
  133. MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs.";
  134. }
  135. return cnode;
  136. }
  137. void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) {
  138. MS_EXCEPTION_IF_NULL(cnode);
  139. if (cnode->inputs().size() != input_size) {
  140. MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size;
  141. }
  142. }
  143. bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) {
  144. MS_EXCEPTION_IF_NULL(node_x);
  145. MS_EXCEPTION_IF_NULL(node_y);
  146. return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) &&
  147. AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0));
  148. }
  149. const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  150. MS_EXCEPTION_IF_NULL(func_graph);
  151. auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
  152. MS_EXCEPTION_IF_NULL(transop_cnode);
  153. auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
  154. auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
  155. MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
  156. MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1));
  157. auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1);
  158. MS_EXCEPTION_IF_NULL(transed_node);
  159. std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
  160. depend_cnode->input(kDependInputNum - 1)};
  161. AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
  162. MS_EXCEPTION_IF_NULL(replace_depend);
  163. auto transed_abstract = transed_node->abstract();
  164. replace_depend->set_abstract(transed_abstract);
  165. return replace_depend;
  166. }
  167. bool Visited(const BaseRef &n) {
  168. if (utils::isa<AnfNodePtr>(n)) {
  169. AnfNodePtr in = utils::cast<AnfNodePtr>(n);
  170. MS_EXCEPTION_IF_NULL(in);
  171. if (IsValueNode<Primitive>(in)) {
  172. auto value_node = in->cast<ValueNodePtr>();
  173. MS_EXCEPTION_IF_NULL(value_node);
  174. auto value = value_node->value();
  175. MS_EXCEPTION_IF_NULL(value);
  176. auto prim_py = value->cast<PrimitivePtr>();
  177. MS_EXCEPTION_IF_NULL(prim_py);
  178. return prim_py->HasAttr(kAttrVisited);
  179. } else if (IsValueNode<FuncGraph>(in)) {
  180. auto func_graph = GetValueNode<FuncGraphPtr>(in);
  181. MS_EXCEPTION_IF_NULL(func_graph);
  182. return func_graph->has_flag(kAttrVisited);
  183. }
  184. return false;
  185. }
  186. return false;
  187. }
  188. void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode,
  189. std::vector<AnfNodePtr> *conv_bn1_outputs) {
  190. auto prim = std::make_shared<Primitive>(kConvBN1OpName);
  191. std::vector<AnfNodePtr> conv_bn1_inputs = {NewValueNode(prim)};
  192. MS_EXCEPTION_IF_NULL(conv_cnode);
  193. // All the inputs of conv_bn1 are from the inputs of conv
  194. for (size_t i = 1; i < conv_cnode->inputs().size(); i++) {
  195. conv_bn1_inputs.push_back(conv_cnode->input(i));
  196. }
  197. MS_EXCEPTION_IF_NULL(func_graph);
  198. CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs);
  199. MS_EXCEPTION_IF_NULL(conv_bn1_cnode);
  200. auto kernel_info = std::make_shared<device::KernelInfo>();
  201. conv_bn1_cnode->set_kernel_info(kernel_info);
  202. // Set attr for conv_bn1
  203. AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode);
  204. // Set abstract of conv_bn1
  205. MS_EXCEPTION_IF_NULL(bn_cnode);
  206. auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_cnode->abstract());
  207. MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
  208. AbstractBasePtrList conv_bn1_abstract_list;
  209. conv_bn1_abstract_list.push_back(conv_cnode->abstract());
  210. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(
  211. kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1)));
  212. conv_bn1_abstract_list.push_back(abstract_tensor);
  213. conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]);
  214. auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(conv_bn1_abstract_list);
  215. conv_bn1_cnode->set_abstract(abstract_tuple);
  216. CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs);
  217. }
  218. void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &fused_bn1_outputs,
  219. const CNodePtr &bn_node, std::vector<AnfNodePtr> *fused_bn2_outputs) {
  220. MS_EXCEPTION_IF_NULL(graph);
  221. MS_EXCEPTION_IF_NULL(bn_node);
  222. MS_EXCEPTION_IF_NULL(fused_bn2_outputs);
  223. if (bn_node->inputs().size() != kBnInputNum) {
  224. MS_LOG(EXCEPTION) << "BN node has wrong input size";
  225. }
  226. if (fused_bn1_outputs.size() != kBN1OutputNum) {
  227. MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size";
  228. }
  229. // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn
  230. std::vector<AnfNodePtr> fused_bn2_inputs = {NewValueNode(std::make_shared<Primitive>(kFusedBN2OpName))};
  231. fused_bn2_inputs.push_back(fused_bn1_outputs[0]);
  232. fused_bn2_inputs.push_back(fused_bn1_outputs[1]);
  233. fused_bn2_inputs.push_back(bn_node->input(4));
  234. fused_bn2_inputs.push_back(bn_node->input(5));
  235. auto fused_bn2 = graph->NewCNode(fused_bn2_inputs);
  236. MS_EXCEPTION_IF_NULL(fused_bn2);
  237. auto kernel_info = std::make_shared<device::KernelInfo>();
  238. fused_bn2->set_kernel_info(kernel_info);
  239. auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1),
  240. AnfAlgo::GetOutputInferDataType(bn_node, 2)};
  241. auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1),
  242. AnfAlgo::GetOutputInferShape(bn_node, 2)};
  243. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get());
  244. fused_bn2->set_scope(bn_node->scope());
  245. AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2);
  246. CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs);
  247. }
  248. void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input,
  249. const std::vector<AnfNodePtr> &fused_bn1_outputs,
  250. const std::vector<AnfNodePtr> &fused_bn2_outputs, const CNodePtr &bn_node,
  251. std::vector<AnfNodePtr> *fused_bn3_outputs) {
  252. MS_EXCEPTION_IF_NULL(graph);
  253. MS_EXCEPTION_IF_NULL(data_input);
  254. MS_EXCEPTION_IF_NULL(bn_node);
  255. MS_EXCEPTION_IF_NULL(fused_bn3_outputs);
  256. if (bn_node->inputs().size() != kBnInputNum) {
  257. MS_LOG(EXCEPTION) << "BN node has wrong input size";
  258. }
  259. if (fused_bn1_outputs.size() != kBN1OutputNum) {
  260. MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size";
  261. }
  262. if (fused_bn2_outputs.size() != kBN2OutputNum) {
  263. MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size";
  264. }
  265. // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn
  266. std::vector<AnfNodePtr> fused_bn3_inputs = {NewValueNode(std::make_shared<Primitive>(kFusedBN3OpName))};
  267. fused_bn3_inputs.push_back(data_input);
  268. fused_bn3_inputs.push_back(fused_bn1_outputs[0]);
  269. fused_bn3_inputs.push_back(fused_bn2_outputs[0]);
  270. fused_bn3_inputs.push_back(bn_node->input(2));
  271. fused_bn3_inputs.push_back(bn_node->input(3));
  272. auto fused_bn3 = graph->NewCNode(fused_bn3_inputs);
  273. MS_EXCEPTION_IF_NULL(fused_bn3);
  274. auto kernel_info = std::make_shared<device::KernelInfo>();
  275. fused_bn3->set_kernel_info(kernel_info);
  276. auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)};
  277. auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)};
  278. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get());
  279. fused_bn3->set_scope(bn_node->scope());
  280. AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3);
  281. (*fused_bn3_outputs).push_back(fused_bn3);
  282. }
  283. void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
  284. std::vector<AnfNodePtr> *outputs) {
  285. MS_EXCEPTION_IF_NULL(func_graph);
  286. MS_EXCEPTION_IF_NULL(node);
  287. MS_EXCEPTION_IF_NULL(outputs);
  288. for (size_t i = 0; i < output_num; i++) {
  289. int temp = SizeToInt(i);
  290. auto idx = NewValueNode(temp);
  291. MS_EXCEPTION_IF_NULL(idx);
  292. auto imm = std::make_shared<Int32Imm>(temp);
  293. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  294. idx->set_abstract(abstract_scalar);
  295. auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  296. MS_EXCEPTION_IF_NULL(tuple_getitem);
  297. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)},
  298. {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get());
  299. (*outputs).push_back(tuple_getitem);
  300. }
  301. }
  302. template <typename T>
  303. tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
  304. size_t data_length) {
  305. MS_EXCEPTION_IF_NULL(value_tuple_ptr);
  306. MS_EXCEPTION_IF_NULL(type_ptr);
  307. std::vector<T> values;
  308. for (const auto &v : value_tuple_ptr->value()) {
  309. MS_EXCEPTION_IF_NULL(v);
  310. if (v->isa<Scalar>()) {
  311. ScalarPtr scalar = v->cast<ScalarPtr>();
  312. values.push_back(GetValue<T>(scalar));
  313. } else {
  314. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  315. return nullptr;
  316. }
  317. }
  318. std::vector<int> tensor_shape = {SizeToInt(values.size())};
  319. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
  320. MS_EXCEPTION_IF_NULL(tensor);
  321. tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
  322. tensor->set_device_info(device_info);
  323. auto data_ptr = tensor->data_c();
  324. MS_EXCEPTION_IF_NULL(data_ptr);
  325. auto elem_num = values.size() * data_length;
  326. auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
  327. if (ret_code != 0) {
  328. MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
  329. }
  330. return tensor;
  331. }
  332. tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
  333. MS_EXCEPTION_IF_NULL(value_tuple);
  334. tensor::TensorPtr tensor = nullptr;
  335. if (value_tuple->value().empty()) {
  336. MS_LOG(WARNING) << "The value tuple is empty.";
  337. return nullptr;
  338. }
  339. ValuePtr v = *(value_tuple->value().begin());
  340. MS_EXCEPTION_IF_NULL(v);
  341. // Currently we only deal with the scalar tuple
  342. if (!v->isa<Scalar>()) {
  343. MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
  344. return nullptr;
  345. }
  346. ScalarPtr scalar = v->cast<ScalarPtr>();
  347. MS_EXCEPTION_IF_NULL(scalar);
  348. if (scalar->isa<Int32Imm>()) {
  349. tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
  350. } else if (scalar->isa<Int64Imm>()) {
  351. tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
  352. } else if (scalar->isa<FloatImm>()) {
  353. tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
  354. } else {
  355. auto type = scalar->type();
  356. auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
  357. MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
  358. return nullptr;
  359. }
  360. return tensor;
  361. }
  362. bool IsNopNode(const AnfNodePtr &node) {
  363. auto context_ptr = MsContext::GetInstance();
  364. MS_EXCEPTION_IF_NULL(context_ptr);
  365. if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
  366. context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  367. return false;
  368. }
  369. static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
  370. prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
  371. kFlattenGradOpName};
  372. if (node == nullptr || !node->isa<CNode>()) {
  373. return false;
  374. }
  375. CNodePtr cnode = node->cast<CNodePtr>();
  376. MS_EXCEPTION_IF_NULL(cnode);
  377. if (cnode->inputs().empty()) {
  378. return false;
  379. }
  380. auto input0 = cnode->input(0);
  381. MS_EXCEPTION_IF_NULL(input0);
  382. if (!input0->isa<ValueNode>()) {
  383. return false;
  384. }
  385. bool is_nop_node = false;
  386. if (AnfAlgo::HasNodeAttr("nop_op", cnode)) {
  387. is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, "nop_op");
  388. }
  389. if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
  390. return false;
  391. }
  392. return true;
  393. }
  394. bool IsAllNopNode(const session::KernelGraph *const graph) {
  395. MS_EXCEPTION_IF_NULL(graph);
  396. auto execution_order = graph->execution_order();
  397. for (auto &cnode : execution_order) {
  398. MS_EXCEPTION_IF_NULL(cnode);
  399. if (!IsNopNode(cnode)) {
  400. return false;
  401. }
  402. }
  403. return true;
  404. }
  405. void HideNopNode(session::KernelGraph *const graph) {
  406. MS_EXCEPTION_IF_NULL(graph);
  407. if (IsAllNopNode(graph) == true) {
  408. return;
  409. }
  410. auto execution_order = graph->execution_order();
  411. MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
  412. std::vector<CNodePtr> new_nodes;
  413. for (auto &cnode : execution_order) {
  414. MS_EXCEPTION_IF_NULL(cnode);
  415. if (!IsNopNode(cnode)) {
  416. new_nodes.push_back(cnode);
  417. }
  418. }
  419. graph->set_execution_order(new_nodes);
  420. MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size();
  421. }
  422. void RemoveNopNode(session::KernelGraph *const graph) {
  423. MS_EXCEPTION_IF_NULL(graph);
  424. if (IsAllNopNode(graph) == true) {
  425. return;
  426. }
  427. bool changed = true;
  428. while (changed) {
  429. changed = false;
  430. std::vector<CNodePtr> new_nodes;
  431. for (auto &cnode : graph->execution_order()) {
  432. MS_EXCEPTION_IF_NULL(cnode);
  433. // ignore nop node itself
  434. if (IsNopNode(cnode)) {
  435. continue;
  436. }
  437. // Replace the input which is nop node
  438. std::vector<AnfNodePtr> new_inputs;
  439. new_inputs.push_back(cnode->input(0));
  440. bool need_update = false;
  441. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  442. auto input = cnode->input(i);
  443. MS_EXCEPTION_IF_NULL(input);
  444. auto cinput = input->cast<CNodePtr>();
  445. if (cinput == nullptr || !IsNopNode(cinput)) {
  446. new_inputs.push_back(input);
  447. continue;
  448. }
  449. if (cinput->inputs().size() == 2) {
  450. new_inputs.push_back(cinput->input(1));
  451. need_update = true;
  452. changed = true;
  453. } else {
  454. new_inputs.push_back(input);
  455. }
  456. }
  457. if (need_update) {
  458. cnode->set_inputs(new_inputs);
  459. }
  460. // push into new execution list
  461. new_nodes.push_back(cnode);
  462. }
  463. graph->set_execution_order(new_nodes);
  464. }
  465. }
  466. size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  467. auto out_list = GetRealNodeUsedList(graph, node);
  468. MS_EXCEPTION_IF_NULL(out_list);
  469. return out_list->size();
  470. }
  471. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  472. const AnfNodePtr &node) {
  473. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  474. MS_EXCEPTION_IF_NULL(graph);
  475. auto manager = graph->manager();
  476. MS_EXCEPTION_IF_NULL(manager);
  477. auto iter = manager->node_users().find(node);
  478. if (iter == manager->node_users().end()) {
  479. MS_LOG(EXCEPTION) << "node has no output in manager";
  480. }
  481. auto output_info_list = iter->second;
  482. for (const auto &output_info : output_info_list) {
  483. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
  484. continue;
  485. }
  486. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
  487. output_info.second == kDependAttachNodeIndex) {
  488. continue;
  489. }
  490. output_node_list->push_back(output_info);
  491. }
  492. return output_node_list;
  493. }
  494. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  495. const AnfNodePtr &node,
  496. size_t output_index) {
  497. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  498. MS_EXCEPTION_IF_NULL(graph);
  499. auto manager = graph->manager();
  500. MS_EXCEPTION_IF_NULL(manager);
  501. auto iter = manager->node_users().find(node);
  502. if (iter == manager->node_users().end()) {
  503. MS_LOG(EXCEPTION) << "node has no output in manager";
  504. }
  505. auto output_info_list = iter->second;
  506. for (const auto &output_info : output_info_list) {
  507. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
  508. continue;
  509. }
  510. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
  511. output_info.second == kDependAttachNodeIndex) {
  512. continue;
  513. }
  514. size_t used_output_index;
  515. if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
  516. used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
  517. } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
  518. used_output_index = output_index;
  519. } else {
  520. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, output_info.second - 1);
  521. if (kernel_with_index.first.get() != node.get()) {
  522. MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
  523. }
  524. used_output_index = kernel_with_index.second;
  525. }
  526. if (used_output_index == output_index) {
  527. output_node_list->push_back(output_info);
  528. }
  529. }
  530. return output_node_list;
  531. }
  532. bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  533. MS_EXCEPTION_IF_NULL(graph);
  534. MS_EXCEPTION_IF_NULL(node);
  535. auto output_node_list = GetRealNodeUsedList(graph, node);
  536. MS_EXCEPTION_IF_NULL(output_node_list);
  537. return output_node_list->size() > 1;
  538. }
  539. bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  540. MS_EXCEPTION_IF_NULL(graph);
  541. MS_EXCEPTION_IF_NULL(node);
  542. auto output_node_list = GetRealNodeUsedList(graph, node);
  543. MS_EXCEPTION_IF_NULL(output_node_list);
  544. if (output_node_list->empty()) {
  545. return true;
  546. }
  547. for (const auto &output : *output_node_list) {
  548. auto out_node = output.first;
  549. auto name = AnfAlgo::GetCNodeName(out_node);
  550. if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
  551. name == prim::kPrimTupleGetItem->name()) {
  552. auto result = IsNotRealUsedByOthers(graph, out_node);
  553. if (!result) {
  554. return result;
  555. }
  556. continue;
  557. }
  558. return false;
  559. }
  560. return true;
  561. }
  562. AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
  563. auto idx = NewValueNode(SizeToInt(output_idx));
  564. MS_EXCEPTION_IF_NULL(idx);
  565. auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx));
  566. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  567. idx->set_abstract(abstract_scalar);
  568. AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  569. MS_EXCEPTION_IF_NULL(tuple_getitem);
  570. tuple_getitem->set_scope(node->scope());
  571. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  572. TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
  573. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
  574. return tuple_getitem;
  575. }
  576. void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
  577. MS_EXCEPTION_IF_NULL(cnode);
  578. std::vector<AnfNodePtr> new_inputs;
  579. std::vector<std::string> new_input_names;
  580. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  581. MS_EXCEPTION_IF_NULL(primitive);
  582. primitive = primitive->Clone();
  583. auto input_names = primitive->GetAttr(kAttrInputNames);
  584. if (input_names == nullptr) {
  585. MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
  586. return;
  587. }
  588. auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
  589. auto inputs = cnode->inputs();
  590. new_inputs.push_back(inputs[0]);
  591. bool need_update = false;
  592. for (size_t i = 0; i < inputs.size() - 1; ++i) {
  593. auto input_node = inputs[i + 1];
  594. MS_EXCEPTION_IF_NULL(input_node);
  595. if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
  596. auto value_node = input_node->cast<ValueNodePtr>();
  597. MS_EXCEPTION_IF_NULL(value_node);
  598. MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
  599. if (i >= input_names_vec.size()) {
  600. MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
  601. }
  602. primitive->set_attr(input_names_vec[i], value_node->value());
  603. need_update = true;
  604. } else {
  605. new_inputs.push_back(input_node);
  606. if (i < input_names_vec.size()) {
  607. new_input_names.push_back(input_names_vec[i]);
  608. }
  609. }
  610. }
  611. if (need_update) {
  612. // Update cnode's inputs
  613. new_inputs[0] = NewValueNode(primitive);
  614. cnode->set_inputs(new_inputs);
  615. // Update cnode's input_names attr
  616. primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
  617. }
  618. }
  619. bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  620. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  621. auto a_node = utils::cast<AnfNodePtr>(a);
  622. auto b_node = utils::cast<AnfNodePtr>(b);
  623. MS_EXCEPTION_IF_NULL(a_node);
  624. MS_EXCEPTION_IF_NULL(b_node);
  625. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  626. auto a_value_node = a_node->cast<ValueNodePtr>();
  627. MS_EXCEPTION_IF_NULL(a_value_node);
  628. auto a_value = a_value_node->value();
  629. MS_EXCEPTION_IF_NULL(a_value);
  630. auto a_prim = a_value->cast<PrimitivePtr>();
  631. MS_EXCEPTION_IF_NULL(a_prim);
  632. auto b_value_node = b_node->cast<ValueNodePtr>();
  633. MS_EXCEPTION_IF_NULL(b_value_node);
  634. auto b_value = b_value_node->value();
  635. MS_EXCEPTION_IF_NULL(b_value);
  636. auto b_prim = b_value->cast<PrimitivePtr>();
  637. MS_EXCEPTION_IF_NULL(b_prim);
  638. return a_prim->name() == b_prim->name();
  639. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  640. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  641. if (a_value_node_ptr == nullptr) {
  642. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  643. }
  644. auto a_value_ptr = a_value_node_ptr->value();
  645. if (a_value_ptr == nullptr) {
  646. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  647. }
  648. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  649. if (b_value_node_ptr == nullptr) {
  650. MS_LOG(EXCEPTION) << "cast value node ptr fail";
  651. }
  652. auto b_value_ptr = b_value_node_ptr->value();
  653. if (b_value_ptr == nullptr) {
  654. MS_LOG(EXCEPTION) << "value ptr is nullptr";
  655. }
  656. return (*a_value_ptr) == (*b_value_ptr);
  657. }
  658. MS_LOG(DEBUG) << "check AnfNodePtr equal";
  659. }
  660. if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
  661. MS_LOG(DEBUG) << "check GraphPtr equal";
  662. }
  663. return a == b;
  664. }
  665. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  666. // To matchCNode and Kernel's type
  667. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  668. return true;
  669. }
  670. return a.type() == b.type();
  671. }
  672. namespace {
  673. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  674. if (utils::isa<int>(sexp)) {
  675. return NewValueNode(utils::cast<int>(sexp));
  676. }
  677. if (utils::isa<int64_t>(sexp)) {
  678. return NewValueNode(utils::cast<int64_t>(sexp));
  679. }
  680. if (utils::isa<float>(sexp)) {
  681. return NewValueNode(utils::cast<float>(sexp));
  682. }
  683. if (utils::isa<bool>(sexp)) {
  684. return NewValueNode(utils::cast<bool>(sexp));
  685. }
  686. if (utils::isa<ValuePtr>(sexp)) {
  687. return NewValueNode(utils::cast<ValuePtr>(sexp));
  688. }
  689. return nullptr;
  690. }
  691. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  692. if (utils::isa<FuncGraphPtr>(graph)) {
  693. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  694. }
  695. if (utils::isa<VarPtr>(graph)) {
  696. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  697. }
  698. return nullptr;
  699. }
  700. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  701. if (utils::isa<VarPtr>(graph)) {
  702. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  703. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  704. }
  705. if (utils::isa<FuncGraphPtr>(graph)) {
  706. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  707. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  708. }
  709. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  710. return nullptr;
  711. }
  712. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  713. bool multigraph) {
  714. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  715. std::vector<AnfNodePtr> input_nodes;
  716. const auto &tuple = utils::cast<VectorRef>(sexp);
  717. if (multigraph && utils::isa<VarPtr>(graph)) {
  718. for (auto &x : tuple) {
  719. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  720. input_nodes.push_back(node);
  721. }
  722. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  723. return std::make_shared<CNode>(input_nodes, var_ptr);
  724. }
  725. for (auto &x : tuple) {
  726. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  727. input_nodes.push_back(node);
  728. }
  729. return CreateCNodeWithGraph(input_nodes, graph);
  730. }
  731. } // namespace
  732. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  733. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  734. MS_EXCEPTION_IF_NULL(primitive_vars);
  735. if (utils::isa<VectorRef>(sexp)) {
  736. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  737. }
  738. if (utils::isa<VarPtr>(sexp)) {
  739. auto var_ptr = utils::cast<VarPtr>(sexp);
  740. MS_EXCEPTION_IF_NULL(var_ptr);
  741. if (var_ptr->primitive()) {
  742. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  743. return NewValueNode(var_ptr->primitive());
  744. }
  745. return CreateVarNodeWithSexp(sexp, graph);
  746. }
  747. if (utils::isa<AnfNodePtr>(sexp)) {
  748. return utils::cast<AnfNodePtr>(sexp);
  749. }
  750. auto value_node = CreateValueNodeWithSexp(sexp);
  751. if (value_node == nullptr) {
  752. MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
  753. }
  754. return value_node;
  755. }
  756. bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
  757. MS_EXCEPTION_IF_NULL(equiv1);
  758. MS_EXCEPTION_IF_NULL(equiv2);
  759. MS_EXCEPTION_IF_NULL(var_node);
  760. auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
  761. MS_EXCEPTION_IF_NULL(equiv1_node);
  762. auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
  763. MS_EXCEPTION_IF_NULL(equiv2_node);
  764. return *equiv1_node == *equiv2_node;
  765. }
  766. AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
  767. MS_EXCEPTION_IF_NULL(equiv);
  768. MS_EXCEPTION_IF_NULL(var_node);
  769. auto iter = (*equiv).find(var_node);
  770. if (iter == (*equiv).end()) {
  771. MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
  772. return nullptr;
  773. }
  774. auto res = utils::cast<AnfNodePtr>(iter->second);
  775. if (res == nullptr) {
  776. MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
  777. }
  778. return res;
  779. }
  780. bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
  781. MS_EXCEPTION_IF_NULL(n1);
  782. MS_EXCEPTION_IF_NULL(n2);
  783. auto n1_cnode = n1->cast<CNodePtr>();
  784. auto n2_cnode = n2->cast<CNodePtr>();
  785. MS_EXCEPTION_IF_NULL(n1_cnode);
  786. MS_EXCEPTION_IF_NULL(n2_cnode);
  787. auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  788. MS_EXCEPTION_IF_NULL(index_input1);
  789. auto value_node1 = index_input1->cast<ValueNodePtr>();
  790. MS_EXCEPTION_IF_NULL(value_node1);
  791. auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem);
  792. MS_EXCEPTION_IF_NULL(index_input2);
  793. auto value_node2 = index_input2->cast<ValueNodePtr>();
  794. MS_EXCEPTION_IF_NULL(value_node2);
  795. return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value());
  796. }
  797. bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
  798. MS_EXCEPTION_IF_NULL(node);
  799. if (!node->isa<CNode>()) {
  800. MS_LOG(INFO) << "node is not a cnode";
  801. return false;
  802. }
  803. auto cnode = node->cast<CNodePtr>();
  804. MS_EXCEPTION_IF_NULL(cnode);
  805. return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
  806. }
  807. bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
  808. MS_EXCEPTION_IF_NULL(node);
  809. TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
  810. if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
  811. return true;
  812. }
  813. MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
  814. return false;
  815. }
  816. ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
  817. MS_EXCEPTION_IF_NULL(value_node);
  818. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  819. new_value_node->set_abstract(value_node->abstract());
  820. // create kernel_info fo new value node
  821. auto kernel_info = std::make_shared<device::KernelInfo>();
  822. new_value_node->set_kernel_info(kernel_info);
  823. // create kernel_build_info for new value node
  824. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  825. // set the format of value_node to DEFAULT_FORMAT
  826. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  827. // set value node initial device data type = infer data type
  828. std::vector<TypeId> types;
  829. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
  830. types.push_back(kTypeUnknown);
  831. }
  832. kernel_build_info_builder->SetOutputsDeviceType(types);
  833. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  834. return new_value_node;
  835. }
  836. void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
  837. MS_EXCEPTION_IF_NULL(old_node);
  838. MS_EXCEPTION_IF_NULL(graph);
  839. auto manager = graph->manager();
  840. MS_EXCEPTION_IF_NULL(manager);
  841. // find BatchNorm's output which is a Depend or ControlDepend
  842. for (const auto &node_index : manager->node_users()[old_node]) {
  843. AnfNodePtr output = node_index.first;
  844. size_t index = IntToSize(node_index.second);
  845. MS_EXCEPTION_IF_NULL(output);
  846. if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
  847. auto control_depend = output->cast<CNodePtr>();
  848. MS_EXCEPTION_IF_NULL(control_depend);
  849. control_depend->set_input(index, new_node);
  850. } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
  851. auto depend = output->cast<CNodePtr>();
  852. MS_EXCEPTION_IF_NULL(depend);
  853. depend->set_input(index, new_node);
  854. }
  855. }
  856. }
  857. } // namespace opt
  858. } // namespace mindspore