You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.cc 36 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/anf_runtime_algorithm.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include <map>
  20. #include <set>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "operator/ops.h"
  24. #include "utils/utils.h"
  25. #include "device/kernel_info.h"
  26. #include "device/device_address.h"
  27. #include "pre_activate/common/helper.h"
  28. #include "kernel/kernel.h"
  29. #include "kernel/kernel_build_info.h"
  30. #include "common/utils.h"
  31. #include "common/trans.h"
  32. namespace mindspore {
  33. namespace session {
  34. using abstract::AbstractTensor;
  35. using abstract::AbstractTuple;
  36. using device::KernelInfo;
  37. using device::ascend::AscendDeviceAddress;
  38. using kernel::KernelBuildInfoPtr;
  39. using kernel::KernelMod;
  40. using kernel::KernelModPtr;
  41. namespace {
  42. std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
  43. MS_EXCEPTION_IF_NULL(shape);
  44. std::vector<size_t> shape_size_t;
  45. std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
  46. return shape_size_t;
  47. }
  48. } // namespace
  49. KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
  50. MS_EXCEPTION_IF_NULL(anf_node);
  51. if (anf_node->isa<ValueNode>()) {
  52. return std::make_pair(anf_node, 0);
  53. } else if (anf_node->isa<Parameter>()) {
  54. return std::make_pair(anf_node, 0);
  55. } else if (anf_node->isa<CNode>()) {
  56. auto cnode = anf_node->cast<CNodePtr>();
  57. MS_EXCEPTION_IF_NULL(cnode);
  58. auto input0 = cnode->input(0);
  59. MS_EXCEPTION_IF_NULL(input0);
  60. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  61. auto node = cnode->input(index + IntToSize(1));
  62. MS_EXCEPTION_IF_NULL(node);
  63. return VisitKernel(node, 0);
  64. } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  65. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  66. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  67. }
  68. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  69. MS_EXCEPTION_IF_NULL(input2);
  70. auto value_node = input2->cast<ValueNodePtr>();
  71. MS_EXCEPTION_IF_NULL(value_node);
  72. int item_idx = GetValue<int>(value_node->value());
  73. return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
  74. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  75. return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
  76. } else {
  77. return std::make_pair(anf_node, index);
  78. }
  79. } else {
  80. MS_LOG(EXCEPTION) << "The input is invalid";
  81. }
  82. }
  83. KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
  84. bool visit_nop_node,
  85. const std::vector<PrimitivePtr> &return_types) {
  86. MS_EXCEPTION_IF_NULL(anf_node);
  87. for (const auto &prim_type : return_types) {
  88. if (CheckPrimitiveType(anf_node, prim_type)) {
  89. return std::make_pair(anf_node, index);
  90. }
  91. }
  92. if (anf_node->isa<ValueNode>()) {
  93. return std::make_pair(anf_node, 0);
  94. } else if (anf_node->isa<Parameter>()) {
  95. return std::make_pair(anf_node, 0);
  96. } else if (anf_node->isa<CNode>()) {
  97. auto cnode = anf_node->cast<CNodePtr>();
  98. MS_EXCEPTION_IF_NULL(cnode);
  99. auto input0 = cnode->input(0);
  100. MS_EXCEPTION_IF_NULL(input0);
  101. if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  102. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  103. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  104. }
  105. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  106. MS_EXCEPTION_IF_NULL(input2);
  107. auto value_node = input2->cast<ValueNodePtr>();
  108. MS_EXCEPTION_IF_NULL(value_node);
  109. int item_idx = GetValue<int>(value_node->value());
  110. return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
  111. visit_nop_node, return_types);
  112. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  113. return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
  114. } else if (opt::IsNopNode(cnode) && visit_nop_node) {
  115. if (cnode->inputs().size() == 2) {
  116. return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
  117. } else {
  118. MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
  119. }
  120. } else {
  121. return std::make_pair(anf_node, index);
  122. }
  123. } else {
  124. MS_LOG(EXCEPTION) << "The input is invalid";
  125. }
  126. }
  127. std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
  128. const std::vector<PrimitivePtr> &return_types) {
  129. std::vector<AnfNodePtr> ret;
  130. auto return_prim_type = return_types;
  131. // if visited make_tuple should return back
  132. return_prim_type.push_back(prim::kPrimMakeTuple);
  133. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
  134. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  135. MS_EXCEPTION_IF_NULL(item_with_index.first);
  136. auto make_tuple = item_with_index.first->cast<CNodePtr>();
  137. MS_EXCEPTION_IF_NULL(make_tuple);
  138. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  139. auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
  140. (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
  141. }
  142. return ret;
  143. }
  144. ret.push_back(item_with_index.first);
  145. return ret;
  146. }
  147. AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
  148. MS_EXCEPTION_IF_NULL(node);
  149. return node->input(kAnfPrimitiveIndex);
  150. }
  151. PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
  152. MS_EXCEPTION_IF_NULL(node);
  153. auto cnode = node->cast<CNodePtr>();
  154. MS_EXCEPTION_IF_NULL(cnode);
  155. auto attr_input = GetCNodePrimitiveNode(cnode);
  156. MS_EXCEPTION_IF_NULL(attr_input);
  157. auto value_node = attr_input->cast<ValueNodePtr>();
  158. MS_EXCEPTION_IF_NULL(value_node);
  159. auto value = value_node->value();
  160. MS_EXCEPTION_IF_NULL(value);
  161. auto primitive = value->cast<PrimitivePtr>();
  162. return primitive;
  163. }
  164. bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. if (!node->isa<CNode>()) {
  167. return false;
  168. }
  169. auto cnode = node->cast<CNodePtr>();
  170. MS_EXCEPTION_IF_NULL(cnode);
  171. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  172. }
  173. std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
  174. MS_EXCEPTION_IF_NULL(node);
  175. if (node->isa<CNode>()) {
  176. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  177. MS_EXCEPTION_IF_NULL(primitive);
  178. return primitive->name();
  179. }
  180. MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
  181. }
  182. std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
  183. MS_EXCEPTION_IF_NULL(node);
  184. return node->DebugString();
  185. }
  186. void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
  187. MS_EXCEPTION_IF_NULL(node);
  188. if (!node->isa<CNode>()) {
  189. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  190. }
  191. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  192. MS_EXCEPTION_IF_NULL(primitive);
  193. primitive->set_attr(key, value);
  194. }
  195. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
  196. CopyNodeAttr(key, key, from, to);
  197. }
  198. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  199. const AnfNodePtr &to) {
  200. MS_EXCEPTION_IF_NULL(from);
  201. MS_EXCEPTION_IF_NULL(to);
  202. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  203. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
  204. << to->DebugString();
  205. }
  206. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  207. MS_EXCEPTION_IF_NULL(from_primitive);
  208. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  209. MS_EXCEPTION_IF_NULL(to_primitive);
  210. to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
  211. }
  212. void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
  213. MS_EXCEPTION_IF_NULL(from);
  214. MS_EXCEPTION_IF_NULL(to);
  215. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  216. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
  217. << from->DebugString();
  218. }
  219. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  220. MS_EXCEPTION_IF_NULL(from_primitive);
  221. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  222. MS_EXCEPTION_IF_NULL(to_primitive);
  223. (void)to_primitive->SetAttrs(from_primitive->attrs());
  224. }
  225. void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
  226. MS_EXCEPTION_IF_NULL(node);
  227. if (!node->isa<CNode>()) {
  228. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  229. }
  230. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  231. MS_EXCEPTION_IF_NULL(primitive);
  232. primitive->EraseAttr(key);
  233. }
  234. bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const AnfNodePtr &node) {
  235. MS_EXCEPTION_IF_NULL(node);
  236. if (!node->isa<CNode>()) {
  237. MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
  238. return false;
  239. }
  240. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  241. MS_EXCEPTION_IF_NULL(primitive);
  242. return primitive->HasAttr(key);
  243. }
  244. size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
  245. MS_EXCEPTION_IF_NULL(node);
  246. if (!node->isa<CNode>()) {
  247. MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
  248. }
  249. auto cnode = node->cast<CNodePtr>();
  250. MS_EXCEPTION_IF_NULL(cnode);
  251. size_t input_num = cnode->inputs().size();
  252. if (input_num == 0) {
  253. MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
  254. }
  255. // exclude intputs[0],which is value_node storing attr,inputs left are real input
  256. return input_num - 1;
  257. }
  258. size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
  259. MS_EXCEPTION_IF_NULL(node);
  260. TypePtr type = node->Type();
  261. MS_EXCEPTION_IF_NULL(type);
  262. if (type->isa<Tuple>()) {
  263. auto tuple_type = type->cast<TuplePtr>();
  264. MS_EXCEPTION_IF_NULL(tuple_type);
  265. return tuple_type->size();
  266. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  267. return 1;
  268. } else if (type->isa<TypeNone>()) {
  269. return 0;
  270. } else {
  271. return 1;
  272. }
  273. }
  274. std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
  275. MS_EXCEPTION_IF_NULL(node);
  276. if (output_idx > GetOutputTensorNum(node)) {
  277. MS_LOG(EXCEPTION) << "Output index:" << output_idx
  278. << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
  279. << node->DebugString() << "]";
  280. }
  281. auto kernel_info = node->kernel_info();
  282. MS_EXCEPTION_IF_NULL(kernel_info);
  283. auto build_info = kernel_info->select_kernel_build_info();
  284. MS_EXCEPTION_IF_NULL(build_info);
  285. auto format = build_info->GetOutputFormat(output_idx);
  286. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  287. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  288. << " has a invalid output format";
  289. }
  290. return format;
  291. }
  292. std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
  293. MS_EXCEPTION_IF_NULL(node);
  294. if (input_idx > GetInputTensorNum(node)) {
  295. MS_LOG(EXCEPTION) << "Input index :" << input_idx
  296. << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
  297. << node->DebugString() << "]";
  298. }
  299. auto kernel_info = node->kernel_info();
  300. MS_EXCEPTION_IF_NULL(kernel_info);
  301. auto build_info = kernel_info->select_kernel_build_info();
  302. MS_EXCEPTION_IF_NULL(build_info);
  303. auto format = build_info->GetInputFormat(input_idx);
  304. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  305. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  306. << " has a invalid input format";
  307. }
  308. return format;
  309. }
  310. KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
  311. MS_EXCEPTION_IF_NULL(anf_node);
  312. if (!anf_node->isa<CNode>()) {
  313. MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
  314. }
  315. auto cnode = anf_node->cast<CNodePtr>();
  316. MS_EXCEPTION_IF_NULL(cnode);
  317. if (input_idx + 1 >= cnode->inputs().size()) {
  318. MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
  319. }
  320. auto node = cnode->input(input_idx + 1);
  321. MS_EXCEPTION_IF_NULL(node);
  322. return VisitKernel(node, 0);
  323. }
  324. std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
  325. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  326. return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
  327. }
  328. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  329. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  330. return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
  331. }
  332. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
  333. MS_EXCEPTION_IF_NULL(node);
  334. abstract::BaseShapePtr base_shape = node->Shape();
  335. MS_EXCEPTION_IF_NULL(base_shape);
  336. if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
  337. return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
  338. } else if (base_shape->isa<abstract::TupleShape>()) {
  339. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  340. MS_EXCEPTION_IF_NULL(tuple_shape);
  341. if (output_idx >= tuple_shape->size()) {
  342. MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
  343. << ".";
  344. }
  345. auto b_shp = (*tuple_shape)[output_idx];
  346. if (b_shp->isa<abstract::Shape>()) {
  347. return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
  348. } else if (b_shp->isa<abstract::NoShape>()) {
  349. return std::vector<size_t>();
  350. } else {
  351. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  352. << base_shape->ToString();
  353. }
  354. } else if (base_shape->isa<abstract::NoShape>()) {
  355. return std::vector<size_t>();
  356. }
  357. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  358. << base_shape->ToString();
  359. }
  360. std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
  361. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  362. return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
  363. }
  364. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
  365. auto format = GetOutputFormat(node, output_idx);
  366. auto infer_shape = GetOutputInferShape(node, output_idx);
  367. if (infer_shape.empty()) {
  368. return infer_shape;
  369. }
  370. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  371. if (trans::IsNeedPadding(format, infer_shape.size())) {
  372. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
  373. }
  374. return trans::TransShapeToDevice(infer_shape, format);
  375. }
  376. std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
  377. auto format = GetInputFormat(node, input_idx);
  378. auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
  379. if (infer_shape.empty()) {
  380. return infer_shape;
  381. }
  382. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  383. if (trans::IsNeedPadding(format, infer_shape.size())) {
  384. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
  385. }
  386. return trans::TransShapeToDevice(infer_shape, format);
  387. }
  388. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  389. MS_EXCEPTION_IF_NULL(node);
  390. if (input_idx > GetInputTensorNum(node)) {
  391. MS_LOG(EXCEPTION) << "The index:" << input_idx
  392. << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
  393. << node->DebugString() << "]";
  394. }
  395. auto kernel_info = node->kernel_info();
  396. MS_EXCEPTION_IF_NULL(kernel_info);
  397. auto build_info = kernel_info->select_kernel_build_info();
  398. MS_EXCEPTION_IF_NULL(build_info);
  399. if (build_info->IsInputDefaultPadding()) {
  400. return {};
  401. }
  402. return build_info->GetInputReshapeType(input_idx);
  403. }
  404. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
  405. MS_EXCEPTION_IF_NULL(node);
  406. if (output_idx > GetOutputTensorNum(node)) {
  407. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  408. << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
  409. }
  410. auto kernel_info = node->kernel_info();
  411. MS_EXCEPTION_IF_NULL(kernel_info);
  412. auto build_info = kernel_info->select_kernel_build_info();
  413. MS_EXCEPTION_IF_NULL(build_info);
  414. if (build_info->IsOutputDefaultPadding()) {
  415. return {};
  416. }
  417. return build_info->GetOutputReshapeType(output_idx);
  418. }
  419. TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
  420. MS_EXCEPTION_IF_NULL(node);
  421. TypePtr type_ptr = node->Type();
  422. MS_EXCEPTION_IF_NULL(type_ptr);
  423. if (type_ptr->isa<TensorType>() && output_idx == 0) {
  424. auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
  425. MS_EXCEPTION_IF_NULL(tensor_ptr);
  426. TypePtr elem = tensor_ptr->element();
  427. MS_EXCEPTION_IF_NULL(elem);
  428. return elem->type_id();
  429. } else if (type_ptr->isa<Tuple>()) {
  430. auto tuple_ptr = type_ptr->cast<TuplePtr>();
  431. MS_EXCEPTION_IF_NULL(tuple_ptr);
  432. if (output_idx >= tuple_ptr->size()) {
  433. MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
  434. }
  435. auto tuple_i = (*tuple_ptr)[output_idx];
  436. MS_EXCEPTION_IF_NULL(tuple_i);
  437. if (tuple_i->isa<TensorType>()) {
  438. auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
  439. MS_EXCEPTION_IF_NULL(tensor_ptr);
  440. TypePtr elem = tensor_ptr->element();
  441. MS_EXCEPTION_IF_NULL(elem);
  442. return elem->type_id();
  443. } else if (tuple_i->isa<Number>()) {
  444. return tuple_i->type_id();
  445. } else {
  446. MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
  447. return tuple_i->type_id();
  448. }
  449. } else if (type_ptr->isa<Number>()) {
  450. return type_ptr->type_id();
  451. }
  452. return type_ptr->type_id();
  453. }
  454. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
  455. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  456. return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
  457. }
  458. TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
  459. MS_EXCEPTION_IF_NULL(node);
  460. if (output_idx > GetOutputTensorNum(node)) {
  461. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  462. << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  463. }
  464. auto kernel_info = node->kernel_info();
  465. MS_EXCEPTION_IF_NULL(kernel_info);
  466. auto build_info = kernel_info->select_kernel_build_info();
  467. MS_EXCEPTION_IF_NULL(build_info);
  468. auto dtype = build_info->GetOutputDeviceType(output_idx);
  469. if (dtype == TypeId::kNumberTypeEnd) {
  470. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  471. << " has a invalid dtype";
  472. }
  473. return dtype;
  474. }
  475. TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
  476. MS_EXCEPTION_IF_NULL(node);
  477. if (input_idx > GetInputTensorNum(node)) {
  478. MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
  479. << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  480. }
  481. auto kernel_info = node->kernel_info();
  482. MS_EXCEPTION_IF_NULL(kernel_info);
  483. auto build_info = kernel_info->select_kernel_build_info();
  484. MS_EXCEPTION_IF_NULL(build_info);
  485. auto dtype = build_info->GetInputDeviceType(input_idx);
  486. if (dtype == TypeId::kNumberTypeEnd) {
  487. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  488. << " has a invalid dtype";
  489. }
  490. return dtype;
  491. }
  492. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
  493. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  494. return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
  495. }
  496. // get output device addr of anf_node
  497. const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) {
  498. MS_EXCEPTION_IF_NULL(node);
  499. if (opt::IsNopNode(node)) {
  500. auto cnode = node->cast<CNodePtr>();
  501. MS_EXCEPTION_IF_NULL(cnode);
  502. if (cnode->inputs().size() == 2) {
  503. return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
  504. } else {
  505. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
  506. }
  507. }
  508. auto kernel_info = node->kernel_info();
  509. MS_EXCEPTION_IF_NULL(kernel_info);
  510. auto addr = kernel_info->GetOutputAddr(output_idx);
  511. if (addr == nullptr) {
  512. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  513. << " output addr is not exist";
  514. }
  515. return addr;
  516. }
  517. DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) {
  518. MS_EXCEPTION_IF_NULL(node);
  519. if (opt::IsNopNode(node)) {
  520. auto cnode = node->cast<CNodePtr>();
  521. MS_EXCEPTION_IF_NULL(cnode);
  522. if (cnode->inputs().size() == 2) {
  523. return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
  524. } else {
  525. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
  526. }
  527. }
  528. auto kernel_info = node->kernel_info();
  529. MS_EXCEPTION_IF_NULL(kernel_info);
  530. auto addr = kernel_info->GetMutableOutputAddr(output_idx);
  531. if (addr == nullptr) {
  532. MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
  533. << " output addr is not exist";
  534. }
  535. return addr;
  536. }
  537. // get output device addr of anf_node
  538. bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
  539. MS_EXCEPTION_IF_NULL(node);
  540. if (output_idx > GetOutputTensorNum(node)) {
  541. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  542. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  543. }
  544. auto kernel_info = node->kernel_info();
  545. MS_EXCEPTION_IF_NULL(kernel_info);
  546. return kernel_info->OutputAddrExist(output_idx);
  547. }
  548. const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
  549. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  550. return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
  551. }
  552. DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
  553. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  554. return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
  555. }
  556. // set output device addr of anf_node
  557. void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  558. MS_EXCEPTION_IF_NULL(node);
  559. auto kernel_info = node->kernel_info();
  560. MS_EXCEPTION_IF_NULL(kernel_info);
  561. if (!kernel_info->SetOutputAddr(addr, output_idx)) {
  562. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  563. }
  564. }
  565. // set workspace device addr of anf_node
  566. void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  567. MS_EXCEPTION_IF_NULL(node);
  568. auto kernel_info = node->kernel_info();
  569. MS_EXCEPTION_IF_NULL(kernel_info);
  570. if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
  571. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  572. }
  573. }
  574. // get workspace device addr of anf_node
  575. DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
  576. MS_EXCEPTION_IF_NULL(node);
  577. auto kernel_info = node->kernel_info();
  578. MS_EXCEPTION_IF_NULL(kernel_info);
  579. auto addr = kernel_info->GetWorkspaceAddr(output_idx);
  580. if (addr == nullptr) {
  581. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  582. << "] workspace addr is not exist";
  583. }
  584. return addr;
  585. }
  586. // set infer shapes and types of anf node
  587. void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  588. const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
  589. MS_EXCEPTION_IF_NULL(node);
  590. if (types.size() != shapes.size()) {
  591. MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
  592. }
  593. if (shapes.empty()) {
  594. MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
  595. } else if (shapes.size() == 1) {
  596. // single output handle
  597. std::vector<int> shape_int;
  598. std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt);
  599. auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
  600. node->set_abstract(abstract);
  601. } else {
  602. // multiple output handle
  603. std::vector<AbstractBasePtr> abstract_list;
  604. for (size_t i = 0; i < types.size(); ++i) {
  605. std::vector<int> shape_int;
  606. std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt);
  607. abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int));
  608. }
  609. auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
  610. node->set_abstract(abstract_tuple);
  611. }
  612. }
  613. // copy an abstract of a node to another node
  614. void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
  615. to_node->set_abstract(from_node->abstract());
  616. }
  617. // get KernelBuildType of node, such as ATT,RT,FWK and so on
  618. KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
  619. MS_EXCEPTION_IF_NULL(node);
  620. auto kernel_info = node->kernel_info();
  621. MS_EXCEPTION_IF_NULL(kernel_info);
  622. // select_kernel_build_info() has checked whether return pointer is null
  623. auto build_info = kernel_info->select_kernel_build_info();
  624. MS_EXCEPTION_IF_NULL(build_info);
  625. return build_info->kernel_type();
  626. }
  627. kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
  628. MS_EXCEPTION_IF_NULL(node);
  629. auto kernel_info = node->kernel_info();
  630. MS_EXCEPTION_IF_NULL(kernel_info);
  631. auto build_info = kernel_info->select_kernel_build_info();
  632. MS_EXCEPTION_IF_NULL(build_info);
  633. return build_info->processor();
  634. }
  635. kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
  636. MS_EXCEPTION_IF_NULL(node);
  637. auto kernel_info = node->kernel_info();
  638. MS_EXCEPTION_IF_NULL(kernel_info);
  639. auto build_info = kernel_info->select_kernel_build_info();
  640. MS_EXCEPTION_IF_NULL(build_info);
  641. return build_info->fusion_type();
  642. }
  643. // set select kernel_build_info
  644. void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
  645. MS_EXCEPTION_IF_NULL(node);
  646. auto kernel_info = node->kernel_info();
  647. MS_EXCEPTION_IF_NULL(kernel_info);
  648. return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
  649. }
  650. // get select kernel_build_info
  651. KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
  652. MS_EXCEPTION_IF_NULL(node);
  653. auto kernel_info = node->kernel_info();
  654. MS_EXCEPTION_IF_NULL(kernel_info);
  655. return kernel_info->GetMutableSelectKernelBuildInfo();
  656. }
  657. // get kernelMode
  658. KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
  659. MS_EXCEPTION_IF_NULL(node);
  660. auto kernel_info = node->kernel_info();
  661. MS_EXCEPTION_IF_NULL(kernel_info);
  662. return kernel_info->MutableKernelMod();
  663. }
  664. // set kernel mod
  665. void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
  666. MS_EXCEPTION_IF_NULL(node);
  667. auto kernel_info = node->kernel_info();
  668. MS_EXCEPTION_IF_NULL(kernel_info);
  669. kernel_info->set_kernel_mod(kernel_mod);
  670. }
  671. bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
  672. MS_EXCEPTION_IF_NULL(node);
  673. // parameter and value node is not a real kernel too
  674. if (!node->isa<CNode>()) {
  675. return true;
  676. }
  677. auto cnode = node->cast<CNodePtr>();
  678. MS_EXCEPTION_IF_NULL(cnode);
  679. if (cnode->inputs().empty()) {
  680. MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  681. }
  682. auto input = cnode->inputs()[0];
  683. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  684. IsPrimitive(input, prim::kPrimTensorSummary) ||
  685. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  686. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  687. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  688. IsPrimitive(input, prim::kPrimReturn);
  689. return !is_virtual_node;
  690. }
  691. bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
  692. MS_EXCEPTION_IF_NULL(node);
  693. // parameter and value node is not a real cnode kernel
  694. if (!node->isa<CNode>()) {
  695. return false;
  696. }
  697. // return considered as a real node
  698. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  699. return true;
  700. }
  701. return IsRealKernel(node);
  702. }
  703. bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
  704. MS_EXCEPTION_IF_NULL(node);
  705. return node->has_default();
  706. }
  707. void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
  708. MS_EXCEPTION_IF_NULL(node);
  709. auto kernel_info = node->kernel_info();
  710. MS_EXCEPTION_IF_NULL(kernel_info);
  711. kernel_info->set_stream_id(stream_id);
  712. }
  713. uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
  714. MS_EXCEPTION_IF_NULL(node);
  715. auto kernel_info = node->kernel_info();
  716. MS_EXCEPTION_IF_NULL(kernel_info);
  717. return kernel_info->stream_id();
  718. }
  719. void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
  720. MS_EXCEPTION_IF_NULL(node);
  721. auto kernel_info = node->kernel_info();
  722. MS_EXCEPTION_IF_NULL(kernel_info);
  723. kernel_info->set_stream_distinction_label(stream_label);
  724. }
  725. uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
  726. MS_EXCEPTION_IF_NULL(node);
  727. auto kernel_info = node->kernel_info();
  728. MS_EXCEPTION_IF_NULL(kernel_info);
  729. return kernel_info->stream_distinction_label();
  730. }
  731. void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
  732. MS_EXCEPTION_IF_NULL(node);
  733. auto kernel_info = node->kernel_info();
  734. MS_EXCEPTION_IF_NULL(kernel_info);
  735. kernel_info->set_graph_id(graph_id);
  736. }
  737. uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
  738. MS_EXCEPTION_IF_NULL(node);
  739. auto kernel_info = node->kernel_info();
  740. MS_EXCEPTION_IF_NULL(kernel_info);
  741. return kernel_info->graph_id();
  742. }
  743. bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
  744. MS_EXCEPTION_IF_NULL(anf);
  745. TypePtr type = anf->Type();
  746. MS_EXCEPTION_IF_NULL(type);
  747. return type->isa<Tuple>();
  748. }
  749. AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
  750. MS_EXCEPTION_IF_NULL(node);
  751. auto get_input_index = index + 1;
  752. if (index + 1 > node->inputs().size()) {
  753. MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
  754. << node->inputs().size();
  755. }
  756. // input 0 is primitive node
  757. return node->input(get_input_index);
  758. }
  759. bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
  760. MS_EXCEPTION_IF_NULL(node);
  761. if (node->isa<ValueNode>()) {
  762. return false;
  763. }
  764. auto kernel_info = node->kernel_info();
  765. MS_EXCEPTION_IF_NULL(kernel_info);
  766. return kernel_info->is_feature_map();
  767. }
  768. bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
  769. if (!node->isa<CNode>()) {
  770. MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
  771. }
  772. auto cnode = node->cast<CNodePtr>();
  773. MS_EXCEPTION_IF_NULL(cnode);
  774. auto input_node = cnode->input(input_index + 1);
  775. return IsFeatureMapOutput(input_node);
  776. }
  777. size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
  778. MS_EXCEPTION_IF_NULL(anf_node);
  779. static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
  780. {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
  781. {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}},
  782. {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  783. {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
  784. {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
  785. {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  786. {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  787. {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  788. {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  789. {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}};
  790. size_t ret = cur_index;
  791. auto node_name = AnfAlgo::GetCNodeName(anf_node);
  792. if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
  793. auto find = spec_node_list.find(node_name);
  794. if (find != spec_node_list.end()) {
  795. ret = find->second[cur_index];
  796. MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
  797. }
  798. }
  799. return ret;
  800. }
  801. void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
  802. MS_EXCEPTION_IF_NULL(node);
  803. MS_EXCEPTION_IF_NULL(input_node);
  804. node->set_input(index + 1, input_node);
  805. }
  806. bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
  807. MS_EXCEPTION_IF_NULL(node);
  808. if (!node->isa<CNode>()) {
  809. return false;
  810. }
  811. auto kernel_name = AnfAlgo::GetCNodeName(node);
  812. if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
  813. kernel_name == kReduceScatterOpName) {
  814. return true;
  815. }
  816. return false;
  817. }
  818. bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
  819. auto kernel_name = AnfAlgo::GetCNodeName(node);
  820. return kernel_name == kGetNextOpName;
  821. }
  822. } // namespace session
  823. } // namespace mindspore