You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.cc 40 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/anf_runtime_algorithm.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include <map>
  20. #include <set>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "operator/ops.h"
  24. #include "utils/utils.h"
  25. #include "device/kernel_info.h"
  26. #include "device/device_address.h"
  27. #include "pre_activate/common/helper.h"
  28. #include "kernel/kernel.h"
  29. #include "kernel/kernel_build_info.h"
  30. #include "common/utils.h"
  31. #include "common/trans.h"
  32. namespace mindspore {
  33. namespace session {
  34. using abstract::AbstractTensor;
  35. using abstract::AbstractTuple;
  36. using device::KernelInfo;
  37. using device::ascend::AscendDeviceAddress;
  38. using kernel::KernelBuildInfoPtr;
  39. using kernel::KernelMod;
  40. using kernel::KernelModPtr;
  41. namespace {
  42. std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
  43. MS_EXCEPTION_IF_NULL(shape);
  44. std::vector<size_t> shape_size_t;
  45. std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
  46. return shape_size_t;
  47. }
  48. } // namespace
  49. KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
  50. MS_EXCEPTION_IF_NULL(anf_node);
  51. if (anf_node->isa<ValueNode>()) {
  52. return std::make_pair(anf_node, 0);
  53. } else if (anf_node->isa<Parameter>()) {
  54. return std::make_pair(anf_node, 0);
  55. } else if (anf_node->isa<CNode>()) {
  56. auto cnode = anf_node->cast<CNodePtr>();
  57. MS_EXCEPTION_IF_NULL(cnode);
  58. auto input0 = cnode->input(0);
  59. MS_EXCEPTION_IF_NULL(input0);
  60. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  61. auto node = cnode->input(index + IntToSize(1));
  62. MS_EXCEPTION_IF_NULL(node);
  63. return VisitKernel(node, 0);
  64. } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  65. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  66. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  67. }
  68. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  69. MS_EXCEPTION_IF_NULL(input2);
  70. auto value_node = input2->cast<ValueNodePtr>();
  71. MS_EXCEPTION_IF_NULL(value_node);
  72. int item_idx = GetValue<int>(value_node->value());
  73. return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
  74. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  75. return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
  76. } else {
  77. return std::make_pair(anf_node, index);
  78. }
  79. } else {
  80. MS_LOG(EXCEPTION) << "The input is invalid";
  81. }
  82. }
  83. KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
  84. bool visit_nop_node,
  85. const std::vector<PrimitivePtr> &return_types) {
  86. MS_EXCEPTION_IF_NULL(anf_node);
  87. for (const auto &prim_type : return_types) {
  88. if (CheckPrimitiveType(anf_node, prim_type)) {
  89. return std::make_pair(anf_node, index);
  90. }
  91. }
  92. if (anf_node->isa<ValueNode>()) {
  93. return std::make_pair(anf_node, 0);
  94. } else if (anf_node->isa<Parameter>()) {
  95. return std::make_pair(anf_node, 0);
  96. } else if (anf_node->isa<CNode>()) {
  97. auto cnode = anf_node->cast<CNodePtr>();
  98. MS_EXCEPTION_IF_NULL(cnode);
  99. auto input0 = cnode->input(0);
  100. MS_EXCEPTION_IF_NULL(input0);
  101. if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  102. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  103. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  104. }
  105. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  106. MS_EXCEPTION_IF_NULL(input2);
  107. auto value_node = input2->cast<ValueNodePtr>();
  108. MS_EXCEPTION_IF_NULL(value_node);
  109. int item_idx = GetValue<int>(value_node->value());
  110. return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
  111. visit_nop_node, return_types);
  112. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  113. return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
  114. } else if (opt::IsNopNode(cnode) && visit_nop_node) {
  115. if (cnode->inputs().size() == 2) {
  116. return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
  117. } else {
  118. MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
  119. }
  120. } else {
  121. return std::make_pair(anf_node, index);
  122. }
  123. } else {
  124. MS_LOG(EXCEPTION) << "The input is invalid";
  125. }
  126. }
  127. std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
  128. const std::vector<PrimitivePtr> &return_types) {
  129. std::vector<AnfNodePtr> ret;
  130. auto return_prim_type = return_types;
  131. // if visited make_tuple should return back
  132. return_prim_type.push_back(prim::kPrimMakeTuple);
  133. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
  134. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  135. MS_EXCEPTION_IF_NULL(item_with_index.first);
  136. auto make_tuple = item_with_index.first->cast<CNodePtr>();
  137. MS_EXCEPTION_IF_NULL(make_tuple);
  138. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  139. auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
  140. (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
  141. }
  142. return ret;
  143. }
  144. ret.push_back(item_with_index.first);
  145. return ret;
  146. }
  147. AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
  148. MS_EXCEPTION_IF_NULL(node);
  149. return node->input(kAnfPrimitiveIndex);
  150. }
  151. PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
  152. MS_EXCEPTION_IF_NULL(node);
  153. auto cnode = node->cast<CNodePtr>();
  154. MS_EXCEPTION_IF_NULL(cnode);
  155. auto attr_input = GetCNodePrimitiveNode(cnode);
  156. MS_EXCEPTION_IF_NULL(attr_input);
  157. auto value_node = attr_input->cast<ValueNodePtr>();
  158. MS_EXCEPTION_IF_NULL(value_node);
  159. auto value = value_node->value();
  160. MS_EXCEPTION_IF_NULL(value);
  161. auto primitive = value->cast<PrimitivePtr>();
  162. return primitive;
  163. }
  164. bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. if (!node->isa<CNode>()) {
  167. return false;
  168. }
  169. auto cnode = node->cast<CNodePtr>();
  170. MS_EXCEPTION_IF_NULL(cnode);
  171. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  172. }
  173. std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
  174. MS_EXCEPTION_IF_NULL(node);
  175. if (node->isa<CNode>()) {
  176. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  177. MS_EXCEPTION_IF_NULL(primitive);
  178. return primitive->name();
  179. }
  180. MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
  181. }
  182. std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
  183. MS_EXCEPTION_IF_NULL(node);
  184. return node->DebugString();
  185. }
  186. void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
  187. MS_EXCEPTION_IF_NULL(node);
  188. if (!node->isa<CNode>()) {
  189. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  190. }
  191. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  192. MS_EXCEPTION_IF_NULL(primitive);
  193. primitive->set_attr(key, value);
  194. }
  195. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
  196. CopyNodeAttr(key, key, from, to);
  197. }
  198. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  199. const AnfNodePtr &to) {
  200. MS_EXCEPTION_IF_NULL(from);
  201. MS_EXCEPTION_IF_NULL(to);
  202. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  203. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
  204. << to->DebugString();
  205. }
  206. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  207. MS_EXCEPTION_IF_NULL(from_primitive);
  208. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  209. MS_EXCEPTION_IF_NULL(to_primitive);
  210. to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
  211. }
  212. void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
  213. MS_EXCEPTION_IF_NULL(from);
  214. MS_EXCEPTION_IF_NULL(to);
  215. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  216. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
  217. << from->DebugString();
  218. }
  219. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  220. MS_EXCEPTION_IF_NULL(from_primitive);
  221. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  222. MS_EXCEPTION_IF_NULL(to_primitive);
  223. (void)to_primitive->SetAttrs(from_primitive->attrs());
  224. }
  225. void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
  226. MS_EXCEPTION_IF_NULL(node);
  227. if (!node->isa<CNode>()) {
  228. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  229. }
  230. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  231. MS_EXCEPTION_IF_NULL(primitive);
  232. primitive->EraseAttr(key);
  233. }
  234. bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
  235. MS_EXCEPTION_IF_NULL(node);
  236. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  237. MS_EXCEPTION_IF_NULL(primitive);
  238. return primitive->HasAttr(key);
  239. }
  240. size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
  241. MS_EXCEPTION_IF_NULL(node);
  242. if (!node->isa<CNode>()) {
  243. MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
  244. }
  245. auto cnode = node->cast<CNodePtr>();
  246. MS_EXCEPTION_IF_NULL(cnode);
  247. size_t input_num = cnode->inputs().size();
  248. if (input_num == 0) {
  249. MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
  250. }
  251. // exclude intputs[0],which is value_node storing attr,inputs left are real input
  252. return input_num - 1;
  253. }
  254. size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
  255. MS_EXCEPTION_IF_NULL(node);
  256. TypePtr type = node->Type();
  257. if (type == nullptr) {
  258. return 0;
  259. }
  260. if (type->isa<Tuple>()) {
  261. auto tuple_type = type->cast<TuplePtr>();
  262. MS_EXCEPTION_IF_NULL(tuple_type);
  263. return tuple_type->size();
  264. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  265. return 1;
  266. } else if (type->isa<TypeNone>()) {
  267. return 0;
  268. } else {
  269. return 1;
  270. }
  271. }
  272. std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
  273. MS_EXCEPTION_IF_NULL(node);
  274. if (output_idx > GetOutputTensorNum(node)) {
  275. MS_LOG(EXCEPTION) << "Output index:" << output_idx
  276. << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
  277. << node->DebugString() << "]";
  278. }
  279. if (!AnfAlgo::IsRealKernel(node)) {
  280. return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
  281. }
  282. auto kernel_info = node->kernel_info();
  283. MS_EXCEPTION_IF_NULL(kernel_info);
  284. auto build_info = kernel_info->select_kernel_build_info();
  285. MS_EXCEPTION_IF_NULL(build_info);
  286. auto format = build_info->GetOutputFormat(output_idx);
  287. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  288. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  289. << " has a invalid output format";
  290. }
  291. return format;
  292. }
  293. std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
  294. MS_EXCEPTION_IF_NULL(node);
  295. if (input_idx > GetInputTensorNum(node)) {
  296. MS_LOG(EXCEPTION) << "Input index :" << input_idx
  297. << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
  298. << node->DebugString() << "]";
  299. }
  300. if (!IsRealKernel(node)) {
  301. GetPrevNodeOutputFormat(node, input_idx);
  302. }
  303. auto kernel_info = node->kernel_info();
  304. MS_EXCEPTION_IF_NULL(kernel_info);
  305. auto build_info = kernel_info->select_kernel_build_info();
  306. MS_EXCEPTION_IF_NULL(build_info);
  307. auto format = build_info->GetInputFormat(input_idx);
  308. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  309. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  310. << " has a invalid input format";
  311. }
  312. return format;
  313. }
  314. KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
  315. MS_EXCEPTION_IF_NULL(anf_node);
  316. if (!anf_node->isa<CNode>()) {
  317. MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
  318. }
  319. auto cnode = anf_node->cast<CNodePtr>();
  320. MS_EXCEPTION_IF_NULL(cnode);
  321. if (input_idx + 1 >= cnode->inputs().size()) {
  322. MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
  323. }
  324. auto node = cnode->input(input_idx + 1);
  325. MS_EXCEPTION_IF_NULL(node);
  326. return VisitKernel(node, 0);
  327. }
  328. std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
  329. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  330. return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
  331. }
  332. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  333. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  334. return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
  335. }
  336. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
  337. MS_EXCEPTION_IF_NULL(node);
  338. abstract::BaseShapePtr base_shape = node->Shape();
  339. MS_EXCEPTION_IF_NULL(base_shape);
  340. if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
  341. return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
  342. } else if (base_shape->isa<abstract::TupleShape>()) {
  343. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  344. MS_EXCEPTION_IF_NULL(tuple_shape);
  345. if (output_idx >= tuple_shape->size()) {
  346. MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
  347. << ".";
  348. }
  349. auto b_shp = (*tuple_shape)[output_idx];
  350. if (b_shp->isa<abstract::Shape>()) {
  351. return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
  352. } else if (b_shp->isa<abstract::NoShape>()) {
  353. return std::vector<size_t>();
  354. } else {
  355. MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
  356. << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
  357. }
  358. } else if (base_shape->isa<abstract::NoShape>()) {
  359. return std::vector<size_t>();
  360. }
  361. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  362. << base_shape->ToString();
  363. }
  364. std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
  365. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  366. return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
  367. }
  368. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
  369. auto format = GetOutputFormat(node, output_idx);
  370. auto infer_shape = GetOutputInferShape(node, output_idx);
  371. if (infer_shape.empty()) {
  372. return infer_shape;
  373. }
  374. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  375. if (trans::IsNeedPadding(format, infer_shape.size())) {
  376. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
  377. }
  378. return trans::TransShapeToDevice(infer_shape, format);
  379. }
  380. std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
  381. auto format = GetInputFormat(node, input_idx);
  382. auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
  383. if (infer_shape.empty()) {
  384. return infer_shape;
  385. }
  386. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  387. if (trans::IsNeedPadding(format, infer_shape.size())) {
  388. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
  389. }
  390. return trans::TransShapeToDevice(infer_shape, format);
  391. }
  392. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  393. MS_EXCEPTION_IF_NULL(node);
  394. if (input_idx > GetInputTensorNum(node)) {
  395. MS_LOG(EXCEPTION) << "The index:" << input_idx
  396. << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
  397. << node->DebugString() << "]";
  398. }
  399. if (!IsRealKernel(node)) {
  400. return GetPrevNodeOutputReshapeType(node, input_idx);
  401. }
  402. auto kernel_info = node->kernel_info();
  403. MS_EXCEPTION_IF_NULL(kernel_info);
  404. auto build_info = kernel_info->select_kernel_build_info();
  405. MS_EXCEPTION_IF_NULL(build_info);
  406. if (build_info->IsInputDefaultPadding()) {
  407. return {};
  408. }
  409. return build_info->GetInputReshapeType(input_idx);
  410. }
  411. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
  412. MS_EXCEPTION_IF_NULL(node);
  413. if (output_idx > GetOutputTensorNum(node)) {
  414. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  415. << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
  416. }
  417. if (!IsRealKernel(node)) {
  418. return GetPrevNodeOutputReshapeType(node, output_idx);
  419. }
  420. auto kernel_info = node->kernel_info();
  421. MS_EXCEPTION_IF_NULL(kernel_info);
  422. auto build_info = kernel_info->select_kernel_build_info();
  423. MS_EXCEPTION_IF_NULL(build_info);
  424. if (build_info->IsOutputDefaultPadding()) {
  425. return {};
  426. }
  427. return build_info->GetOutputReshapeType(output_idx);
  428. }
  429. TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
  430. MS_EXCEPTION_IF_NULL(node);
  431. TypePtr type_ptr = node->Type();
  432. MS_EXCEPTION_IF_NULL(type_ptr);
  433. if (type_ptr->isa<TensorType>() && output_idx == 0) {
  434. auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
  435. MS_EXCEPTION_IF_NULL(tensor_ptr);
  436. TypePtr elem = tensor_ptr->element();
  437. MS_EXCEPTION_IF_NULL(elem);
  438. return elem->type_id();
  439. } else if (type_ptr->isa<Tuple>()) {
  440. auto tuple_ptr = type_ptr->cast<TuplePtr>();
  441. MS_EXCEPTION_IF_NULL(tuple_ptr);
  442. if (output_idx >= tuple_ptr->size()) {
  443. MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
  444. }
  445. auto tuple_i = (*tuple_ptr)[output_idx];
  446. MS_EXCEPTION_IF_NULL(tuple_i);
  447. if (tuple_i->isa<TensorType>()) {
  448. auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
  449. MS_EXCEPTION_IF_NULL(tensor_ptr);
  450. TypePtr elem = tensor_ptr->element();
  451. MS_EXCEPTION_IF_NULL(elem);
  452. return elem->type_id();
  453. } else if (tuple_i->isa<Number>()) {
  454. return tuple_i->type_id();
  455. } else {
  456. MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
  457. return tuple_i->type_id();
  458. }
  459. } else if (type_ptr->isa<Number>()) {
  460. return type_ptr->type_id();
  461. }
  462. return type_ptr->type_id();
  463. }
  464. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
  465. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  466. return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
  467. }
  468. TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
  469. MS_EXCEPTION_IF_NULL(node);
  470. if (output_idx > GetOutputTensorNum(node)) {
  471. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  472. << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  473. }
  474. if (!IsRealKernel(node)) {
  475. return GetPrevNodeOutputDeviceDataType(node, output_idx);
  476. }
  477. auto kernel_info = node->kernel_info();
  478. MS_EXCEPTION_IF_NULL(kernel_info);
  479. auto build_info = kernel_info->select_kernel_build_info();
  480. MS_EXCEPTION_IF_NULL(build_info);
  481. auto dtype = build_info->GetOutputDeviceType(output_idx);
  482. if (dtype == TypeId::kNumberTypeEnd) {
  483. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  484. << " has a invalid dtype";
  485. }
  486. return dtype;
  487. }
  488. TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
  489. MS_EXCEPTION_IF_NULL(node);
  490. if (input_idx > GetInputTensorNum(node)) {
  491. MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
  492. << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  493. }
  494. if (!IsRealKernel(node)) {
  495. return GetPrevNodeOutputDeviceDataType(node, 0);
  496. }
  497. auto kernel_info = node->kernel_info();
  498. MS_EXCEPTION_IF_NULL(kernel_info);
  499. auto build_info = kernel_info->select_kernel_build_info();
  500. MS_EXCEPTION_IF_NULL(build_info);
  501. auto dtype = build_info->GetInputDeviceType(input_idx);
  502. if (dtype == TypeId::kNumberTypeEnd) {
  503. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  504. << " has a invalid dtype";
  505. }
  506. return dtype;
  507. }
  508. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
  509. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  510. return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
  511. }
  512. // get output device addr of anf_node
  513. const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) {
  514. MS_EXCEPTION_IF_NULL(node);
  515. if (opt::IsNopNode(node)) {
  516. auto cnode = node->cast<CNodePtr>();
  517. MS_EXCEPTION_IF_NULL(cnode);
  518. if (cnode->inputs().size() == 2) {
  519. return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
  520. } else {
  521. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
  522. }
  523. }
  524. auto kernel_info = node->kernel_info();
  525. MS_EXCEPTION_IF_NULL(kernel_info);
  526. auto addr = kernel_info->GetOutputAddr(output_idx);
  527. if (addr == nullptr) {
  528. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  529. << " output addr is not exist";
  530. }
  531. return addr;
  532. }
  533. DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) {
  534. MS_EXCEPTION_IF_NULL(node);
  535. if (opt::IsNopNode(node)) {
  536. auto cnode = node->cast<CNodePtr>();
  537. MS_EXCEPTION_IF_NULL(cnode);
  538. if (cnode->inputs().size() == 2) {
  539. return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
  540. } else {
  541. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
  542. }
  543. }
  544. auto kernel_info = node->kernel_info();
  545. MS_EXCEPTION_IF_NULL(kernel_info);
  546. auto addr = kernel_info->GetMutableOutputAddr(output_idx);
  547. if (addr == nullptr) {
  548. MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
  549. << " output addr is not exist";
  550. }
  551. return addr;
  552. }
  553. // get output device addr of anf_node
  554. bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
  555. MS_EXCEPTION_IF_NULL(node);
  556. if (output_idx > GetOutputTensorNum(node)) {
  557. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  558. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  559. }
  560. auto kernel_info = node->kernel_info();
  561. MS_EXCEPTION_IF_NULL(kernel_info);
  562. return kernel_info->OutputAddrExist(output_idx);
  563. }
  564. const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
  565. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  566. return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
  567. }
  568. DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
  569. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  570. return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
  571. }
  572. // set output device addr of anf_node
  573. void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  574. MS_EXCEPTION_IF_NULL(node);
  575. auto kernel_info = node->kernel_info();
  576. MS_EXCEPTION_IF_NULL(kernel_info);
  577. if (!kernel_info->SetOutputAddr(addr, output_idx)) {
  578. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  579. }
  580. }
  581. // set workspace device addr of anf_node
  582. void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  583. MS_EXCEPTION_IF_NULL(node);
  584. auto kernel_info = node->kernel_info();
  585. MS_EXCEPTION_IF_NULL(kernel_info);
  586. if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
  587. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  588. }
  589. }
  590. // get workspace device addr of anf_node
  591. DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
  592. MS_EXCEPTION_IF_NULL(node);
  593. auto kernel_info = node->kernel_info();
  594. MS_EXCEPTION_IF_NULL(kernel_info);
  595. auto addr = kernel_info->GetWorkspaceAddr(output_idx);
  596. if (addr == nullptr) {
  597. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  598. << "] workspace addr is not exist";
  599. }
  600. return addr;
  601. }
  602. // set infer shapes and types of anf node
  603. void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  604. const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
  605. MS_EXCEPTION_IF_NULL(node);
  606. if (types.size() != shapes.size()) {
  607. MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
  608. }
  609. if (shapes.empty()) {
  610. MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
  611. } else if (shapes.size() == 1) {
  612. // single output handle
  613. std::vector<int> shape_int;
  614. std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt);
  615. auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
  616. node->set_abstract(abstract);
  617. } else {
  618. // multiple output handle
  619. std::vector<AbstractBasePtr> abstract_list;
  620. for (size_t i = 0; i < types.size(); ++i) {
  621. std::vector<int> shape_int;
  622. std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt);
  623. abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int));
  624. }
  625. auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
  626. node->set_abstract(abstract_tuple);
  627. }
  628. }
  629. // copy an abstract of a node to another node
  630. void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
  631. to_node->set_abstract(from_node->abstract());
  632. }
  633. kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
  634. MS_EXCEPTION_IF_NULL(node);
  635. auto kernel_info = node->kernel_info();
  636. MS_EXCEPTION_IF_NULL(kernel_info);
  637. // select_kernel_build_info() has checked whether return pointer is null
  638. auto build_info = kernel_info->select_kernel_build_info();
  639. MS_EXCEPTION_IF_NULL(build_info);
  640. return build_info->op_pattern();
  641. }
  642. // get KernelBuildType of node, such as ATT,RT,FWK and so on
  643. KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
  644. MS_EXCEPTION_IF_NULL(node);
  645. auto kernel_info = node->kernel_info();
  646. MS_EXCEPTION_IF_NULL(kernel_info);
  647. // select_kernel_build_info() has checked whether return pointer is null
  648. auto build_info = kernel_info->select_kernel_build_info();
  649. MS_EXCEPTION_IF_NULL(build_info);
  650. return build_info->kernel_type();
  651. }
  652. kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
  653. MS_EXCEPTION_IF_NULL(node);
  654. auto kernel_info = node->kernel_info();
  655. MS_EXCEPTION_IF_NULL(kernel_info);
  656. auto build_info = kernel_info->select_kernel_build_info();
  657. MS_EXCEPTION_IF_NULL(build_info);
  658. return build_info->processor();
  659. }
  660. kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
  661. MS_EXCEPTION_IF_NULL(node);
  662. auto kernel_info = node->kernel_info();
  663. MS_EXCEPTION_IF_NULL(kernel_info);
  664. auto build_info = kernel_info->select_kernel_build_info();
  665. MS_EXCEPTION_IF_NULL(build_info);
  666. return build_info->fusion_type();
  667. }
  668. // set select kernel_build_info
  669. void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
  670. MS_EXCEPTION_IF_NULL(node);
  671. auto kernel_info = node->kernel_info();
  672. MS_EXCEPTION_IF_NULL(kernel_info);
  673. return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
  674. }
  675. // get select kernel_build_info
  676. KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
  677. MS_EXCEPTION_IF_NULL(node);
  678. auto kernel_info = node->kernel_info();
  679. MS_EXCEPTION_IF_NULL(kernel_info);
  680. return kernel_info->GetMutableSelectKernelBuildInfo();
  681. }
  682. // get kernelMode
  683. KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
  684. MS_EXCEPTION_IF_NULL(node);
  685. auto kernel_info = node->kernel_info();
  686. MS_EXCEPTION_IF_NULL(kernel_info);
  687. return kernel_info->MutableKernelMod();
  688. }
  689. // set kernel mod
  690. void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
  691. MS_EXCEPTION_IF_NULL(node);
  692. auto kernel_info = node->kernel_info();
  693. MS_EXCEPTION_IF_NULL(kernel_info);
  694. kernel_info->set_kernel_mod(kernel_mod);
  695. }
  696. bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
  697. MS_EXCEPTION_IF_NULL(node);
  698. // parameter and value node is not a real kernel too
  699. if (!node->isa<CNode>()) {
  700. return true;
  701. }
  702. auto cnode = node->cast<CNodePtr>();
  703. MS_EXCEPTION_IF_NULL(cnode);
  704. if (cnode->inputs().empty()) {
  705. MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  706. }
  707. auto input = cnode->inputs()[0];
  708. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  709. IsPrimitive(input, prim::kPrimTensorSummary) ||
  710. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  711. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  712. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  713. IsPrimitive(input, prim::kPrimReturn);
  714. return !is_virtual_node;
  715. }
  716. bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
  717. MS_EXCEPTION_IF_NULL(node);
  718. // parameter and value node is not a real cnode kernel
  719. if (!node->isa<CNode>()) {
  720. return false;
  721. }
  722. // return considered as a real node
  723. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  724. return true;
  725. }
  726. return IsRealKernel(node);
  727. }
  728. bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
  729. MS_EXCEPTION_IF_NULL(node);
  730. return node->has_default();
  731. }
  732. void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
  733. MS_EXCEPTION_IF_NULL(node);
  734. auto kernel_info = node->kernel_info();
  735. MS_EXCEPTION_IF_NULL(kernel_info);
  736. kernel_info->set_stream_id(stream_id);
  737. }
  738. uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
  739. MS_EXCEPTION_IF_NULL(node);
  740. auto kernel_info = node->kernel_info();
  741. MS_EXCEPTION_IF_NULL(kernel_info);
  742. return kernel_info->stream_id();
  743. }
  744. void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
  745. MS_EXCEPTION_IF_NULL(node);
  746. auto kernel_info = node->kernel_info();
  747. MS_EXCEPTION_IF_NULL(kernel_info);
  748. kernel_info->set_stream_distinction_label(stream_label);
  749. }
  750. uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
  751. MS_EXCEPTION_IF_NULL(node);
  752. auto kernel_info = node->kernel_info();
  753. MS_EXCEPTION_IF_NULL(kernel_info);
  754. return kernel_info->stream_distinction_label();
  755. }
  756. void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
  757. MS_EXCEPTION_IF_NULL(node);
  758. auto kernel_info = node->kernel_info();
  759. MS_EXCEPTION_IF_NULL(kernel_info);
  760. kernel_info->set_graph_id(graph_id);
  761. }
  762. uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
  763. MS_EXCEPTION_IF_NULL(node);
  764. auto kernel_info = node->kernel_info();
  765. MS_EXCEPTION_IF_NULL(kernel_info);
  766. return kernel_info->graph_id();
  767. }
  768. bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
  769. MS_EXCEPTION_IF_NULL(anf);
  770. TypePtr type = anf->Type();
  771. MS_EXCEPTION_IF_NULL(type);
  772. return type->isa<Tuple>();
  773. }
  774. AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
  775. MS_EXCEPTION_IF_NULL(node);
  776. auto get_input_index = index + 1;
  777. if (index + 1 > node->inputs().size()) {
  778. MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
  779. << node->inputs().size();
  780. }
  781. // input 0 is primitive node
  782. return node->input(get_input_index);
  783. }
  784. bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
  785. MS_EXCEPTION_IF_NULL(node);
  786. if (node->isa<ValueNode>()) {
  787. return false;
  788. }
  789. auto kernel_info = node->kernel_info();
  790. MS_EXCEPTION_IF_NULL(kernel_info);
  791. return kernel_info->is_feature_map();
  792. }
  793. bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
  794. if (!node->isa<CNode>()) {
  795. MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
  796. }
  797. auto cnode = node->cast<CNodePtr>();
  798. MS_EXCEPTION_IF_NULL(cnode);
  799. auto input_node = cnode->input(input_index + 1);
  800. return IsFeatureMapOutput(input_node);
  801. }
  802. size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
  803. MS_EXCEPTION_IF_NULL(anf_node);
  804. static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
  805. {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
  806. {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}},
  807. {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  808. {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
  809. {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
  810. {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  811. {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  812. {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  813. {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  814. {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}};
  815. size_t ret = cur_index;
  816. auto node_name = AnfAlgo::GetCNodeName(anf_node);
  817. if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
  818. auto find = spec_node_list.find(node_name);
  819. if (find != spec_node_list.end()) {
  820. ret = find->second[cur_index];
  821. MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
  822. }
  823. }
  824. return ret;
  825. }
  826. void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
  827. MS_EXCEPTION_IF_NULL(node);
  828. MS_EXCEPTION_IF_NULL(input_node);
  829. node->set_input(index + 1, input_node);
  830. }
  831. bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
  832. MS_EXCEPTION_IF_NULL(node);
  833. if (!node->isa<CNode>()) {
  834. return false;
  835. }
  836. auto kernel_name = AnfAlgo::GetCNodeName(node);
  837. if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
  838. kernel_name == kReduceScatterOpName) {
  839. return true;
  840. }
  841. return false;
  842. }
  843. bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
  844. auto kernel_name = AnfAlgo::GetCNodeName(node);
  845. return kernel_name == kGetNextOpName;
  846. }
  847. FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
  848. MS_EXCEPTION_IF_NULL(node);
  849. auto value_node = node->cast<ValueNodePtr>();
  850. if (value_node == nullptr) {
  851. return nullptr;
  852. }
  853. auto value = value_node->value();
  854. if (value == nullptr) {
  855. return nullptr;
  856. }
  857. auto func_graph = value->cast<FuncGraphPtr>();
  858. return func_graph;
  859. }
  860. std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
  861. if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
  862. MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
  863. }
  864. MS_EXCEPTION_IF_NULL(call_node);
  865. auto input1 = call_node->input(1);
  866. MS_EXCEPTION_IF_NULL(input1);
  867. if (input1->isa<ValueNode>()) {
  868. auto value_node = input1->cast<ValueNodePtr>();
  869. MS_EXCEPTION_IF_NULL(value_node);
  870. auto kernel_graph = value_node->value();
  871. MS_EXCEPTION_IF_NULL(kernel_graph);
  872. return {kernel_graph->cast<KernelGraphPtr>()};
  873. } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
  874. auto switch_node = input1->cast<CNodePtr>();
  875. MS_EXCEPTION_IF_NULL(switch_node);
  876. MS_LOG(INFO) << "switch : " << switch_node->DebugString();
  877. auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
  878. auto partial = switch_node->input(input_index);
  879. MS_EXCEPTION_IF_NULL(partial);
  880. auto partial_cnode = partial->cast<CNodePtr>();
  881. MS_EXCEPTION_IF_NULL(partial_cnode);
  882. auto graph_node = partial_cnode->input(1);
  883. MS_EXCEPTION_IF_NULL(graph_node);
  884. MS_LOG(INFO) << graph_node->DebugString();
  885. auto graph_value_node = graph_node->cast<ValueNodePtr>();
  886. MS_EXCEPTION_IF_NULL(graph_value_node);
  887. auto graph_value = graph_value_node->value();
  888. MS_EXCEPTION_IF_NULL(graph_value);
  889. auto child_graph = graph_value->cast<KernelGraphPtr>();
  890. return child_graph;
  891. };
  892. return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
  893. }
  894. return {};
  895. }
  896. bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
  897. MS_EXCEPTION_IF_NULL(call_node);
  898. if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
  899. MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
  900. }
  901. auto input1 = call_node->input(1);
  902. if (input1->isa<ValueNode>()) {
  903. return false;
  904. } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
  905. return true;
  906. }
  907. MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
  908. }
  909. } // namespace session
  910. } // namespace mindspore