You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.cc 36 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/anf_runtime_algorithm.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include <map>
  20. #include <set>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "operator/ops.h"
  24. #include "utils/utils.h"
  25. #include "device/kernel_info.h"
  26. #include "device/device_address.h"
  27. #include "pre_activate/common/helper.h"
  28. #include "kernel/kernel.h"
  29. #include "kernel/kernel_build_info.h"
  30. #include "common/utils.h"
  31. #include "common/trans.h"
  32. namespace mindspore {
  33. namespace session {
  34. using abstract::AbstractTensor;
  35. using abstract::AbstractTuple;
  36. using device::KernelInfo;
  37. using device::ascend::AscendDeviceAddress;
  38. using kernel::KernelBuildInfoPtr;
  39. using kernel::KernelMod;
  40. using kernel::KernelModPtr;
  41. namespace {
  42. std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
  43. MS_EXCEPTION_IF_NULL(shape);
  44. std::vector<size_t> shape_size_t;
  45. std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
  46. return shape_size_t;
  47. }
  48. } // namespace
  49. KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
  50. MS_EXCEPTION_IF_NULL(anf_node);
  51. if (anf_node->isa<ValueNode>()) {
  52. return std::make_pair(anf_node, 0);
  53. } else if (anf_node->isa<Parameter>()) {
  54. return std::make_pair(anf_node, 0);
  55. } else if (anf_node->isa<CNode>()) {
  56. auto cnode = anf_node->cast<CNodePtr>();
  57. MS_EXCEPTION_IF_NULL(cnode);
  58. auto input0 = cnode->input(0);
  59. MS_EXCEPTION_IF_NULL(input0);
  60. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  61. auto node = cnode->input(index + IntToSize(1));
  62. MS_EXCEPTION_IF_NULL(node);
  63. return VisitKernel(node, 0);
  64. } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  65. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  66. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  67. }
  68. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  69. MS_EXCEPTION_IF_NULL(input2);
  70. auto value_node = input2->cast<ValueNodePtr>();
  71. MS_EXCEPTION_IF_NULL(value_node);
  72. int item_idx = GetValue<int>(value_node->value());
  73. return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
  74. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  75. return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
  76. } else {
  77. return std::make_pair(anf_node, index);
  78. }
  79. } else {
  80. MS_LOG(EXCEPTION) << "The input is invalid";
  81. }
  82. }
  83. KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
  84. bool visit_nop_node,
  85. const std::vector<PrimitivePtr> &return_types) {
  86. MS_EXCEPTION_IF_NULL(anf_node);
  87. for (const auto &prim_type : return_types) {
  88. if (CheckPrimitiveType(anf_node, prim_type)) {
  89. return std::make_pair(anf_node, index);
  90. }
  91. }
  92. if (anf_node->isa<ValueNode>()) {
  93. return std::make_pair(anf_node, 0);
  94. } else if (anf_node->isa<Parameter>()) {
  95. return std::make_pair(anf_node, 0);
  96. } else if (anf_node->isa<CNode>()) {
  97. auto cnode = anf_node->cast<CNodePtr>();
  98. MS_EXCEPTION_IF_NULL(cnode);
  99. auto input0 = cnode->input(0);
  100. MS_EXCEPTION_IF_NULL(input0);
  101. if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  102. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  103. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  104. }
  105. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  106. MS_EXCEPTION_IF_NULL(input2);
  107. auto value_node = input2->cast<ValueNodePtr>();
  108. MS_EXCEPTION_IF_NULL(value_node);
  109. int item_idx = GetValue<int>(value_node->value());
  110. return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
  111. visit_nop_node, return_types);
  112. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  113. return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
  114. } else if (opt::IsNopNode(cnode) && visit_nop_node) {
  115. if (cnode->inputs().size() == 2) {
  116. return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
  117. } else {
  118. MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
  119. }
  120. } else {
  121. return std::make_pair(anf_node, index);
  122. }
  123. } else {
  124. MS_LOG(EXCEPTION) << "The input is invalid";
  125. }
  126. }
  127. std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
  128. const std::vector<PrimitivePtr> &return_types) {
  129. std::vector<AnfNodePtr> ret;
  130. auto return_prim_type = return_types;
  131. // if visited make_tuple should return back
  132. return_prim_type.push_back(prim::kPrimMakeTuple);
  133. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
  134. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  135. MS_EXCEPTION_IF_NULL(item_with_index.first);
  136. auto make_tuple = item_with_index.first->cast<CNodePtr>();
  137. MS_EXCEPTION_IF_NULL(make_tuple);
  138. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  139. auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
  140. (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
  141. }
  142. return ret;
  143. }
  144. ret.push_back(item_with_index.first);
  145. return ret;
  146. }
  147. AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
  148. MS_EXCEPTION_IF_NULL(node);
  149. return node->input(kAnfPrimitiveIndex);
  150. }
  151. PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
  152. MS_EXCEPTION_IF_NULL(node);
  153. auto cnode = node->cast<CNodePtr>();
  154. MS_EXCEPTION_IF_NULL(cnode);
  155. auto attr_input = GetCNodePrimitiveNode(cnode);
  156. MS_EXCEPTION_IF_NULL(attr_input);
  157. auto value_node = attr_input->cast<ValueNodePtr>();
  158. MS_EXCEPTION_IF_NULL(value_node);
  159. auto value = value_node->value();
  160. MS_EXCEPTION_IF_NULL(value);
  161. auto primitive = value->cast<PrimitivePtr>();
  162. return primitive;
  163. }
  164. bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. if (!node->isa<CNode>()) {
  167. return false;
  168. }
  169. auto cnode = node->cast<CNodePtr>();
  170. MS_EXCEPTION_IF_NULL(cnode);
  171. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  172. }
  173. std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
  174. MS_EXCEPTION_IF_NULL(node);
  175. if (node->isa<CNode>()) {
  176. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  177. MS_EXCEPTION_IF_NULL(primitive);
  178. return primitive->name();
  179. }
  180. MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
  181. }
  182. std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
  183. MS_EXCEPTION_IF_NULL(node);
  184. return node->DebugString();
  185. }
  186. void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
  187. MS_EXCEPTION_IF_NULL(node);
  188. if (!node->isa<CNode>()) {
  189. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  190. }
  191. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  192. MS_EXCEPTION_IF_NULL(primitive);
  193. primitive->set_attr(key, value);
  194. }
  195. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
  196. CopyNodeAttr(key, key, from, to);
  197. }
  198. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  199. const AnfNodePtr &to) {
  200. MS_EXCEPTION_IF_NULL(from);
  201. MS_EXCEPTION_IF_NULL(to);
  202. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  203. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
  204. << to->DebugString();
  205. }
  206. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  207. MS_EXCEPTION_IF_NULL(from_primitive);
  208. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  209. MS_EXCEPTION_IF_NULL(to_primitive);
  210. to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
  211. }
  212. void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
  213. MS_EXCEPTION_IF_NULL(from);
  214. MS_EXCEPTION_IF_NULL(to);
  215. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  216. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
  217. << from->DebugString();
  218. }
  219. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  220. MS_EXCEPTION_IF_NULL(from_primitive);
  221. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  222. MS_EXCEPTION_IF_NULL(to_primitive);
  223. (void)to_primitive->SetAttrs(from_primitive->attrs());
  224. }
  225. void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
  226. MS_EXCEPTION_IF_NULL(node);
  227. if (!node->isa<CNode>()) {
  228. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  229. }
  230. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  231. MS_EXCEPTION_IF_NULL(primitive);
  232. primitive->EraseAttr(key);
  233. }
  234. bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const AnfNodePtr &node) {
  235. MS_EXCEPTION_IF_NULL(node);
  236. if (!node->isa<CNode>()) {
  237. MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
  238. return false;
  239. }
  240. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  241. MS_EXCEPTION_IF_NULL(primitive);
  242. return primitive->HasAttr(key);
  243. }
  244. size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
  245. MS_EXCEPTION_IF_NULL(node);
  246. if (!node->isa<CNode>()) {
  247. MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
  248. }
  249. auto cnode = node->cast<CNodePtr>();
  250. MS_EXCEPTION_IF_NULL(cnode);
  251. size_t input_num = cnode->inputs().size();
  252. if (input_num == 0) {
  253. MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
  254. }
  255. // exclude intputs[0],which is value_node storing attr,inputs left are real input
  256. return input_num - 1;
  257. }
  258. size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
  259. MS_EXCEPTION_IF_NULL(node);
  260. TypePtr type = node->Type();
  261. MS_EXCEPTION_IF_NULL(type);
  262. if (type->isa<Tuple>()) {
  263. auto tuple_type = type->cast<TuplePtr>();
  264. MS_EXCEPTION_IF_NULL(tuple_type);
  265. return tuple_type->size();
  266. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  267. return 1;
  268. } else if (type->isa<TypeNone>()) {
  269. return 0;
  270. } else {
  271. return 1;
  272. }
  273. }
  274. std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
  275. MS_EXCEPTION_IF_NULL(node);
  276. if (output_idx > GetOutputTensorNum(node)) {
  277. MS_LOG(EXCEPTION) << "Output index:" << output_idx
  278. << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
  279. << node->DebugString() << "]";
  280. }
  281. auto kernel_info = node->kernel_info();
  282. MS_EXCEPTION_IF_NULL(kernel_info);
  283. auto build_info = kernel_info->select_kernel_build_info();
  284. MS_EXCEPTION_IF_NULL(build_info);
  285. return build_info->GetOutputFormat(output_idx);
  286. }
  287. std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
  288. MS_EXCEPTION_IF_NULL(node);
  289. if (input_idx > GetInputTensorNum(node)) {
  290. MS_LOG(EXCEPTION) << "Input index :" << input_idx
  291. << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
  292. << node->DebugString() << "]";
  293. }
  294. auto kernel_info = node->kernel_info();
  295. MS_EXCEPTION_IF_NULL(kernel_info);
  296. auto build_info = kernel_info->select_kernel_build_info();
  297. MS_EXCEPTION_IF_NULL(build_info);
  298. return build_info->GetInputFormat(input_idx);
  299. }
  300. KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
  301. MS_EXCEPTION_IF_NULL(anf_node);
  302. if (!anf_node->isa<CNode>()) {
  303. MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
  304. }
  305. auto cnode = anf_node->cast<CNodePtr>();
  306. MS_EXCEPTION_IF_NULL(cnode);
  307. if (input_idx + 1 >= cnode->inputs().size()) {
  308. MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
  309. }
  310. auto node = cnode->input(input_idx + 1);
  311. MS_EXCEPTION_IF_NULL(node);
  312. return VisitKernel(node, 0);
  313. }
  314. std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
  315. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  316. return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
  317. }
  318. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
  319. MS_EXCEPTION_IF_NULL(node);
  320. abstract::BaseShapePtr base_shape = node->Shape();
  321. MS_EXCEPTION_IF_NULL(base_shape);
  322. if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
  323. return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
  324. } else if (base_shape->isa<abstract::TupleShape>()) {
  325. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  326. MS_EXCEPTION_IF_NULL(tuple_shape);
  327. if (output_idx >= tuple_shape->size()) {
  328. MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
  329. << ".";
  330. }
  331. auto b_shp = (*tuple_shape)[output_idx];
  332. if (b_shp->isa<abstract::Shape>()) {
  333. return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
  334. } else if (b_shp->isa<abstract::NoShape>()) {
  335. return std::vector<size_t>();
  336. } else {
  337. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  338. << base_shape->ToString();
  339. }
  340. } else if (base_shape->isa<abstract::NoShape>()) {
  341. return std::vector<size_t>();
  342. }
  343. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  344. << base_shape->ToString();
  345. }
  346. std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
  347. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  348. return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
  349. }
  350. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
  351. auto format = GetOutputFormat(node, output_idx);
  352. auto infer_shape = GetOutputInferShape(node, output_idx);
  353. if (infer_shape.empty()) {
  354. return infer_shape;
  355. }
  356. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  357. if (trans::IsNeedPadding(format, infer_shape.size())) {
  358. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
  359. }
  360. return trans::TransShapeToDevice(infer_shape, format);
  361. }
  362. std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
  363. auto format = GetInputFormat(node, input_idx);
  364. auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
  365. if (infer_shape.empty()) {
  366. return infer_shape;
  367. }
  368. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  369. if (trans::IsNeedPadding(format, infer_shape.size())) {
  370. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
  371. }
  372. return trans::TransShapeToDevice(infer_shape, format);
  373. }
  374. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  375. MS_EXCEPTION_IF_NULL(node);
  376. if (input_idx > GetInputTensorNum(node)) {
  377. MS_LOG(EXCEPTION) << "The index:" << input_idx
  378. << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
  379. << node->DebugString() << "]";
  380. }
  381. auto kernel_info = node->kernel_info();
  382. MS_EXCEPTION_IF_NULL(kernel_info);
  383. auto build_info = kernel_info->select_kernel_build_info();
  384. MS_EXCEPTION_IF_NULL(build_info);
  385. if (build_info->IsInputDefaultPadding()) {
  386. return {};
  387. }
  388. return build_info->GetInputReshapeType(input_idx);
  389. }
  390. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
  391. MS_EXCEPTION_IF_NULL(node);
  392. if (output_idx > GetOutputTensorNum(node)) {
  393. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  394. << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
  395. }
  396. auto kernel_info = node->kernel_info();
  397. MS_EXCEPTION_IF_NULL(kernel_info);
  398. auto build_info = kernel_info->select_kernel_build_info();
  399. MS_EXCEPTION_IF_NULL(build_info);
  400. if (build_info->IsOutputDefaultPadding()) {
  401. return {};
  402. }
  403. return build_info->GetOutputReshapeType(output_idx);
  404. }
  405. TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
  406. MS_EXCEPTION_IF_NULL(node);
  407. TypePtr type_ptr = node->Type();
  408. MS_EXCEPTION_IF_NULL(type_ptr);
  409. if (type_ptr->isa<TensorType>() && output_idx == 0) {
  410. auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
  411. MS_EXCEPTION_IF_NULL(tensor_ptr);
  412. TypePtr elem = tensor_ptr->element();
  413. MS_EXCEPTION_IF_NULL(elem);
  414. return elem->type_id();
  415. } else if (type_ptr->isa<Tuple>()) {
  416. auto tuple_ptr = type_ptr->cast<TuplePtr>();
  417. MS_EXCEPTION_IF_NULL(tuple_ptr);
  418. if (output_idx >= tuple_ptr->size()) {
  419. MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
  420. }
  421. auto tuple_i = (*tuple_ptr)[output_idx];
  422. MS_EXCEPTION_IF_NULL(tuple_i);
  423. if (tuple_i->isa<TensorType>()) {
  424. auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
  425. MS_EXCEPTION_IF_NULL(tensor_ptr);
  426. TypePtr elem = tensor_ptr->element();
  427. MS_EXCEPTION_IF_NULL(elem);
  428. return elem->type_id();
  429. } else if (tuple_i->isa<Number>()) {
  430. return tuple_i->type_id();
  431. } else {
  432. MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
  433. return tuple_i->type_id();
  434. }
  435. } else if (type_ptr->isa<Number>()) {
  436. return type_ptr->type_id();
  437. }
  438. return type_ptr->type_id();
  439. }
  440. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
  441. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  442. return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
  443. }
  444. TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
  445. MS_EXCEPTION_IF_NULL(node);
  446. if (output_idx > GetOutputTensorNum(node)) {
  447. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  448. << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  449. }
  450. auto kernel_info = node->kernel_info();
  451. MS_EXCEPTION_IF_NULL(kernel_info);
  452. auto build_info = kernel_info->select_kernel_build_info();
  453. MS_EXCEPTION_IF_NULL(build_info);
  454. return build_info->GetOutputDeviceType(output_idx);
  455. }
  456. TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
  457. MS_EXCEPTION_IF_NULL(node);
  458. if (input_idx > GetInputTensorNum(node)) {
  459. MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
  460. << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  461. }
  462. auto kernel_info = node->kernel_info();
  463. MS_EXCEPTION_IF_NULL(kernel_info);
  464. auto build_info = kernel_info->select_kernel_build_info();
  465. MS_EXCEPTION_IF_NULL(build_info);
  466. return build_info->GetInputDeviceType(input_idx);
  467. }
  468. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
  469. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  470. return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
  471. }
  472. // get output device addr of anf_node
  473. const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) {
  474. MS_EXCEPTION_IF_NULL(node);
  475. if (opt::IsNopNode(node)) {
  476. auto cnode = node->cast<CNodePtr>();
  477. MS_EXCEPTION_IF_NULL(cnode);
  478. if (cnode->inputs().size() == 2) {
  479. return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
  480. } else {
  481. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
  482. }
  483. }
  484. if (output_idx > GetOutputTensorNum(node)) {
  485. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  486. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  487. }
  488. auto kernel_info = node->kernel_info();
  489. MS_EXCEPTION_IF_NULL(kernel_info);
  490. auto addr = kernel_info->GetOutputAddr(output_idx);
  491. if (addr == nullptr) {
  492. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  493. << " output addr is not exist";
  494. }
  495. return addr;
  496. }
  497. DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) {
  498. MS_EXCEPTION_IF_NULL(node);
  499. if (opt::IsNopNode(node)) {
  500. auto cnode = node->cast<CNodePtr>();
  501. MS_EXCEPTION_IF_NULL(cnode);
  502. if (cnode->inputs().size() == 2) {
  503. return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
  504. } else {
  505. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
  506. }
  507. }
  508. if (output_idx > GetOutputTensorNum(node)) {
  509. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  510. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  511. }
  512. auto kernel_info = node->kernel_info();
  513. MS_EXCEPTION_IF_NULL(kernel_info);
  514. auto addr = kernel_info->GetMutableOutputAddr(output_idx);
  515. if (addr == nullptr) {
  516. MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
  517. << " output addr is not exist";
  518. }
  519. return addr;
  520. }
  521. // get output device addr of anf_node
  522. bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
  523. MS_EXCEPTION_IF_NULL(node);
  524. if (output_idx > GetOutputTensorNum(node)) {
  525. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  526. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  527. }
  528. auto kernel_info = node->kernel_info();
  529. MS_EXCEPTION_IF_NULL(kernel_info);
  530. return kernel_info->OutputAddrExist(output_idx);
  531. }
  532. const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
  533. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  534. return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
  535. }
  536. DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
  537. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  538. return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
  539. }
  540. // set output device addr of anf_node
  541. void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  542. MS_EXCEPTION_IF_NULL(node);
  543. auto kernel_info = node->kernel_info();
  544. MS_EXCEPTION_IF_NULL(kernel_info);
  545. if (!kernel_info->SetOutputAddr(addr, output_idx)) {
  546. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  547. }
  548. }
  549. // set workspace device addr of anf_node
  550. void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  551. MS_EXCEPTION_IF_NULL(node);
  552. auto kernel_info = node->kernel_info();
  553. MS_EXCEPTION_IF_NULL(kernel_info);
  554. if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
  555. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  556. }
  557. }
  558. // get workspace device addr of anf_node
  559. DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
  560. MS_EXCEPTION_IF_NULL(node);
  561. auto kernel_info = node->kernel_info();
  562. MS_EXCEPTION_IF_NULL(kernel_info);
  563. auto addr = kernel_info->GetWorkspaceAddr(output_idx);
  564. if (addr == nullptr) {
  565. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  566. << "] workspace addr is not exist";
  567. }
  568. return addr;
  569. }
  570. // set infer shapes and types of anf node
  571. void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  572. const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
  573. MS_EXCEPTION_IF_NULL(node);
  574. if (types.size() != shapes.size()) {
  575. MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
  576. }
  577. if (shapes.empty()) {
  578. MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
  579. } else if (shapes.size() == 1) {
  580. // single output handle
  581. std::vector<int> shape_int;
  582. std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt);
  583. auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
  584. node->set_abstract(abstract);
  585. } else {
  586. // multiple output handle
  587. std::vector<AbstractBasePtr> abstract_list;
  588. for (size_t i = 0; i < types.size(); ++i) {
  589. std::vector<int> shape_int;
  590. std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt);
  591. abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int));
  592. }
  593. auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
  594. node->set_abstract(abstract_tuple);
  595. }
  596. }
  597. // copy an abstract of a node to another node
  598. void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
  599. to_node->set_abstract(from_node->abstract());
  600. }
  601. // get KernelBuildType of node, such as ATT,RT,FWK and so on
  602. KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
  603. MS_EXCEPTION_IF_NULL(node);
  604. auto kernel_info = node->kernel_info();
  605. MS_EXCEPTION_IF_NULL(kernel_info);
  606. // select_kernel_build_info() has checked whether return pointer is null
  607. auto build_info = kernel_info->select_kernel_build_info();
  608. MS_EXCEPTION_IF_NULL(build_info);
  609. return build_info->kernel_type();
  610. }
  611. kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
  612. MS_EXCEPTION_IF_NULL(node);
  613. auto kernel_info = node->kernel_info();
  614. MS_EXCEPTION_IF_NULL(kernel_info);
  615. auto build_info = kernel_info->select_kernel_build_info();
  616. MS_EXCEPTION_IF_NULL(build_info);
  617. return build_info->processor();
  618. }
  619. kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
  620. MS_EXCEPTION_IF_NULL(node);
  621. auto kernel_info = node->kernel_info();
  622. MS_EXCEPTION_IF_NULL(kernel_info);
  623. auto build_info = kernel_info->select_kernel_build_info();
  624. MS_EXCEPTION_IF_NULL(build_info);
  625. return build_info->fusion_type();
  626. }
  627. // set select kernel_build_info
  628. void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
  629. MS_EXCEPTION_IF_NULL(node);
  630. auto kernel_info = node->kernel_info();
  631. MS_EXCEPTION_IF_NULL(kernel_info);
  632. return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
  633. }
  634. // get select kernel_build_info
  635. KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
  636. MS_EXCEPTION_IF_NULL(node);
  637. auto kernel_info = node->kernel_info();
  638. MS_EXCEPTION_IF_NULL(kernel_info);
  639. return kernel_info->GetMutableSelectKernelBuildInfo();
  640. }
  641. // get kernelMode
  642. KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
  643. MS_EXCEPTION_IF_NULL(node);
  644. auto kernel_info = node->kernel_info();
  645. MS_EXCEPTION_IF_NULL(kernel_info);
  646. return kernel_info->MutableKernelMod();
  647. }
  648. // set kernel mod
  649. void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
  650. MS_EXCEPTION_IF_NULL(node);
  651. auto kernel_info = node->kernel_info();
  652. MS_EXCEPTION_IF_NULL(kernel_info);
  653. kernel_info->set_kernel_mod(kernel_mod);
  654. }
  655. bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
  656. MS_EXCEPTION_IF_NULL(node);
  657. // parameter and value node is not a real kernel too
  658. if (!node->isa<CNode>()) {
  659. return true;
  660. }
  661. auto cnode = node->cast<CNodePtr>();
  662. MS_EXCEPTION_IF_NULL(cnode);
  663. if (cnode->inputs().empty()) {
  664. MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  665. }
  666. auto input = cnode->inputs()[0];
  667. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  668. IsPrimitive(input, prim::kPrimTensorSummary) ||
  669. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  670. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  671. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  672. IsPrimitive(input, prim::kPrimReturn);
  673. return !is_virtual_node;
  674. }
  675. bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
  676. MS_EXCEPTION_IF_NULL(node);
  677. // parameter and value node is not a real cnode kernel
  678. if (!node->isa<CNode>()) {
  679. return false;
  680. }
  681. // return considered as a real node
  682. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  683. return true;
  684. }
  685. return IsRealKernel(node);
  686. }
  687. bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
  688. MS_EXCEPTION_IF_NULL(node);
  689. return node->has_default();
  690. }
  691. void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
  692. MS_EXCEPTION_IF_NULL(node);
  693. auto kernel_info = node->kernel_info();
  694. MS_EXCEPTION_IF_NULL(kernel_info);
  695. kernel_info->set_stream_id(stream_id);
  696. }
  697. uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
  698. MS_EXCEPTION_IF_NULL(node);
  699. auto kernel_info = node->kernel_info();
  700. MS_EXCEPTION_IF_NULL(kernel_info);
  701. return kernel_info->stream_id();
  702. }
  703. void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
  704. MS_EXCEPTION_IF_NULL(node);
  705. auto kernel_info = node->kernel_info();
  706. MS_EXCEPTION_IF_NULL(kernel_info);
  707. kernel_info->set_stream_distinction_label(stream_label);
  708. }
  709. uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
  710. MS_EXCEPTION_IF_NULL(node);
  711. auto kernel_info = node->kernel_info();
  712. MS_EXCEPTION_IF_NULL(kernel_info);
  713. return kernel_info->stream_distinction_label();
  714. }
  715. void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
  716. MS_EXCEPTION_IF_NULL(node);
  717. auto kernel_info = node->kernel_info();
  718. MS_EXCEPTION_IF_NULL(kernel_info);
  719. kernel_info->set_graph_id(graph_id);
  720. }
  721. uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
  722. MS_EXCEPTION_IF_NULL(node);
  723. auto kernel_info = node->kernel_info();
  724. MS_EXCEPTION_IF_NULL(kernel_info);
  725. return kernel_info->graph_id();
  726. }
  727. bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
  728. MS_EXCEPTION_IF_NULL(anf);
  729. TypePtr type = anf->Type();
  730. MS_EXCEPTION_IF_NULL(type);
  731. return type->isa<Tuple>();
  732. }
  733. AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
  734. MS_EXCEPTION_IF_NULL(node);
  735. auto get_input_index = index + 1;
  736. if (index + 1 > node->inputs().size()) {
  737. MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
  738. << node->inputs().size();
  739. }
  740. // input 0 is primitive node
  741. return node->input(get_input_index);
  742. }
  743. bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
  744. MS_EXCEPTION_IF_NULL(node);
  745. if (node->isa<ValueNode>()) {
  746. return false;
  747. }
  748. auto kernel_info = node->kernel_info();
  749. MS_EXCEPTION_IF_NULL(kernel_info);
  750. return kernel_info->is_feature_map();
  751. }
  752. bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
  753. if (!node->isa<CNode>()) {
  754. MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
  755. }
  756. auto cnode = node->cast<CNodePtr>();
  757. MS_EXCEPTION_IF_NULL(cnode);
  758. auto input_node = cnode->input(input_index + 1);
  759. return IsFeatureMapOutput(input_node);
  760. }
  761. size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
  762. MS_EXCEPTION_IF_NULL(anf_node);
  763. static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
  764. {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
  765. {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
  766. {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
  767. {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  768. {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  769. {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  770. {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  771. {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}};
  772. size_t ret = cur_index;
  773. auto node_name = AnfAlgo::GetCNodeName(anf_node);
  774. if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
  775. auto find = spec_node_list.find(node_name);
  776. if (find != spec_node_list.end()) {
  777. ret = find->second[cur_index];
  778. MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
  779. }
  780. }
  781. return ret;
  782. }
  783. void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
  784. MS_EXCEPTION_IF_NULL(node);
  785. MS_EXCEPTION_IF_NULL(input_node);
  786. node->set_input(index + 1, input_node);
  787. }
  788. bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
  789. MS_EXCEPTION_IF_NULL(node);
  790. auto kernel_name = AnfAlgo::GetCNodeName(node);
  791. auto kernel_type = AnfAlgo::GetKernelType(node);
  792. if (kernel_name == kAllReduceOpName || kernel_type == HCCL_KERNEL) {
  793. return true;
  794. }
  795. return false;
  796. }
  797. bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
  798. MS_EXCEPTION_IF_NULL(node);
  799. if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
  800. return true;
  801. }
  802. return false;
  803. }
  804. bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
  805. auto kernel_name = AnfAlgo::GetCNodeName(node);
  806. return kernel_name == kGetNextOpName;
  807. }
  808. } // namespace session
  809. } // namespace mindspore