You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.cc 40 kB

6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/anf_runtime_algorithm.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include <map>
  20. #include <set>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "operator/ops.h"
  24. #include "utils/utils.h"
  25. #include "device/kernel_info.h"
  26. #include "device/device_address.h"
  27. #include "pre_activate/common/helper.h"
  28. #include "kernel/kernel.h"
  29. #include "kernel/kernel_build_info.h"
  30. #include "common/utils.h"
  31. #include "common/trans.h"
  32. namespace mindspore {
  33. namespace session {
  34. using abstract::AbstractTensor;
  35. using abstract::AbstractTuple;
  36. using device::KernelInfo;
  37. using device::ascend::AscendDeviceAddress;
  38. using kernel::KernelBuildInfoPtr;
  39. using kernel::KernelMod;
  40. using kernel::KernelModPtr;
  41. namespace {
  42. std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
  43. MS_EXCEPTION_IF_NULL(shape);
  44. std::vector<size_t> shape_size_t;
  45. std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
  46. return shape_size_t;
  47. }
  48. } // namespace
  49. KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
  50. MS_EXCEPTION_IF_NULL(anf_node);
  51. if (anf_node->isa<ValueNode>()) {
  52. return std::make_pair(anf_node, 0);
  53. } else if (anf_node->isa<Parameter>()) {
  54. return std::make_pair(anf_node, 0);
  55. } else if (anf_node->isa<CNode>()) {
  56. auto cnode = anf_node->cast<CNodePtr>();
  57. MS_EXCEPTION_IF_NULL(cnode);
  58. auto input0 = cnode->input(0);
  59. MS_EXCEPTION_IF_NULL(input0);
  60. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  61. auto node = cnode->input(index + IntToSize(1));
  62. MS_EXCEPTION_IF_NULL(node);
  63. return VisitKernel(node, 0);
  64. } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  65. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  66. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  67. }
  68. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  69. MS_EXCEPTION_IF_NULL(input2);
  70. auto value_node = input2->cast<ValueNodePtr>();
  71. MS_EXCEPTION_IF_NULL(value_node);
  72. int item_idx = GetValue<int>(value_node->value());
  73. return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
  74. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  75. return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
  76. } else {
  77. return std::make_pair(anf_node, index);
  78. }
  79. } else {
  80. MS_LOG(EXCEPTION) << "The input is invalid";
  81. }
  82. }
  83. KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
  84. bool visit_nop_node,
  85. const std::vector<PrimitivePtr> &return_types) {
  86. MS_EXCEPTION_IF_NULL(anf_node);
  87. for (const auto &prim_type : return_types) {
  88. if (CheckPrimitiveType(anf_node, prim_type)) {
  89. return std::make_pair(anf_node, index);
  90. }
  91. }
  92. if (anf_node->isa<ValueNode>()) {
  93. return std::make_pair(anf_node, 0);
  94. } else if (anf_node->isa<Parameter>()) {
  95. return std::make_pair(anf_node, 0);
  96. } else if (anf_node->isa<CNode>()) {
  97. auto cnode = anf_node->cast<CNodePtr>();
  98. MS_EXCEPTION_IF_NULL(cnode);
  99. auto input0 = cnode->input(0);
  100. MS_EXCEPTION_IF_NULL(input0);
  101. if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  102. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  103. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  104. }
  105. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  106. MS_EXCEPTION_IF_NULL(input2);
  107. auto value_node = input2->cast<ValueNodePtr>();
  108. MS_EXCEPTION_IF_NULL(value_node);
  109. int item_idx = GetValue<int>(value_node->value());
  110. return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
  111. visit_nop_node, return_types);
  112. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  113. return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
  114. } else if (opt::IsNopNode(cnode) && visit_nop_node) {
  115. if (cnode->inputs().size() == 2) {
  116. return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
  117. } else {
  118. MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
  119. }
  120. } else {
  121. return std::make_pair(anf_node, index);
  122. }
  123. } else {
  124. MS_LOG(EXCEPTION) << "The input is invalid";
  125. }
  126. }
  127. std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
  128. const std::vector<PrimitivePtr> &return_types) {
  129. std::vector<AnfNodePtr> ret;
  130. auto return_prim_type = return_types;
  131. // if visited make_tuple should return back
  132. return_prim_type.push_back(prim::kPrimMakeTuple);
  133. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
  134. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  135. MS_EXCEPTION_IF_NULL(item_with_index.first);
  136. auto make_tuple = item_with_index.first->cast<CNodePtr>();
  137. MS_EXCEPTION_IF_NULL(make_tuple);
  138. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  139. auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
  140. (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
  141. }
  142. return ret;
  143. }
  144. ret.push_back(item_with_index.first);
  145. return ret;
  146. }
  147. AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
  148. MS_EXCEPTION_IF_NULL(node);
  149. return node->input(kAnfPrimitiveIndex);
  150. }
  151. PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
  152. MS_EXCEPTION_IF_NULL(node);
  153. auto cnode = node->cast<CNodePtr>();
  154. MS_EXCEPTION_IF_NULL(cnode);
  155. auto attr_input = GetCNodePrimitiveNode(cnode);
  156. MS_EXCEPTION_IF_NULL(attr_input);
  157. auto value_node = attr_input->cast<ValueNodePtr>();
  158. MS_EXCEPTION_IF_NULL(value_node);
  159. auto value = value_node->value();
  160. MS_EXCEPTION_IF_NULL(value);
  161. auto primitive = value->cast<PrimitivePtr>();
  162. return primitive;
  163. }
  164. bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. if (!node->isa<CNode>()) {
  167. return false;
  168. }
  169. auto cnode = node->cast<CNodePtr>();
  170. MS_EXCEPTION_IF_NULL(cnode);
  171. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  172. }
  173. std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
  174. MS_EXCEPTION_IF_NULL(node);
  175. if (node->isa<CNode>()) {
  176. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  177. MS_EXCEPTION_IF_NULL(primitive);
  178. return primitive->name();
  179. }
  180. MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
  181. }
  182. std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
  183. MS_EXCEPTION_IF_NULL(node);
  184. return node->DebugString();
  185. }
  186. void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
  187. MS_EXCEPTION_IF_NULL(node);
  188. if (!node->isa<CNode>()) {
  189. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  190. }
  191. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  192. MS_EXCEPTION_IF_NULL(primitive);
  193. primitive->set_attr(key, value);
  194. }
  195. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
  196. CopyNodeAttr(key, key, from, to);
  197. }
  198. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  199. const AnfNodePtr &to) {
  200. MS_EXCEPTION_IF_NULL(from);
  201. MS_EXCEPTION_IF_NULL(to);
  202. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  203. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
  204. << to->DebugString();
  205. }
  206. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  207. MS_EXCEPTION_IF_NULL(from_primitive);
  208. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  209. MS_EXCEPTION_IF_NULL(to_primitive);
  210. to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
  211. }
  212. void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
  213. MS_EXCEPTION_IF_NULL(from);
  214. MS_EXCEPTION_IF_NULL(to);
  215. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  216. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
  217. << from->DebugString();
  218. }
  219. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  220. MS_EXCEPTION_IF_NULL(from_primitive);
  221. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  222. MS_EXCEPTION_IF_NULL(to_primitive);
  223. (void)to_primitive->SetAttrs(from_primitive->attrs());
  224. }
  225. void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
  226. MS_EXCEPTION_IF_NULL(node);
  227. if (!node->isa<CNode>()) {
  228. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  229. }
  230. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  231. MS_EXCEPTION_IF_NULL(primitive);
  232. primitive->EraseAttr(key);
  233. }
  234. bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
  235. MS_EXCEPTION_IF_NULL(node);
  236. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  237. MS_EXCEPTION_IF_NULL(primitive);
  238. return primitive->HasAttr(key);
  239. }
  240. size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
  241. MS_EXCEPTION_IF_NULL(node);
  242. if (!node->isa<CNode>()) {
  243. MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
  244. }
  245. auto cnode = node->cast<CNodePtr>();
  246. MS_EXCEPTION_IF_NULL(cnode);
  247. size_t input_num = cnode->inputs().size();
  248. if (input_num == 0) {
  249. MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
  250. }
  251. // exclude intputs[0],which is value_node storing attr,inputs left are real input
  252. return input_num - 1;
  253. }
  254. size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
  255. MS_EXCEPTION_IF_NULL(node);
  256. TypePtr type = node->Type();
  257. if (type == nullptr) {
  258. return 0;
  259. }
  260. if (type->isa<Tuple>()) {
  261. auto tuple_type = type->cast<TuplePtr>();
  262. MS_EXCEPTION_IF_NULL(tuple_type);
  263. return tuple_type->size();
  264. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  265. return 1;
  266. } else if (type->isa<TypeNone>()) {
  267. return 0;
  268. } else {
  269. return 1;
  270. }
  271. }
  272. std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
  273. MS_EXCEPTION_IF_NULL(node);
  274. if (output_idx > GetOutputTensorNum(node)) {
  275. MS_LOG(EXCEPTION) << "Output index:" << output_idx
  276. << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
  277. << node->DebugString() << "]";
  278. }
  279. if (!AnfAlgo::IsRealKernel(node)) {
  280. return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
  281. }
  282. auto kernel_info = node->kernel_info();
  283. MS_EXCEPTION_IF_NULL(kernel_info);
  284. auto build_info = kernel_info->select_kernel_build_info();
  285. MS_EXCEPTION_IF_NULL(build_info);
  286. auto format = build_info->GetOutputFormat(output_idx);
  287. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  288. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  289. << " has a invalid output format";
  290. }
  291. return format;
  292. }
  293. std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
  294. MS_EXCEPTION_IF_NULL(node);
  295. if (input_idx > GetInputTensorNum(node)) {
  296. MS_LOG(EXCEPTION) << "Input index :" << input_idx
  297. << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
  298. << node->DebugString() << "]";
  299. }
  300. if (!IsRealKernel(node)) {
  301. GetPrevNodeOutputFormat(node, input_idx);
  302. }
  303. auto kernel_info = node->kernel_info();
  304. MS_EXCEPTION_IF_NULL(kernel_info);
  305. auto build_info = kernel_info->select_kernel_build_info();
  306. MS_EXCEPTION_IF_NULL(build_info);
  307. auto format = build_info->GetInputFormat(input_idx);
  308. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  309. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  310. << " has a invalid input format";
  311. }
  312. return format;
  313. }
  314. KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
  315. MS_EXCEPTION_IF_NULL(anf_node);
  316. if (!anf_node->isa<CNode>()) {
  317. MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
  318. }
  319. auto cnode = anf_node->cast<CNodePtr>();
  320. MS_EXCEPTION_IF_NULL(cnode);
  321. if (input_idx + 1 >= cnode->inputs().size()) {
  322. MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
  323. }
  324. auto node = cnode->input(input_idx + 1);
  325. MS_EXCEPTION_IF_NULL(node);
  326. return VisitKernel(node, 0);
  327. }
  328. std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
  329. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  330. return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
  331. }
  332. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  333. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  334. return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
  335. }
  336. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
  337. MS_EXCEPTION_IF_NULL(node);
  338. abstract::BaseShapePtr base_shape = node->Shape();
  339. MS_EXCEPTION_IF_NULL(base_shape);
  340. if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
  341. return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
  342. } else if (base_shape->isa<abstract::TupleShape>()) {
  343. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  344. MS_EXCEPTION_IF_NULL(tuple_shape);
  345. if (output_idx >= tuple_shape->size()) {
  346. MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
  347. << ".";
  348. }
  349. auto b_shp = (*tuple_shape)[output_idx];
  350. if (b_shp->isa<abstract::Shape>()) {
  351. return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
  352. } else if (b_shp->isa<abstract::NoShape>()) {
  353. return std::vector<size_t>();
  354. } else {
  355. MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
  356. << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
  357. }
  358. } else if (base_shape->isa<abstract::NoShape>()) {
  359. return std::vector<size_t>();
  360. }
  361. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  362. << base_shape->ToString();
  363. }
  364. std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
  365. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  366. return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
  367. }
  368. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
  369. auto format = GetOutputFormat(node, output_idx);
  370. auto infer_shape = GetOutputInferShape(node, output_idx);
  371. if (infer_shape.empty()) {
  372. return infer_shape;
  373. }
  374. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  375. if (trans::IsNeedPadding(format, infer_shape.size())) {
  376. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
  377. }
  378. return trans::TransShapeToDevice(infer_shape, format);
  379. }
  380. std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
  381. auto format = GetInputFormat(node, input_idx);
  382. auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
  383. if (infer_shape.empty()) {
  384. return infer_shape;
  385. }
  386. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  387. if (trans::IsNeedPadding(format, infer_shape.size())) {
  388. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
  389. }
  390. return trans::TransShapeToDevice(infer_shape, format);
  391. }
  392. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  393. MS_EXCEPTION_IF_NULL(node);
  394. if (input_idx > GetInputTensorNum(node)) {
  395. MS_LOG(EXCEPTION) << "The index:" << input_idx
  396. << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
  397. << node->DebugString() << "]";
  398. }
  399. if (!IsRealKernel(node)) {
  400. return GetPrevNodeOutputReshapeType(node, input_idx);
  401. }
  402. auto kernel_info = node->kernel_info();
  403. MS_EXCEPTION_IF_NULL(kernel_info);
  404. auto build_info = kernel_info->select_kernel_build_info();
  405. MS_EXCEPTION_IF_NULL(build_info);
  406. if (build_info->IsInputDefaultPadding()) {
  407. return {};
  408. }
  409. return build_info->GetInputReshapeType(input_idx);
  410. }
  411. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
  412. MS_EXCEPTION_IF_NULL(node);
  413. if (output_idx > GetOutputTensorNum(node)) {
  414. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  415. << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
  416. }
  417. if (!IsRealKernel(node)) {
  418. return GetPrevNodeOutputReshapeType(node, output_idx);
  419. }
  420. auto kernel_info = node->kernel_info();
  421. MS_EXCEPTION_IF_NULL(kernel_info);
  422. auto build_info = kernel_info->select_kernel_build_info();
  423. MS_EXCEPTION_IF_NULL(build_info);
  424. if (build_info->IsOutputDefaultPadding()) {
  425. return {};
  426. }
  427. return build_info->GetOutputReshapeType(output_idx);
  428. }
  429. TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
  430. MS_EXCEPTION_IF_NULL(node);
  431. TypePtr type_ptr = node->Type();
  432. MS_EXCEPTION_IF_NULL(type_ptr);
  433. if (type_ptr->isa<TensorType>() && output_idx == 0) {
  434. auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
  435. MS_EXCEPTION_IF_NULL(tensor_ptr);
  436. TypePtr elem = tensor_ptr->element();
  437. MS_EXCEPTION_IF_NULL(elem);
  438. return elem->type_id();
  439. } else if (type_ptr->isa<Tuple>()) {
  440. auto tuple_ptr = type_ptr->cast<TuplePtr>();
  441. MS_EXCEPTION_IF_NULL(tuple_ptr);
  442. if (output_idx >= tuple_ptr->size()) {
  443. MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
  444. }
  445. auto tuple_i = (*tuple_ptr)[output_idx];
  446. MS_EXCEPTION_IF_NULL(tuple_i);
  447. if (tuple_i->isa<TensorType>()) {
  448. auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
  449. MS_EXCEPTION_IF_NULL(tensor_ptr);
  450. TypePtr elem = tensor_ptr->element();
  451. MS_EXCEPTION_IF_NULL(elem);
  452. return elem->type_id();
  453. } else if (tuple_i->isa<Number>()) {
  454. return tuple_i->type_id();
  455. } else {
  456. MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
  457. return tuple_i->type_id();
  458. }
  459. } else if (type_ptr->isa<Number>()) {
  460. return type_ptr->type_id();
  461. }
  462. return type_ptr->type_id();
  463. }
  464. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
  465. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  466. return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
  467. }
  468. TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
  469. MS_EXCEPTION_IF_NULL(node);
  470. if (output_idx > GetOutputTensorNum(node)) {
  471. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  472. << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  473. }
  474. if (!IsRealKernel(node)) {
  475. return GetPrevNodeOutputDeviceDataType(node, output_idx);
  476. }
  477. auto kernel_info = node->kernel_info();
  478. MS_EXCEPTION_IF_NULL(kernel_info);
  479. auto build_info = kernel_info->select_kernel_build_info();
  480. MS_EXCEPTION_IF_NULL(build_info);
  481. auto dtype = build_info->GetOutputDeviceType(output_idx);
  482. if (dtype == TypeId::kNumberTypeEnd) {
  483. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  484. << " has a invalid dtype";
  485. }
  486. return dtype;
  487. }
  488. TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
  489. MS_EXCEPTION_IF_NULL(node);
  490. if (input_idx > GetInputTensorNum(node)) {
  491. MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
  492. << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  493. }
  494. if (!IsRealKernel(node)) {
  495. return GetPrevNodeOutputDeviceDataType(node, 0);
  496. }
  497. auto kernel_info = node->kernel_info();
  498. MS_EXCEPTION_IF_NULL(kernel_info);
  499. auto build_info = kernel_info->select_kernel_build_info();
  500. MS_EXCEPTION_IF_NULL(build_info);
  501. auto dtype = build_info->GetInputDeviceType(input_idx);
  502. if (dtype == TypeId::kNumberTypeEnd) {
  503. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  504. << " has a invalid dtype";
  505. }
  506. return dtype;
  507. }
  508. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
  509. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  510. return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
  511. }
  512. // get output device addr of anf_node
  513. const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx,
  514. bool visit_nop_node) {
  515. MS_EXCEPTION_IF_NULL(node);
  516. if (opt::IsNopNode(node) && visit_nop_node) {
  517. auto cnode = node->cast<CNodePtr>();
  518. MS_EXCEPTION_IF_NULL(cnode);
  519. if (cnode->inputs().size() == 2) {
  520. return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
  521. } else {
  522. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
  523. }
  524. }
  525. auto kernel_info = node->kernel_info();
  526. MS_EXCEPTION_IF_NULL(kernel_info);
  527. auto addr = kernel_info->GetOutputAddr(output_idx);
  528. if (addr == nullptr) {
  529. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  530. << " output addr is not exist";
  531. }
  532. return addr;
  533. }
  534. DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
  535. bool visit_nop_node) {
  536. MS_EXCEPTION_IF_NULL(node);
  537. if (opt::IsNopNode(node) && visit_nop_node) {
  538. auto cnode = node->cast<CNodePtr>();
  539. MS_EXCEPTION_IF_NULL(cnode);
  540. if (cnode->inputs().size() == 2) {
  541. return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
  542. } else {
  543. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
  544. }
  545. }
  546. auto kernel_info = node->kernel_info();
  547. MS_EXCEPTION_IF_NULL(kernel_info);
  548. auto addr = kernel_info->GetMutableOutputAddr(output_idx);
  549. if (addr == nullptr) {
  550. MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
  551. << " output addr is not exist";
  552. }
  553. return addr;
  554. }
  555. // get output device addr of anf_node
  556. bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
  557. MS_EXCEPTION_IF_NULL(node);
  558. if (output_idx > GetOutputTensorNum(node)) {
  559. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  560. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  561. }
  562. auto kernel_info = node->kernel_info();
  563. MS_EXCEPTION_IF_NULL(kernel_info);
  564. return kernel_info->OutputAddrExist(output_idx);
  565. }
  566. const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
  567. bool visit_nop_node) {
  568. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  569. return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
  570. }
  571. DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
  572. bool visit_nop_node) {
  573. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  574. return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
  575. }
  576. // set output device addr of anf_node
  577. void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  578. MS_EXCEPTION_IF_NULL(node);
  579. auto kernel_info = node->kernel_info();
  580. MS_EXCEPTION_IF_NULL(kernel_info);
  581. if (!kernel_info->SetOutputAddr(addr, output_idx)) {
  582. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  583. }
  584. }
  585. // set workspace device addr of anf_node
  586. void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  587. MS_EXCEPTION_IF_NULL(node);
  588. auto kernel_info = node->kernel_info();
  589. MS_EXCEPTION_IF_NULL(kernel_info);
  590. if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
  591. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  592. }
  593. }
  594. // get workspace device addr of anf_node
  595. DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
  596. MS_EXCEPTION_IF_NULL(node);
  597. auto kernel_info = node->kernel_info();
  598. MS_EXCEPTION_IF_NULL(kernel_info);
  599. auto addr = kernel_info->GetWorkspaceAddr(output_idx);
  600. if (addr == nullptr) {
  601. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  602. << "] workspace addr is not exist";
  603. }
  604. return addr;
  605. }
  606. // set infer shapes and types of anf node
  607. void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  608. const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
  609. MS_EXCEPTION_IF_NULL(node);
  610. if (types.size() != shapes.size()) {
  611. MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
  612. }
  613. if (shapes.empty()) {
  614. MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
  615. } else if (shapes.size() == 1) {
  616. // single output handle
  617. std::vector<int> shape_int;
  618. std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt);
  619. auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
  620. node->set_abstract(abstract);
  621. } else {
  622. // multiple output handle
  623. std::vector<AbstractBasePtr> abstract_list;
  624. for (size_t i = 0; i < types.size(); ++i) {
  625. std::vector<int> shape_int;
  626. std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt);
  627. abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int));
  628. }
  629. auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
  630. node->set_abstract(abstract_tuple);
  631. }
  632. }
  633. // copy an abstract of a node to another node
  634. void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
  635. to_node->set_abstract(from_node->abstract());
  636. }
  637. kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
  638. MS_EXCEPTION_IF_NULL(node);
  639. auto kernel_info = node->kernel_info();
  640. MS_EXCEPTION_IF_NULL(kernel_info);
  641. // select_kernel_build_info() has checked whether return pointer is null
  642. auto build_info = kernel_info->select_kernel_build_info();
  643. MS_EXCEPTION_IF_NULL(build_info);
  644. return build_info->op_pattern();
  645. }
  646. // get KernelBuildType of node, such as ATT,RT,FWK and so on
  647. KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
  648. MS_EXCEPTION_IF_NULL(node);
  649. auto kernel_info = node->kernel_info();
  650. MS_EXCEPTION_IF_NULL(kernel_info);
  651. // select_kernel_build_info() has checked whether return pointer is null
  652. auto build_info = kernel_info->select_kernel_build_info();
  653. MS_EXCEPTION_IF_NULL(build_info);
  654. return build_info->kernel_type();
  655. }
  656. kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
  657. MS_EXCEPTION_IF_NULL(node);
  658. auto kernel_info = node->kernel_info();
  659. MS_EXCEPTION_IF_NULL(kernel_info);
  660. auto build_info = kernel_info->select_kernel_build_info();
  661. MS_EXCEPTION_IF_NULL(build_info);
  662. return build_info->processor();
  663. }
  664. kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
  665. MS_EXCEPTION_IF_NULL(node);
  666. auto kernel_info = node->kernel_info();
  667. MS_EXCEPTION_IF_NULL(kernel_info);
  668. auto build_info = kernel_info->select_kernel_build_info();
  669. MS_EXCEPTION_IF_NULL(build_info);
  670. return build_info->fusion_type();
  671. }
  672. // set select kernel_build_info
  673. void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
  674. MS_EXCEPTION_IF_NULL(node);
  675. auto kernel_info = node->kernel_info();
  676. MS_EXCEPTION_IF_NULL(kernel_info);
  677. return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
  678. }
  679. // get select kernel_build_info
  680. KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
  681. MS_EXCEPTION_IF_NULL(node);
  682. auto kernel_info = node->kernel_info();
  683. MS_EXCEPTION_IF_NULL(kernel_info);
  684. return kernel_info->GetMutableSelectKernelBuildInfo();
  685. }
  686. // get kernelMode
  687. KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
  688. MS_EXCEPTION_IF_NULL(node);
  689. auto kernel_info = node->kernel_info();
  690. MS_EXCEPTION_IF_NULL(kernel_info);
  691. return kernel_info->MutableKernelMod();
  692. }
  693. // set kernel mod
  694. void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
  695. MS_EXCEPTION_IF_NULL(node);
  696. auto kernel_info = node->kernel_info();
  697. MS_EXCEPTION_IF_NULL(kernel_info);
  698. kernel_info->set_kernel_mod(kernel_mod);
  699. }
  700. bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
  701. MS_EXCEPTION_IF_NULL(node);
  702. // parameter and value node is not a real kernel too
  703. if (!node->isa<CNode>()) {
  704. return true;
  705. }
  706. auto cnode = node->cast<CNodePtr>();
  707. MS_EXCEPTION_IF_NULL(cnode);
  708. if (cnode->inputs().empty()) {
  709. MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  710. }
  711. auto input = cnode->inputs()[0];
  712. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  713. IsPrimitive(input, prim::kPrimTensorSummary) ||
  714. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  715. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  716. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  717. IsPrimitive(input, prim::kPrimReturn);
  718. return !is_virtual_node;
  719. }
  720. bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
  721. MS_EXCEPTION_IF_NULL(node);
  722. // parameter and value node is not a real cnode kernel
  723. if (!node->isa<CNode>()) {
  724. return false;
  725. }
  726. // return considered as a real node
  727. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  728. return true;
  729. }
  730. return IsRealKernel(node);
  731. }
  732. bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
  733. MS_EXCEPTION_IF_NULL(node);
  734. return node->has_default();
  735. }
  736. void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
  737. MS_EXCEPTION_IF_NULL(node);
  738. auto kernel_info = node->kernel_info();
  739. MS_EXCEPTION_IF_NULL(kernel_info);
  740. kernel_info->set_stream_id(stream_id);
  741. }
  742. uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
  743. MS_EXCEPTION_IF_NULL(node);
  744. auto kernel_info = node->kernel_info();
  745. MS_EXCEPTION_IF_NULL(kernel_info);
  746. return kernel_info->stream_id();
  747. }
  748. void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
  749. MS_EXCEPTION_IF_NULL(node);
  750. auto kernel_info = node->kernel_info();
  751. MS_EXCEPTION_IF_NULL(kernel_info);
  752. kernel_info->set_stream_distinction_label(stream_label);
  753. }
  754. uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
  755. MS_EXCEPTION_IF_NULL(node);
  756. auto kernel_info = node->kernel_info();
  757. MS_EXCEPTION_IF_NULL(kernel_info);
  758. return kernel_info->stream_distinction_label();
  759. }
  760. void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
  761. MS_EXCEPTION_IF_NULL(node);
  762. auto kernel_info = node->kernel_info();
  763. MS_EXCEPTION_IF_NULL(kernel_info);
  764. kernel_info->set_graph_id(graph_id);
  765. }
  766. uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
  767. MS_EXCEPTION_IF_NULL(node);
  768. auto kernel_info = node->kernel_info();
  769. MS_EXCEPTION_IF_NULL(kernel_info);
  770. return kernel_info->graph_id();
  771. }
  772. bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
  773. MS_EXCEPTION_IF_NULL(anf);
  774. TypePtr type = anf->Type();
  775. MS_EXCEPTION_IF_NULL(type);
  776. return type->isa<Tuple>();
  777. }
  778. AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
  779. MS_EXCEPTION_IF_NULL(node);
  780. auto get_input_index = index + 1;
  781. if (index + 1 > node->inputs().size()) {
  782. MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
  783. << node->inputs().size();
  784. }
  785. // input 0 is primitive node
  786. return node->input(get_input_index);
  787. }
  788. bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
  789. MS_EXCEPTION_IF_NULL(node);
  790. if (node->isa<ValueNode>()) {
  791. return false;
  792. }
  793. auto kernel_info = node->kernel_info();
  794. MS_EXCEPTION_IF_NULL(kernel_info);
  795. return kernel_info->is_feature_map();
  796. }
  797. bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
  798. if (!node->isa<CNode>()) {
  799. MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
  800. }
  801. auto cnode = node->cast<CNodePtr>();
  802. MS_EXCEPTION_IF_NULL(cnode);
  803. auto input_node = cnode->input(input_index + 1);
  804. return IsFeatureMapOutput(input_node);
  805. }
  806. size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
  807. MS_EXCEPTION_IF_NULL(anf_node);
  808. static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
  809. {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
  810. {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}},
  811. {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  812. {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
  813. {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
  814. {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  815. {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  816. {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  817. {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  818. {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  819. {prim::kPrimApplyCenteredRMSProp->name(),
  820. {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}};
  821. size_t ret = cur_index;
  822. auto node_name = AnfAlgo::GetCNodeName(anf_node);
  823. if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
  824. auto find = spec_node_list.find(node_name);
  825. if (find != spec_node_list.end()) {
  826. ret = find->second[cur_index];
  827. MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
  828. }
  829. }
  830. return ret;
  831. }
  832. void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
  833. MS_EXCEPTION_IF_NULL(node);
  834. MS_EXCEPTION_IF_NULL(input_node);
  835. node->set_input(index + 1, input_node);
  836. }
  837. bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
  838. MS_EXCEPTION_IF_NULL(node);
  839. if (!node->isa<CNode>()) {
  840. return false;
  841. }
  842. auto kernel_name = AnfAlgo::GetCNodeName(node);
  843. if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
  844. kernel_name == kReduceScatterOpName) {
  845. return true;
  846. }
  847. return false;
  848. }
  849. bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
  850. auto kernel_name = AnfAlgo::GetCNodeName(node);
  851. return kernel_name == kGetNextOpName;
  852. }
  853. FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
  854. MS_EXCEPTION_IF_NULL(node);
  855. auto value_node = node->cast<ValueNodePtr>();
  856. if (value_node == nullptr) {
  857. return nullptr;
  858. }
  859. auto value = value_node->value();
  860. if (value == nullptr) {
  861. return nullptr;
  862. }
  863. auto func_graph = value->cast<FuncGraphPtr>();
  864. return func_graph;
  865. }
  866. std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
  867. if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
  868. MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
  869. }
  870. MS_EXCEPTION_IF_NULL(call_node);
  871. auto input1 = call_node->input(1);
  872. MS_EXCEPTION_IF_NULL(input1);
  873. if (input1->isa<ValueNode>()) {
  874. auto value_node = input1->cast<ValueNodePtr>();
  875. MS_EXCEPTION_IF_NULL(value_node);
  876. auto kernel_graph = value_node->value();
  877. MS_EXCEPTION_IF_NULL(kernel_graph);
  878. return {kernel_graph->cast<KernelGraphPtr>()};
  879. } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
  880. auto switch_node = input1->cast<CNodePtr>();
  881. MS_EXCEPTION_IF_NULL(switch_node);
  882. auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
  883. auto partial = switch_node->input(input_index);
  884. MS_EXCEPTION_IF_NULL(partial);
  885. auto partial_cnode = partial->cast<CNodePtr>();
  886. MS_EXCEPTION_IF_NULL(partial_cnode);
  887. auto graph_node = partial_cnode->input(1);
  888. MS_EXCEPTION_IF_NULL(graph_node);
  889. auto graph_value_node = graph_node->cast<ValueNodePtr>();
  890. MS_EXCEPTION_IF_NULL(graph_value_node);
  891. auto graph_value = graph_value_node->value();
  892. MS_EXCEPTION_IF_NULL(graph_value);
  893. auto child_graph = graph_value->cast<KernelGraphPtr>();
  894. return child_graph;
  895. };
  896. return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
  897. }
  898. return {};
  899. }
  900. bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
  901. MS_EXCEPTION_IF_NULL(call_node);
  902. if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
  903. MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
  904. }
  905. auto input1 = call_node->input(1);
  906. if (input1->isa<ValueNode>()) {
  907. return false;
  908. } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
  909. return true;
  910. }
  911. MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
  912. }
  913. bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
  914. auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
  915. if (shape.empty()) {
  916. return true;
  917. }
  918. return shape.size() == kShape1dDims && shape[0] == 1;
  919. }
  920. bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
  921. auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
  922. if (shape.empty()) {
  923. return true;
  924. }
  925. return shape.size() == kShape1dDims && shape[0] == 1;
  926. }
  927. } // namespace session
  928. } // namespace mindspore