You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_runtime_algorithm.cc 44 kB

6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/anf_runtime_algorithm.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include <map>
  20. #include <set>
  21. #include "ir/anf.h"
  22. #include "ir/func_graph.h"
  23. #include "operator/ops.h"
  24. #include "utils/utils.h"
  25. #include "device/kernel_info.h"
  26. #include "device/device_address.h"
  27. #include "pre_activate/common/helper.h"
  28. #include "kernel/kernel.h"
  29. #include "kernel/kernel_build_info.h"
  30. #include "common/utils.h"
  31. #include "common/trans.h"
  32. namespace mindspore {
  33. namespace session {
  34. using abstract::AbstractTensor;
  35. using abstract::AbstractTuple;
  36. using device::KernelInfo;
  37. using device::ascend::AscendDeviceAddress;
  38. using kernel::KernelBuildInfoPtr;
  39. using kernel::KernelMod;
  40. using kernel::KernelModPtr;
  41. namespace {
  42. std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
  43. MS_EXCEPTION_IF_NULL(shape);
  44. std::vector<size_t> shape_size_t;
  45. std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
  46. return shape_size_t;
  47. }
  48. } // namespace
  49. KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
  50. MS_EXCEPTION_IF_NULL(anf_node);
  51. if (anf_node->isa<ValueNode>()) {
  52. return std::make_pair(anf_node, 0);
  53. } else if (anf_node->isa<Parameter>()) {
  54. return std::make_pair(anf_node, 0);
  55. } else if (anf_node->isa<CNode>()) {
  56. auto cnode = anf_node->cast<CNodePtr>();
  57. MS_EXCEPTION_IF_NULL(cnode);
  58. auto input0 = cnode->input(0);
  59. MS_EXCEPTION_IF_NULL(input0);
  60. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  61. auto node = cnode->input(index + IntToSize(1));
  62. MS_EXCEPTION_IF_NULL(node);
  63. return VisitKernel(node, 0);
  64. } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  65. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  66. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  67. }
  68. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  69. MS_EXCEPTION_IF_NULL(input2);
  70. auto value_node = input2->cast<ValueNodePtr>();
  71. MS_EXCEPTION_IF_NULL(value_node);
  72. int item_idx = GetValue<int>(value_node->value());
  73. return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
  74. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  75. return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
  76. } else {
  77. return std::make_pair(anf_node, index);
  78. }
  79. } else {
  80. MS_LOG(EXCEPTION) << "The input is invalid";
  81. }
  82. }
  83. KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
  84. bool visit_nop_node,
  85. const std::vector<PrimitivePtr> &return_types) {
  86. MS_EXCEPTION_IF_NULL(anf_node);
  87. for (const auto &prim_type : return_types) {
  88. if (CheckPrimitiveType(anf_node, prim_type)) {
  89. return std::make_pair(anf_node, index);
  90. }
  91. }
  92. if (anf_node->isa<ValueNode>()) {
  93. return std::make_pair(anf_node, 0);
  94. } else if (anf_node->isa<Parameter>()) {
  95. return std::make_pair(anf_node, 0);
  96. } else if (anf_node->isa<CNode>()) {
  97. auto cnode = anf_node->cast<CNodePtr>();
  98. MS_EXCEPTION_IF_NULL(cnode);
  99. auto input0 = cnode->input(0);
  100. MS_EXCEPTION_IF_NULL(input0);
  101. if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
  102. if (cnode->inputs().size() != kTupleGetItemInputSize) {
  103. MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
  104. }
  105. auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
  106. MS_EXCEPTION_IF_NULL(input2);
  107. auto value_node = input2->cast<ValueNodePtr>();
  108. MS_EXCEPTION_IF_NULL(value_node);
  109. int item_idx = GetValue<int>(value_node->value());
  110. return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
  111. visit_nop_node, return_types);
  112. } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
  113. return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
  114. } else if (opt::IsNopNode(cnode) && visit_nop_node) {
  115. if (cnode->inputs().size() == 2) {
  116. return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
  117. } else {
  118. MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
  119. }
  120. } else {
  121. return std::make_pair(anf_node, index);
  122. }
  123. } else {
  124. MS_LOG(EXCEPTION) << "The input is invalid";
  125. }
  126. }
  127. std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
  128. const std::vector<PrimitivePtr> &return_types) {
  129. std::vector<AnfNodePtr> ret;
  130. auto return_prim_type = return_types;
  131. // if visited make_tuple should return back
  132. return_prim_type.push_back(prim::kPrimMakeTuple);
  133. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
  134. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  135. MS_EXCEPTION_IF_NULL(item_with_index.first);
  136. auto make_tuple = item_with_index.first->cast<CNodePtr>();
  137. MS_EXCEPTION_IF_NULL(make_tuple);
  138. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  139. auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
  140. (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
  141. }
  142. return ret;
  143. }
  144. ret.push_back(item_with_index.first);
  145. return ret;
  146. }
  147. AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
  148. MS_EXCEPTION_IF_NULL(node);
  149. return node->input(kAnfPrimitiveIndex);
  150. }
  151. PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
  152. MS_EXCEPTION_IF_NULL(node);
  153. auto cnode = node->cast<CNodePtr>();
  154. MS_EXCEPTION_IF_NULL(cnode);
  155. auto attr_input = GetCNodePrimitiveNode(cnode);
  156. MS_EXCEPTION_IF_NULL(attr_input);
  157. auto value_node = attr_input->cast<ValueNodePtr>();
  158. MS_EXCEPTION_IF_NULL(value_node);
  159. auto value = value_node->value();
  160. MS_EXCEPTION_IF_NULL(value);
  161. auto primitive = value->cast<PrimitivePtr>();
  162. return primitive;
  163. }
  164. bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  165. MS_EXCEPTION_IF_NULL(node);
  166. if (!node->isa<CNode>()) {
  167. return false;
  168. }
  169. auto cnode = node->cast<CNodePtr>();
  170. MS_EXCEPTION_IF_NULL(cnode);
  171. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  172. }
  173. FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) {
  174. MS_EXCEPTION_IF_NULL(node);
  175. auto cnode = node->cast<CNodePtr>();
  176. MS_EXCEPTION_IF_NULL(cnode);
  177. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  178. MS_EXCEPTION_IF_NULL(attr_input);
  179. auto value_node = attr_input->cast<ValueNodePtr>();
  180. MS_EXCEPTION_IF_NULL(value_node);
  181. auto value = value_node->value();
  182. MS_EXCEPTION_IF_NULL(value);
  183. return value->cast<FuncGraphPtr>();
  184. }
  185. std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
  186. MS_EXCEPTION_IF_NULL(node);
  187. if (node->isa<CNode>()) {
  188. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  189. if (primitive != nullptr) {
  190. return primitive->name();
  191. }
  192. auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
  193. MS_EXCEPTION_IF_NULL(func_graph);
  194. return func_graph->ToString();
  195. }
  196. MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
  197. }
  198. std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
  199. MS_EXCEPTION_IF_NULL(node);
  200. return node->DebugString();
  201. }
  202. void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
  203. MS_EXCEPTION_IF_NULL(node);
  204. if (!node->isa<CNode>()) {
  205. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  206. }
  207. // single op cnode.
  208. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  209. if (primitive != nullptr) {
  210. primitive->set_attr(key, value);
  211. return;
  212. }
  213. // graph kernel cnode.
  214. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
  215. MS_EXCEPTION_IF_NULL(fg);
  216. fg->set_attr(key, value);
  217. }
  218. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
  219. CopyNodeAttr(key, key, from, to);
  220. }
  221. void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
  222. const AnfNodePtr &to) {
  223. MS_EXCEPTION_IF_NULL(from);
  224. MS_EXCEPTION_IF_NULL(to);
  225. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  226. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
  227. << to->DebugString();
  228. }
  229. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  230. MS_EXCEPTION_IF_NULL(from_primitive);
  231. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  232. MS_EXCEPTION_IF_NULL(to_primitive);
  233. to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
  234. }
  235. void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
  236. MS_EXCEPTION_IF_NULL(from);
  237. MS_EXCEPTION_IF_NULL(to);
  238. if (!from->isa<CNode>() || !to->isa<CNode>()) {
  239. MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
  240. << from->DebugString();
  241. }
  242. auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
  243. MS_EXCEPTION_IF_NULL(from_primitive);
  244. auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
  245. MS_EXCEPTION_IF_NULL(to_primitive);
  246. (void)to_primitive->SetAttrs(from_primitive->attrs());
  247. }
  248. void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
  249. MS_EXCEPTION_IF_NULL(node);
  250. if (!node->isa<CNode>()) {
  251. MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
  252. }
  253. // single op cnode.
  254. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  255. if (primitive != nullptr) {
  256. primitive->EraseAttr(key);
  257. return;
  258. }
  259. // graph kernel cnode.
  260. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
  261. MS_EXCEPTION_IF_NULL(fg);
  262. fg->erase_flag(key);
  263. }
  264. bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
  265. MS_EXCEPTION_IF_NULL(node);
  266. if (!node->isa<CNode>()) {
  267. MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString();
  268. return false;
  269. }
  270. // single op cnode.
  271. auto primitive = AnfAlgo::GetCNodePrimitive(node);
  272. if (primitive != nullptr) {
  273. return primitive->HasAttr(key);
  274. }
  275. // graph kernel cnode.
  276. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
  277. MS_EXCEPTION_IF_NULL(fg);
  278. return fg->has_flag(key);
  279. }
  280. size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
  281. MS_EXCEPTION_IF_NULL(node);
  282. if (!node->isa<CNode>()) {
  283. MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
  284. }
  285. auto cnode = node->cast<CNodePtr>();
  286. MS_EXCEPTION_IF_NULL(cnode);
  287. size_t input_num = cnode->inputs().size();
  288. if (input_num == 0) {
  289. MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
  290. }
  291. // exclude intputs[0],which is value_node storing attr,inputs left are real input
  292. return input_num - 1;
  293. }
  294. size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
  295. MS_EXCEPTION_IF_NULL(node);
  296. TypePtr type = node->Type();
  297. if (type == nullptr) {
  298. return 0;
  299. }
  300. if (type->isa<Tuple>()) {
  301. auto tuple_type = type->cast<TuplePtr>();
  302. MS_EXCEPTION_IF_NULL(tuple_type);
  303. return tuple_type->size();
  304. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  305. return 1;
  306. } else if (type->isa<TypeNone>()) {
  307. return 0;
  308. } else {
  309. return 1;
  310. }
  311. }
  312. std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
  313. MS_EXCEPTION_IF_NULL(node);
  314. if (output_idx > GetOutputTensorNum(node)) {
  315. MS_LOG(EXCEPTION) << "Output index:" << output_idx
  316. << " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
  317. << node->DebugString() << "]";
  318. }
  319. if (!AnfAlgo::IsRealKernel(node)) {
  320. return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
  321. }
  322. auto kernel_info = node->kernel_info();
  323. MS_EXCEPTION_IF_NULL(kernel_info);
  324. auto build_info = kernel_info->select_kernel_build_info();
  325. MS_EXCEPTION_IF_NULL(build_info);
  326. auto format = build_info->GetOutputFormat(output_idx);
  327. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  328. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  329. << " has a invalid output format";
  330. }
  331. return format;
  332. }
  333. std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
  334. MS_EXCEPTION_IF_NULL(node);
  335. if (input_idx > GetInputTensorNum(node)) {
  336. MS_LOG(EXCEPTION) << "Input index :" << input_idx
  337. << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
  338. << node->DebugString() << "]";
  339. }
  340. if (!IsRealKernel(node)) {
  341. GetPrevNodeOutputFormat(node, input_idx);
  342. }
  343. auto kernel_info = node->kernel_info();
  344. MS_EXCEPTION_IF_NULL(kernel_info);
  345. auto build_info = kernel_info->select_kernel_build_info();
  346. MS_EXCEPTION_IF_NULL(build_info);
  347. auto format = build_info->GetInputFormat(input_idx);
  348. if (format == kernel::KernelBuildInfo::kInvalidFormat) {
  349. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  350. << " has a invalid input format";
  351. }
  352. return format;
  353. }
  354. KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
  355. MS_EXCEPTION_IF_NULL(anf_node);
  356. if (!anf_node->isa<CNode>()) {
  357. MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
  358. }
  359. auto cnode = anf_node->cast<CNodePtr>();
  360. MS_EXCEPTION_IF_NULL(cnode);
  361. if (input_idx + 1 >= cnode->inputs().size()) {
  362. MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
  363. }
  364. auto node = cnode->input(input_idx + 1);
  365. MS_EXCEPTION_IF_NULL(node);
  366. return VisitKernel(node, 0);
  367. }
  368. std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
  369. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  370. return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
  371. }
  372. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  373. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  374. return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
  375. }
  376. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
  377. MS_EXCEPTION_IF_NULL(node);
  378. abstract::BaseShapePtr base_shape = node->Shape();
  379. MS_EXCEPTION_IF_NULL(base_shape);
  380. if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
  381. return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
  382. } else if (base_shape->isa<abstract::TupleShape>()) {
  383. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  384. MS_EXCEPTION_IF_NULL(tuple_shape);
  385. if (output_idx >= tuple_shape->size()) {
  386. MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
  387. << ".";
  388. }
  389. auto b_shp = (*tuple_shape)[output_idx];
  390. if (b_shp->isa<abstract::Shape>()) {
  391. return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
  392. } else if (b_shp->isa<abstract::NoShape>()) {
  393. return std::vector<size_t>();
  394. } else {
  395. MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
  396. << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
  397. }
  398. } else if (base_shape->isa<abstract::NoShape>()) {
  399. return std::vector<size_t>();
  400. }
  401. MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
  402. << base_shape->ToString();
  403. }
  404. std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
  405. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  406. return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
  407. }
  408. std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
  409. auto format = GetOutputFormat(node, output_idx);
  410. auto infer_shape = GetOutputInferShape(node, output_idx);
  411. if (infer_shape.empty()) {
  412. return infer_shape;
  413. }
  414. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  415. if (trans::IsNeedPadding(format, infer_shape.size())) {
  416. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
  417. }
  418. return trans::TransShapeToDevice(infer_shape, format);
  419. }
  420. std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
  421. auto format = GetInputFormat(node, input_idx);
  422. auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
  423. if (infer_shape.empty()) {
  424. return infer_shape;
  425. }
  426. // if format is default_format or NC1KHKWHWC0,device shape = original shape
  427. if (trans::IsNeedPadding(format, infer_shape.size())) {
  428. infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
  429. }
  430. return trans::TransShapeToDevice(infer_shape, format);
  431. }
  432. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
  433. MS_EXCEPTION_IF_NULL(node);
  434. if (input_idx > GetInputTensorNum(node)) {
  435. MS_LOG(EXCEPTION) << "The index:" << input_idx
  436. << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
  437. << node->DebugString() << "]";
  438. }
  439. if (!IsRealKernel(node)) {
  440. return GetPrevNodeOutputReshapeType(node, input_idx);
  441. }
  442. auto kernel_info = node->kernel_info();
  443. MS_EXCEPTION_IF_NULL(kernel_info);
  444. auto build_info = kernel_info->select_kernel_build_info();
  445. MS_EXCEPTION_IF_NULL(build_info);
  446. if (build_info->IsInputDefaultPadding()) {
  447. return {};
  448. }
  449. return build_info->GetInputReshapeType(input_idx);
  450. }
  451. std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
  452. MS_EXCEPTION_IF_NULL(node);
  453. if (output_idx > GetOutputTensorNum(node)) {
  454. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  455. << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
  456. }
  457. if (!IsRealKernel(node)) {
  458. return GetPrevNodeOutputReshapeType(node, output_idx);
  459. }
  460. auto kernel_info = node->kernel_info();
  461. MS_EXCEPTION_IF_NULL(kernel_info);
  462. auto build_info = kernel_info->select_kernel_build_info();
  463. MS_EXCEPTION_IF_NULL(build_info);
  464. if (build_info->IsOutputDefaultPadding()) {
  465. return {};
  466. }
  467. return build_info->GetOutputReshapeType(output_idx);
  468. }
  469. TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
  470. MS_EXCEPTION_IF_NULL(node);
  471. TypePtr type_ptr = node->Type();
  472. MS_EXCEPTION_IF_NULL(type_ptr);
  473. if (type_ptr->isa<TensorType>() && output_idx == 0) {
  474. auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
  475. MS_EXCEPTION_IF_NULL(tensor_ptr);
  476. TypePtr elem = tensor_ptr->element();
  477. MS_EXCEPTION_IF_NULL(elem);
  478. return elem->type_id();
  479. } else if (type_ptr->isa<Tuple>()) {
  480. auto tuple_ptr = type_ptr->cast<TuplePtr>();
  481. MS_EXCEPTION_IF_NULL(tuple_ptr);
  482. if (output_idx >= tuple_ptr->size()) {
  483. MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
  484. }
  485. auto tuple_i = (*tuple_ptr)[output_idx];
  486. MS_EXCEPTION_IF_NULL(tuple_i);
  487. if (tuple_i->isa<TensorType>()) {
  488. auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
  489. MS_EXCEPTION_IF_NULL(tensor_ptr);
  490. TypePtr elem = tensor_ptr->element();
  491. MS_EXCEPTION_IF_NULL(elem);
  492. return elem->type_id();
  493. } else if (tuple_i->isa<Number>()) {
  494. return tuple_i->type_id();
  495. } else {
  496. MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
  497. return tuple_i->type_id();
  498. }
  499. } else if (type_ptr->isa<Number>()) {
  500. return type_ptr->type_id();
  501. }
  502. return type_ptr->type_id();
  503. }
  504. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
  505. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
  506. return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
  507. }
  508. TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
  509. MS_EXCEPTION_IF_NULL(node);
  510. if (output_idx > GetOutputTensorNum(node)) {
  511. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  512. << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  513. }
  514. if (!IsRealKernel(node)) {
  515. return GetPrevNodeOutputDeviceDataType(node, output_idx);
  516. }
  517. auto kernel_info = node->kernel_info();
  518. MS_EXCEPTION_IF_NULL(kernel_info);
  519. auto build_info = kernel_info->select_kernel_build_info();
  520. MS_EXCEPTION_IF_NULL(build_info);
  521. auto dtype = build_info->GetOutputDeviceType(output_idx);
  522. if (dtype == TypeId::kNumberTypeEnd) {
  523. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  524. << " has a invalid dtype";
  525. }
  526. return dtype;
  527. }
  528. TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
  529. MS_EXCEPTION_IF_NULL(node);
  530. if (input_idx > GetInputTensorNum(node)) {
  531. MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
  532. << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
  533. }
  534. if (!IsRealKernel(node)) {
  535. return GetPrevNodeOutputDeviceDataType(node, 0);
  536. }
  537. auto kernel_info = node->kernel_info();
  538. MS_EXCEPTION_IF_NULL(kernel_info);
  539. auto build_info = kernel_info->select_kernel_build_info();
  540. MS_EXCEPTION_IF_NULL(build_info);
  541. auto dtype = build_info->GetInputDeviceType(input_idx);
  542. if (dtype == TypeId::kNumberTypeEnd) {
  543. MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
  544. << " has a invalid dtype";
  545. }
  546. return dtype;
  547. }
  548. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
  549. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  550. return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
  551. }
  552. // get output device addr of anf_node
  553. const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx,
  554. bool visit_nop_node) {
  555. MS_EXCEPTION_IF_NULL(node);
  556. if (opt::IsNopNode(node) && visit_nop_node) {
  557. auto cnode = node->cast<CNodePtr>();
  558. MS_EXCEPTION_IF_NULL(cnode);
  559. if (cnode->inputs().size() == 2) {
  560. return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
  561. } else {
  562. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
  563. }
  564. }
  565. auto kernel_info = node->kernel_info();
  566. MS_EXCEPTION_IF_NULL(kernel_info);
  567. auto addr = kernel_info->GetOutputAddr(output_idx);
  568. if (addr == nullptr) {
  569. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  570. << " output addr is not exist";
  571. }
  572. return addr;
  573. }
  574. DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
  575. bool visit_nop_node) {
  576. MS_EXCEPTION_IF_NULL(node);
  577. if (opt::IsNopNode(node) && visit_nop_node) {
  578. auto cnode = node->cast<CNodePtr>();
  579. MS_EXCEPTION_IF_NULL(cnode);
  580. if (cnode->inputs().size() == 2) {
  581. return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
  582. } else {
  583. MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
  584. }
  585. }
  586. auto kernel_info = node->kernel_info();
  587. MS_EXCEPTION_IF_NULL(kernel_info);
  588. auto addr = kernel_info->GetMutableOutputAddr(output_idx);
  589. if (addr == nullptr) {
  590. MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
  591. << " output addr is not exist";
  592. }
  593. return addr;
  594. }
  595. // get output device addr of anf_node
  596. bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
  597. MS_EXCEPTION_IF_NULL(node);
  598. if (output_idx > GetOutputTensorNum(node)) {
  599. MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
  600. << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
  601. }
  602. auto kernel_info = node->kernel_info();
  603. MS_EXCEPTION_IF_NULL(kernel_info);
  604. return kernel_info->OutputAddrExist(output_idx);
  605. }
  606. const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
  607. bool visit_nop_node) {
  608. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  609. return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
  610. }
  611. DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
  612. bool visit_nop_node) {
  613. KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
  614. return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
  615. }
  616. // set output device addr of anf_node
  617. void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  618. MS_EXCEPTION_IF_NULL(node);
  619. auto kernel_info = node->kernel_info();
  620. MS_EXCEPTION_IF_NULL(kernel_info);
  621. if (!kernel_info->SetOutputAddr(addr, output_idx)) {
  622. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  623. }
  624. }
  625. // set workspace device addr of anf_node
  626. void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
  627. MS_EXCEPTION_IF_NULL(node);
  628. auto kernel_info = node->kernel_info();
  629. MS_EXCEPTION_IF_NULL(kernel_info);
  630. if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
  631. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
  632. }
  633. }
  634. // get workspace device addr of anf_node
  635. DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
  636. MS_EXCEPTION_IF_NULL(node);
  637. auto kernel_info = node->kernel_info();
  638. MS_EXCEPTION_IF_NULL(kernel_info);
  639. auto addr = kernel_info->GetWorkspaceAddr(output_idx);
  640. if (addr == nullptr) {
  641. MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
  642. << "] workspace addr is not exist";
  643. }
  644. return addr;
  645. }
  646. // set infer shapes and types of anf node
  647. void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
  648. const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
  649. MS_EXCEPTION_IF_NULL(node);
  650. if (types.size() != shapes.size()) {
  651. MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
  652. }
  653. if (shapes.empty()) {
  654. MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
  655. } else if (shapes.size() == 1) {
  656. // single output handle
  657. std::vector<int> shape_int;
  658. std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt);
  659. auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
  660. node->set_abstract(abstract);
  661. } else {
  662. // multiple output handle
  663. std::vector<AbstractBasePtr> abstract_list;
  664. for (size_t i = 0; i < types.size(); ++i) {
  665. std::vector<int> shape_int;
  666. std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt);
  667. abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int));
  668. }
  669. auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
  670. node->set_abstract(abstract_tuple);
  671. }
  672. }
  673. // copy an abstract of a node to another node
  674. void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
  675. to_node->set_abstract(from_node->abstract());
  676. }
  677. kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
  678. MS_EXCEPTION_IF_NULL(node);
  679. auto kernel_info = node->kernel_info();
  680. MS_EXCEPTION_IF_NULL(kernel_info);
  681. // select_kernel_build_info() has checked whether return pointer is null
  682. auto build_info = kernel_info->select_kernel_build_info();
  683. MS_EXCEPTION_IF_NULL(build_info);
  684. return build_info->op_pattern();
  685. }
  686. // get KernelBuildType of node, such as ATT,RT,FWK and so on
  687. KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
  688. MS_EXCEPTION_IF_NULL(node);
  689. auto kernel_info = node->kernel_info();
  690. MS_EXCEPTION_IF_NULL(kernel_info);
  691. // select_kernel_build_info() has checked whether return pointer is null
  692. auto build_info = kernel_info->select_kernel_build_info();
  693. MS_EXCEPTION_IF_NULL(build_info);
  694. return build_info->kernel_type();
  695. }
  696. kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
  697. MS_EXCEPTION_IF_NULL(node);
  698. auto kernel_info = node->kernel_info();
  699. MS_EXCEPTION_IF_NULL(kernel_info);
  700. auto build_info = kernel_info->select_kernel_build_info();
  701. MS_EXCEPTION_IF_NULL(build_info);
  702. return build_info->processor();
  703. }
  704. kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
  705. MS_EXCEPTION_IF_NULL(node);
  706. auto kernel_info = node->kernel_info();
  707. MS_EXCEPTION_IF_NULL(kernel_info);
  708. auto build_info = kernel_info->select_kernel_build_info();
  709. MS_EXCEPTION_IF_NULL(build_info);
  710. return build_info->fusion_type();
  711. }
  712. // set select kernel_build_info
  713. void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
  714. MS_EXCEPTION_IF_NULL(node);
  715. auto kernel_info = node->kernel_info();
  716. MS_EXCEPTION_IF_NULL(kernel_info);
  717. return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
  718. }
  719. // get select kernel_build_info
  720. KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
  721. MS_EXCEPTION_IF_NULL(node);
  722. auto kernel_info = node->kernel_info();
  723. MS_EXCEPTION_IF_NULL(kernel_info);
  724. return kernel_info->GetMutableSelectKernelBuildInfo();
  725. }
  726. // get kernelMode
  727. KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
  728. MS_EXCEPTION_IF_NULL(node);
  729. auto kernel_info = node->kernel_info();
  730. MS_EXCEPTION_IF_NULL(kernel_info);
  731. return kernel_info->MutableKernelMod();
  732. }
  733. // set kernel mod
  734. void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
  735. MS_EXCEPTION_IF_NULL(node);
  736. auto kernel_info = node->kernel_info();
  737. MS_EXCEPTION_IF_NULL(kernel_info);
  738. kernel_info->set_kernel_mod(kernel_mod);
  739. }
  740. bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
  741. MS_EXCEPTION_IF_NULL(node);
  742. // parameter and value node is not a real kernel too
  743. if (!node->isa<CNode>()) {
  744. return true;
  745. }
  746. auto cnode = node->cast<CNodePtr>();
  747. MS_EXCEPTION_IF_NULL(cnode);
  748. if (cnode->inputs().empty()) {
  749. MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
  750. }
  751. auto input = cnode->inputs()[0];
  752. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  753. IsPrimitive(input, prim::kPrimTensorSummary) ||
  754. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  755. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  756. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  757. IsPrimitive(input, prim::kPrimReturn);
  758. return !is_virtual_node;
  759. }
  760. bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
  761. MS_EXCEPTION_IF_NULL(node);
  762. // parameter and value node is not a real cnode kernel
  763. if (!node->isa<CNode>()) {
  764. return false;
  765. }
  766. // return considered as a real node
  767. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  768. return true;
  769. }
  770. return IsRealKernel(node);
  771. }
  772. bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) {
  773. MS_EXCEPTION_IF_NULL(node);
  774. // graph kernel should be a real cnode kernel.
  775. if (!IsRealCNodeKernel(node)) {
  776. return false;
  777. }
  778. auto cnode = node->cast<CNodePtr>();
  779. MS_EXCEPTION_IF_NULL(cnode);
  780. auto input = cnode->input(kAnfPrimitiveIndex);
  781. // graph kernel should has func_graph as first input.
  782. if (!IsValueNode<FuncGraph>(input)) {
  783. return false;
  784. }
  785. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  786. MS_EXCEPTION_IF_NULL(func_graph);
  787. return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
  788. }
  789. bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
  790. MS_EXCEPTION_IF_NULL(node);
  791. return node->has_default();
  792. }
  793. void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
  794. MS_EXCEPTION_IF_NULL(node);
  795. auto kernel_info = node->kernel_info();
  796. MS_EXCEPTION_IF_NULL(kernel_info);
  797. kernel_info->set_stream_id(stream_id);
  798. }
  799. uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
  800. MS_EXCEPTION_IF_NULL(node);
  801. auto kernel_info = node->kernel_info();
  802. MS_EXCEPTION_IF_NULL(kernel_info);
  803. return kernel_info->stream_id();
  804. }
  805. void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
  806. MS_EXCEPTION_IF_NULL(node);
  807. auto kernel_info = node->kernel_info();
  808. MS_EXCEPTION_IF_NULL(kernel_info);
  809. kernel_info->set_stream_distinction_label(stream_label);
  810. }
  811. uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
  812. MS_EXCEPTION_IF_NULL(node);
  813. auto kernel_info = node->kernel_info();
  814. MS_EXCEPTION_IF_NULL(kernel_info);
  815. return kernel_info->stream_distinction_label();
  816. }
  817. void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
  818. MS_EXCEPTION_IF_NULL(node);
  819. auto kernel_info = node->kernel_info();
  820. MS_EXCEPTION_IF_NULL(kernel_info);
  821. kernel_info->set_graph_id(graph_id);
  822. }
  823. uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
  824. MS_EXCEPTION_IF_NULL(node);
  825. auto kernel_info = node->kernel_info();
  826. MS_EXCEPTION_IF_NULL(kernel_info);
  827. return kernel_info->graph_id();
  828. }
  829. bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
  830. MS_EXCEPTION_IF_NULL(anf);
  831. TypePtr type = anf->Type();
  832. MS_EXCEPTION_IF_NULL(type);
  833. return type->isa<Tuple>();
  834. }
  835. AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
  836. MS_EXCEPTION_IF_NULL(node);
  837. auto get_input_index = index + 1;
  838. if (index + 1 > node->inputs().size()) {
  839. MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
  840. << node->inputs().size();
  841. }
  842. // input 0 is primitive node
  843. return node->input(get_input_index);
  844. }
  845. bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
  846. MS_EXCEPTION_IF_NULL(node);
  847. if (node->isa<ValueNode>()) {
  848. return false;
  849. }
  850. auto kernel_info = node->kernel_info();
  851. MS_EXCEPTION_IF_NULL(kernel_info);
  852. return kernel_info->is_feature_map();
  853. }
  854. bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
  855. if (!node->isa<CNode>()) {
  856. MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
  857. }
  858. auto cnode = node->cast<CNodePtr>();
  859. MS_EXCEPTION_IF_NULL(cnode);
  860. auto input_node = cnode->input(input_index + 1);
  861. return IsFeatureMapOutput(input_node);
  862. }
  863. size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
  864. MS_EXCEPTION_IF_NULL(anf_node);
  865. static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
  866. {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
  867. {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}},
  868. {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  869. {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
  870. {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
  871. {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  872. {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
  873. {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
  874. {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  875. {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
  876. {prim::kPrimApplyCenteredRMSProp->name(),
  877. {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}};
  878. size_t ret = cur_index;
  879. auto node_name = AnfAlgo::GetCNodeName(anf_node);
  880. if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
  881. auto find = spec_node_list.find(node_name);
  882. if (find != spec_node_list.end()) {
  883. ret = find->second[cur_index];
  884. MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
  885. }
  886. }
  887. return ret;
  888. }
  889. void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
  890. MS_EXCEPTION_IF_NULL(node);
  891. MS_EXCEPTION_IF_NULL(input_node);
  892. node->set_input(index + 1, input_node);
  893. }
  894. bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
  895. MS_EXCEPTION_IF_NULL(node);
  896. if (!node->isa<CNode>()) {
  897. return false;
  898. }
  899. auto kernel_name = AnfAlgo::GetCNodeName(node);
  900. if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
  901. kernel_name == kReduceScatterOpName) {
  902. return true;
  903. }
  904. return false;
  905. }
  906. bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
  907. auto kernel_name = AnfAlgo::GetCNodeName(node);
  908. return kernel_name == kGetNextOpName;
  909. }
  910. FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
  911. MS_EXCEPTION_IF_NULL(node);
  912. auto value_node = node->cast<ValueNodePtr>();
  913. if (value_node == nullptr) {
  914. return nullptr;
  915. }
  916. auto value = value_node->value();
  917. if (value == nullptr) {
  918. return nullptr;
  919. }
  920. auto func_graph = value->cast<FuncGraphPtr>();
  921. return func_graph;
  922. }
  923. std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
  924. if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
  925. MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
  926. }
  927. MS_EXCEPTION_IF_NULL(call_node);
  928. auto input1 = call_node->input(1);
  929. MS_EXCEPTION_IF_NULL(input1);
  930. if (input1->isa<ValueNode>()) {
  931. auto value_node = input1->cast<ValueNodePtr>();
  932. MS_EXCEPTION_IF_NULL(value_node);
  933. auto kernel_graph = value_node->value();
  934. MS_EXCEPTION_IF_NULL(kernel_graph);
  935. return {kernel_graph->cast<KernelGraphPtr>()};
  936. } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
  937. auto switch_node = input1->cast<CNodePtr>();
  938. MS_EXCEPTION_IF_NULL(switch_node);
  939. auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
  940. auto partial = switch_node->input(input_index);
  941. MS_EXCEPTION_IF_NULL(partial);
  942. auto partial_cnode = partial->cast<CNodePtr>();
  943. MS_EXCEPTION_IF_NULL(partial_cnode);
  944. auto graph_node = partial_cnode->input(1);
  945. MS_EXCEPTION_IF_NULL(graph_node);
  946. auto graph_value_node = graph_node->cast<ValueNodePtr>();
  947. MS_EXCEPTION_IF_NULL(graph_value_node);
  948. auto graph_value = graph_value_node->value();
  949. MS_EXCEPTION_IF_NULL(graph_value);
  950. auto child_graph = graph_value->cast<KernelGraphPtr>();
  951. return child_graph;
  952. };
  953. return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
  954. }
  955. return {};
  956. }
  957. bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
  958. MS_EXCEPTION_IF_NULL(call_node);
  959. if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
  960. MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
  961. }
  962. auto input1 = call_node->input(1);
  963. if (input1->isa<ValueNode>()) {
  964. return false;
  965. } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
  966. return true;
  967. }
  968. MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
  969. }
  970. bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
  971. auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
  972. if (shape.empty()) {
  973. return true;
  974. }
  975. return shape.size() == kShape1dDims && shape[0] == 1;
  976. }
  977. bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
  978. auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
  979. if (shape.empty()) {
  980. return true;
  981. }
  982. return shape.size() == kShape1dDims && shape[0] == 1;
  983. }
  984. void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
  985. std::vector<CNodePtr> all_opt_list;
  986. std::vector<CNodePtr> non_opt_list;
  987. for (const auto &node : *node_list) {
  988. MS_EXCEPTION_IF_NULL(node);
  989. if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
  990. all_opt_list.emplace_back(node);
  991. } else {
  992. non_opt_list.emplace_back(node);
  993. }
  994. }
  995. node_list->clear();
  996. std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
  997. std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
  998. }
  999. TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) {
  1000. MS_EXCEPTION_IF_NULL(node);
  1001. auto prim = AnfAlgo::GetCNodePrimitive(node);
  1002. if (prim == nullptr) {
  1003. return kTypeUnknown;
  1004. }
  1005. TypeId except_type = kTypeUnknown;
  1006. if (prim->GetAttr(kAttrOutputPrecision) != nullptr) {
  1007. auto output_type_str = GetValue<std::string>(prim->GetAttr(kAttrOutputPrecision));
  1008. if (output_type_str == "float16") {
  1009. except_type = kNumberTypeFloat16;
  1010. } else if (output_type_str == "float32") {
  1011. except_type = kNumberTypeFloat32;
  1012. } else {
  1013. MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str;
  1014. }
  1015. }
  1016. return except_type;
  1017. }
  1018. TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) {
  1019. if (!node->isa<CNode>()) {
  1020. MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode.";
  1021. }
  1022. auto cnode = node->cast<CNodePtr>();
  1023. MS_EXCEPTION_IF_NULL(cnode);
  1024. if (input_idx + 1 >= cnode->inputs().size()) {
  1025. MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
  1026. }
  1027. auto input_node = cnode->input(input_idx + 1);
  1028. MS_EXCEPTION_IF_NULL(input_node);
  1029. auto kernel_with_index = VisitKernel(input_node, 0);
  1030. if (!kernel_with_index.first->isa<CNode>()) {
  1031. return kTypeUnknown;
  1032. }
  1033. return GetCNodeOutputPrecision(kernel_with_index.first);
  1034. }
  1035. } // namespace session
  1036. } // namespace mindspore