You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_helper.cc 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/ascend/ascend_helper.h"
  17. #include <set>
  18. #include "common/trans.h"
  19. #include "utils/ms_utils.h"
  20. #include "backend/optimizer/common/helper.h"
  21. #include "utils/utils.h"
  22. #include "runtime/device/kernel_info.h"
  23. #include "backend/kernel_compiler/oplib/oplib.h"
  24. #include "backend/kernel_compiler/common_utils.h"
  25. #include "base/core_ops.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "backend/session/kernel_graph.h"
  28. #include "utils/ms_context.h"
  29. namespace mindspore {
  30. namespace opt {
  31. using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
  32. namespace {
  33. const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
  34. AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
  35. const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
  36. std::vector<AnfNodePtr> trans_inputs;
  37. auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
  38. trans_inputs.emplace_back(NewValueNode(prim));
  39. trans_inputs.emplace_back(input_node);
  40. auto reshape = func_graph->NewCNode(trans_inputs);
  41. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
  42. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
  43. AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
  44. reshape->set_scope(input_node->scope());
  45. kernel_select->SelectKernel(reshape);
  46. return reshape;
  47. }
  48. void SetTransNodeAttr(const CNodePtr &trans_node) {
  49. MS_EXCEPTION_IF_NULL(trans_node);
  50. if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) {
  51. std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
  52. std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
  53. if (input_format == kOpFormat_DEFAULT) {
  54. input_format = kOpFormat_NCHW;
  55. }
  56. if (output_format == kOpFormat_DEFAULT) {
  57. output_format = kOpFormat_NCHW;
  58. }
  59. AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
  60. AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
  61. }
  62. }
  63. AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  64. const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
  65. AnfNodePtr trans_node = nullptr;
  66. CNodePtr trans_data = nullptr;
  67. MS_EXCEPTION_IF_NULL(node);
  68. // Init
  69. AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
  70. std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index);
  71. std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT;
  72. std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
  73. : AnfAlgo::GetOutputReshapeType(node, insert_index);
  74. auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
  75. : AnfAlgo::GetOutputInferShape(input_node, insert_index);
  76. bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
  77. : trans::IsNeedPadding(input_format, input_node_out_shape.size());
  78. if (!need_padding) {
  79. // don't need padding insert transdata only
  80. trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
  81. trans_node = trans_data;
  82. } else if (is_insert_input) {
  83. // if need padding & is input need insert a transdata
  84. // reshape[padding shape] -> transdata[padding shape] -> node
  85. auto padding_shape =
  86. trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index));
  87. auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
  88. trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
  89. trans_node = trans_data;
  90. trans_data->set_abstract(input_node->abstract());
  91. } else {
  92. // if need padding & is output need insert a transdata
  93. // node -> transdata[padding shape] -> reshape[ori_shape]
  94. trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
  95. auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
  96. trans_node = reshape_node;
  97. }
  98. // refresh the transdata's format to ori format & dst format
  99. RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
  100. return trans_node;
  101. }
  102. AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
  103. const KernelSelectPtr &kernel_select) {
  104. MS_EXCEPTION_IF_NULL(node);
  105. auto input_node = AnfAlgo::GetInputNode(node, index);
  106. auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
  107. MS_EXCEPTION_IF_NULL(node_with_index.first);
  108. auto real_input = node_with_index.first;
  109. if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
  110. input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
  111. MS_EXCEPTION_IF_NULL(input_node);
  112. AnfAlgo::SetNodeInput(node, input_node, index);
  113. }
  114. std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
  115. std::string dest_format = AnfAlgo::GetInputFormat(node, index);
  116. if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
  117. MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
  118. << " To DefaultFormat , index: " << index;
  119. return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
  120. }
  121. return input_node;
  122. }
  123. AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  124. const KernelSelectPtr &kernel_select) {
  125. MS_EXCEPTION_IF_NULL(node);
  126. std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
  127. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
  128. if (output_format == kOpFormat_NC1KHKWHWC0) {
  129. MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
  130. << node->DebugString();
  131. }
  132. if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
  133. MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
  134. return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
  135. }
  136. return node;
  137. }
  138. void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) {
  139. MS_EXCEPTION_IF_NULL(node);
  140. if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) {
  141. auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
  142. auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
  143. AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
  144. }
  145. }
  146. AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  147. const KernelSelectPtr &kernel_select) {
  148. MS_EXCEPTION_IF_NULL(func_graph);
  149. MS_EXCEPTION_IF_NULL(node);
  150. std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
  151. auto kernel_graph = func_graph->cast<KernelGraphPtr>();
  152. size_t out_num = AnfAlgo::GetOutputTensorNum(node);
  153. std::string op_name;
  154. if (node->isa<CNode>()) {
  155. op_name = AnfAlgo::GetCNodeName(node);
  156. }
  157. for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
  158. std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
  159. if (output_format == kOpFormat_NC1KHKWHWC0) {
  160. MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
  161. << node->DebugString();
  162. }
  163. auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
  164. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  165. if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) {
  166. auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
  167. ReFreshInferShape(trans_op, op_name);
  168. if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
  169. kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
  170. }
  171. make_tuple_inputs.push_back(trans_op);
  172. } else {
  173. // No need insert trans op.
  174. make_tuple_inputs.push_back(tuple_getitem);
  175. }
  176. }
  177. AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
  178. return make_tuple;
  179. }
  180. } // namespace
  181. void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
  182. const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type,
  183. const TypeId &type_id) {
  184. MS_EXCEPTION_IF_NULL(trans_data);
  185. auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
  186. MS_EXCEPTION_IF_NULL(ori_build_info);
  187. auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
  188. builder->SetInputsFormat({input_format});
  189. builder->SetInputsReshapeType({reshape_type});
  190. builder->SetOutputsReshapeType({reshape_type});
  191. builder->SetOutputsFormat({output_format});
  192. if (type_id != kTypeUnknown) {
  193. builder->SetOutputsDeviceType({type_id});
  194. builder->SetInputsDeviceType({type_id});
  195. }
  196. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
  197. SetTransNodeAttr(trans_data->cast<CNodePtr>());
  198. }
  199. CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
  200. const bool need_padding, const std::string &op_name) {
  201. MS_EXCEPTION_IF_NULL(func_graph);
  202. MS_EXCEPTION_IF_NULL(input);
  203. CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
  204. MS_EXCEPTION_IF_NULL(trans_node);
  205. if (need_padding) {
  206. // if need padding we should set the transdata node's shape to the padding shape
  207. auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
  208. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
  209. {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
  210. trans_node.get());
  211. } else {
  212. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
  213. {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
  214. }
  215. // special handle for ut
  216. if (trans_node->kernel_info() == nullptr) {
  217. auto kernel_info = std::make_shared<device::KernelInfo>();
  218. trans_node->set_kernel_info(kernel_info);
  219. }
  220. MS_EXCEPTION_IF_NULL(kernel_select);
  221. kernel_select->SelectKernel(trans_node);
  222. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
  223. AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node);
  224. MS_EXCEPTION_IF_NULL(trans_node);
  225. trans_node->set_scope(input->scope());
  226. return trans_node;
  227. }
  228. AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
  229. const TypeId &input_type, const TypeId &output_type,
  230. const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
  231. MS_EXCEPTION_IF_NULL(func_graph);
  232. std::string input_format = format;
  233. std::string output_format = format;
  234. CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
  235. MS_EXCEPTION_IF_NULL(cast);
  236. // set kernel build info
  237. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  238. builder.SetInputsFormat({input_format});
  239. builder.SetOutputsFormat({output_format});
  240. builder.SetInputsDeviceType({input_type});
  241. builder.SetOutputsDeviceType({output_type});
  242. builder.SetFusionType(kernel::FusionType::OPAQUE);
  243. builder.SetProcessor(kernel::Processor::AICORE);
  244. if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
  245. builder.SetKernelType(KernelType::TBE_KERNEL);
  246. } else {
  247. builder.SetKernelType(KernelType::AKG_KERNEL);
  248. }
  249. // if kernel info is null , it remarks this function is running ut
  250. if (cast->kernel_info() == nullptr) {
  251. auto kernel_info = std::make_shared<device::KernelInfo>();
  252. cast->set_kernel_info(kernel_info);
  253. }
  254. AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
  255. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
  256. AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
  257. AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
  258. return cast;
  259. }
  260. AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  261. const KernelSelectPtr &kernel_select) {
  262. size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
  263. if (outputs_num == 0) {
  264. return node;
  265. }
  266. auto kernel_graph = func_graph->cast<KernelGraphPtr>();
  267. // Single output
  268. if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
  269. auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
  270. if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) {
  271. kernel_graph->ReplaceInternalOutput(node, new_node);
  272. }
  273. return new_node;
  274. }
  275. // Multiple output
  276. return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
  277. }
  278. AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  279. const KernelSelectPtr &kernel_select) {
  280. MS_EXCEPTION_IF_NULL(node);
  281. auto cnode = node->cast<CNodePtr>();
  282. MS_EXCEPTION_IF_NULL(cnode);
  283. std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  284. size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
  285. for (size_t input_index = 0; input_index < in_num; ++input_index) {
  286. AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
  287. MS_EXCEPTION_IF_NULL(input_node);
  288. new_inputs.push_back(input_node);
  289. }
  290. CNodePtr new_cnode = nullptr;
  291. // cnode changed so make a new cnode to differ from original one.
  292. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  293. if (kernel_graph == nullptr) {
  294. new_cnode = std::make_shared<CNode>(*cnode);
  295. } else {
  296. new_cnode = kernel_graph->NewCNode(cnode);
  297. }
  298. MS_EXCEPTION_IF_NULL(new_cnode);
  299. new_cnode->set_inputs(new_inputs);
  300. return new_cnode;
  301. }
  302. CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  303. MS_EXCEPTION_IF_NULL(cnode);
  304. std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  305. size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
  306. for (size_t input_index = 0; input_index < in_num; ++input_index) {
  307. auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
  308. const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
  309. TypeId origin_type(kTypeUnknown);
  310. auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
  311. auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
  312. auto real_input_node = kernel_with_index.first;
  313. if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
  314. // weight
  315. origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
  316. if (origin_type == kTypeUnknown) {
  317. origin_type = AnfAlgo::GetOutputDeviceDataType(prev_node.first, prev_node.second);
  318. }
  319. } else {
  320. // feature map
  321. origin_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
  322. }
  323. const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
  324. const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
  325. // In graph kernel, we check parameter,
  326. // the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
  327. if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
  328. new_inputs.push_back(cur_input);
  329. } else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
  330. auto cast =
  331. AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
  332. MS_EXCEPTION_IF_NULL(cast);
  333. cast->set_scope(cnode->scope());
  334. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
  335. new_inputs.push_back(cast);
  336. } else {
  337. new_inputs.push_back(cur_input);
  338. }
  339. }
  340. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  341. CNodePtr new_node = nullptr;
  342. if (kernel_graph == nullptr) {
  343. new_node = std::make_shared<CNode>(*cnode);
  344. } else {
  345. new_node = kernel_graph->NewCNode(cnode);
  346. }
  347. MS_EXCEPTION_IF_NULL(new_node);
  348. new_node->set_inputs(new_inputs);
  349. return new_node;
  350. }
  351. AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  352. MS_EXCEPTION_IF_NULL(graph);
  353. MS_EXCEPTION_IF_NULL(node);
  354. auto prim = std::make_shared<Primitive>(kMemCpyAsyncOpName);
  355. std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
  356. auto new_node = graph->NewCNode(new_node_inputs);
  357. MS_EXCEPTION_IF_NULL(new_node);
  358. new_node->set_abstract(node->abstract());
  359. new_node->set_scope(node->scope());
  360. AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
  361. return new_node;
  362. }
  363. } // namespace opt
  364. } // namespace mindspore