You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_helper.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pre_activate/ascend/ascend_helper.h"
  17. #include <set>
  18. #include "common/trans.h"
  19. #include "common/utils.h"
  20. #include "pre_activate/common/helper.h"
  21. #include "utils/utils.h"
  22. #include "device/kernel_info.h"
  23. #include "kernel/oplib/oplib.h"
  24. #include "operator/ops.h"
  25. #include "session/anf_runtime_algorithm.h"
  26. #include "session/kernel_graph.h"
  27. #include "utils/context/ms_context.h"
  28. namespace mindspore {
  29. namespace opt {
  30. using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
  31. namespace {
  32. kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
  33. const AnfNodePtr &node,
  34. const kernel::KernelBuildInfo ori_build_info) {
  35. KernelBuildInfoBuilder builder;
  36. builder.SetInputsFormat({input_format});
  37. builder.SetOutputsFormat({output_format});
  38. builder.SetInputsDeviceType({ori_build_info.GetInputDeviceType(0)});
  39. builder.SetOutputsDeviceType({ori_build_info.GetOutputDeviceType(0)});
  40. builder.SetKernelType(ori_build_info.kernel_type());
  41. builder.SetFusionType(ori_build_info.fusion_type());
  42. builder.SetProcessor(ori_build_info.processor());
  43. return builder.Build();
  44. }
  45. CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
  46. const bool need_padding, const std::string &op_name) {
  47. MS_EXCEPTION_IF_NULL(func_graph);
  48. MS_EXCEPTION_IF_NULL(input);
  49. std::vector<AnfNodePtr> trans_inputs;
  50. auto prim = std::make_shared<Primitive>(op_name);
  51. trans_inputs.push_back(NewValueNode(prim));
  52. trans_inputs.push_back(input);
  53. CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
  54. MS_EXCEPTION_IF_NULL(trans_node);
  55. if (need_padding) {
  56. // if need padding we should set the transdata node's shape to the padding shape
  57. AnfAlgo::SetOutputInferTypeAndShape(
  58. {AnfAlgo::GetOutputInferDataType(input, 0)},
  59. {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))},
  60. trans_node.get());
  61. } else {
  62. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
  63. {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
  64. }
  65. // special handle for ut
  66. if (trans_node->kernel_info() == nullptr) {
  67. auto kernel_info = std::make_shared<device::KernelInfo>();
  68. trans_node->set_kernel_info(kernel_info);
  69. }
  70. MS_EXCEPTION_IF_NULL(kernel_select);
  71. kernel_select->SelectKernel(trans_node);
  72. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
  73. MS_EXCEPTION_IF_NULL(trans_node);
  74. trans_node->set_scope(input->scope());
  75. return trans_node;
  76. }
  77. AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
  78. const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
  79. std::vector<AnfNodePtr> trans_inputs;
  80. auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
  81. trans_inputs.emplace_back(NewValueNode(prim));
  82. trans_inputs.emplace_back(input_node);
  83. auto reshape = func_graph->NewCNode(trans_inputs);
  84. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
  85. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
  86. AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
  87. reshape->set_scope(input_node->scope());
  88. kernel_select->SelectKernel(reshape);
  89. return reshape;
  90. }
  91. AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
  92. const KernelSelectPtr &kernel_select) {
  93. MS_EXCEPTION_IF_NULL(node);
  94. auto input_node = AnfAlgo::GetInputNode(node, index);
  95. auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
  96. MS_EXCEPTION_IF_NULL(node_with_index.first);
  97. auto real_input = node_with_index.first;
  98. if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
  99. input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
  100. MS_EXCEPTION_IF_NULL(input_node);
  101. AnfAlgo::SetNodeInput(node, input_node, index);
  102. }
  103. if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
  104. MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
  105. << "when inserting the transdata node " << node->DebugString();
  106. }
  107. std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
  108. std::string origin_format = kOpFormat_DEFAULT;
  109. std::string dest_format = AnfAlgo::GetInputFormat(node, index);
  110. if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
  111. MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
  112. << " To DefaultFormat , index: " << index;
  113. return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName,
  114. true);
  115. }
  116. return input_node;
  117. }
  118. AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  119. const KernelSelectPtr &kernel_select) {
  120. MS_EXCEPTION_IF_NULL(node);
  121. std::string output_format;
  122. std::vector<size_t> origin_shape;
  123. if (!AnfAlgo::IsRealKernel(node)) {
  124. output_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
  125. origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
  126. } else {
  127. output_format = AnfAlgo::GetOutputFormat(node, 0);
  128. origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
  129. }
  130. if (output_format == kOpFormat_NC1KHKWHWC0) {
  131. MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
  132. << node->DebugString();
  133. }
  134. std::string origin_format = output_format;
  135. std::string dest_format = kOpFormat_DEFAULT;
  136. if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
  137. MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
  138. return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName,
  139. false);
  140. }
  141. return node;
  142. }
  143. AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  144. const KernelSelectPtr &kernel_select) {
  145. MS_EXCEPTION_IF_NULL(func_graph);
  146. MS_EXCEPTION_IF_NULL(node);
  147. std::vector<AnfNodePtr> make_tuple_inputs;
  148. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  149. for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
  150. std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
  151. if (output_format == kOpFormat_NC1KHKWHWC0) {
  152. MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
  153. << node->DebugString();
  154. }
  155. auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
  156. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  157. std::string dest_format = kOpFormat_DEFAULT;
  158. if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
  159. make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format,
  160. dest_format, kTransDataOpName, false));
  161. } else {
  162. // No need insert trans op.
  163. make_tuple_inputs.push_back(tuple_getitem);
  164. }
  165. }
  166. AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
  167. return make_tuple;
  168. }
  169. } // namespace
  170. AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  171. const KernelSelectPtr &kernel_select, size_t insert_index,
  172. const std::string &origin_format, const std::string &dest_format,
  173. const std::string &op_name, bool is_insert_input) {
  174. AnfNodePtr trans_node = nullptr;
  175. AnfNodePtr input_node = node;
  176. AnfNodePtr trans_data = nullptr;
  177. MS_EXCEPTION_IF_NULL(node);
  178. if (origin_format.empty() || dest_format.empty()) {
  179. MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
  180. }
  181. // if insert transdata for input we need to change the input
  182. if (is_insert_input) {
  183. if (!node->isa<CNode>()) {
  184. MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
  185. }
  186. auto cnode = node->cast<CNodePtr>();
  187. MS_EXCEPTION_IF_NULL(cnode);
  188. input_node = AnfAlgo::GetInputNode(cnode, insert_index);
  189. }
  190. bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
  191. op_name == kTransDataOpName);
  192. if (!need_padding) {
  193. // don't need padding insert transdata only
  194. trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
  195. trans_node = trans_data;
  196. } else if (is_insert_input) {
  197. // if need padding & is input need insert a transdata
  198. // reshape[padding shape] -> transdata[padding shape] -> node
  199. auto padding_shape =
  200. trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
  201. auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
  202. trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name);
  203. trans_node = trans_data;
  204. } else {
  205. // if need padding & is output need insert a transdata
  206. // node -> transdata[padding shape] -> reshape[ori_shape]
  207. trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
  208. auto reshape_node =
  209. CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
  210. trans_node = reshape_node;
  211. }
  212. // refresh the transdata's format to ori format & dst format
  213. MS_EXCEPTION_IF_NULL(trans_data);
  214. MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
  215. auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
  216. auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
  217. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
  218. return trans_node;
  219. }
  220. AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
  221. const TypeId &input_type, const TypeId &output_type,
  222. const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
  223. MS_EXCEPTION_IF_NULL(func_graph);
  224. std::string input_format = format;
  225. std::string output_format = format;
  226. std::vector<AnfNodePtr> new_cast_inputs;
  227. auto prim = std::make_shared<Primitive>(prim::kPrimCast->name());
  228. new_cast_inputs.push_back(NewValueNode(prim));
  229. new_cast_inputs.push_back(input);
  230. CNodePtr cast = func_graph->NewCNode(new_cast_inputs);
  231. MS_EXCEPTION_IF_NULL(cast);
  232. // set kernel build info
  233. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  234. builder.SetInputsFormat({input_format});
  235. builder.SetOutputsFormat({output_format});
  236. builder.SetInputsDeviceType({input_type});
  237. builder.SetOutputsDeviceType({output_type});
  238. builder.SetFusionType(kernel::FusionType::OPAQUE);
  239. builder.SetProcessor(kernel::Processor::AICORE);
  240. if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
  241. builder.SetKernelType(KernelType::TBE_KERNEL);
  242. } else {
  243. builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
  244. }
  245. // if kernel info is null , it remarks this function is running ut
  246. if (cast->kernel_info() == nullptr) {
  247. auto kernel_info = std::make_shared<device::KernelInfo>();
  248. cast->set_kernel_info(kernel_info);
  249. }
  250. AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
  251. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
  252. return cast;
  253. }
  254. AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  255. const KernelSelectPtr &kernel_select) {
  256. size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
  257. if (outputs_num == 0) {
  258. return node;
  259. }
  260. // Single output
  261. if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
  262. return InsertTransOpForSingleOutput(func_graph, node, kernel_select);
  263. }
  264. // Multiple output
  265. return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
  266. }
  267. AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  268. const KernelSelectPtr &kernel_select) {
  269. MS_EXCEPTION_IF_NULL(node);
  270. auto cnode = node->cast<CNodePtr>();
  271. MS_EXCEPTION_IF_NULL(cnode);
  272. std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  273. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
  274. AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
  275. MS_EXCEPTION_IF_NULL(input_node);
  276. new_inputs.push_back(input_node);
  277. }
  278. CNodePtr new_cnode = nullptr;
  279. // cnode changed so make a new cnode to differ from original one.
  280. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  281. if (kernel_graph == nullptr) {
  282. new_cnode = std::make_shared<CNode>(*cnode);
  283. } else {
  284. new_cnode = kernel_graph->NewCNode(cnode);
  285. }
  286. MS_EXCEPTION_IF_NULL(new_cnode);
  287. new_cnode->set_inputs(new_inputs);
  288. return new_cnode;
  289. }
  290. CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
  291. MS_EXCEPTION_IF_NULL(cnode);
  292. std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
  293. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
  294. TypeId origin_type;
  295. auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
  296. auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
  297. auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
  298. if (node->isa<ValueNode>()) {
  299. return true;
  300. } else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  301. return true;
  302. }
  303. return false;
  304. };
  305. auto real_input_node = kernel_with_index.first;
  306. if (is_weight_boundary(real_input_node)) {
  307. // weight
  308. origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
  309. } else {
  310. // feature map
  311. origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
  312. }
  313. const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
  314. const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index);
  315. const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
  316. if (origin_type != device_type) {
  317. auto cast =
  318. AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, origin_type);
  319. MS_EXCEPTION_IF_NULL(cast);
  320. cast->set_scope(cnode->scope());
  321. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
  322. new_inputs.push_back(cast);
  323. } else {
  324. new_inputs.push_back(cur_input);
  325. }
  326. }
  327. auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
  328. CNodePtr new_node = nullptr;
  329. if (kernel_graph == nullptr) {
  330. new_node = std::make_shared<CNode>(*cnode);
  331. } else {
  332. new_node = kernel_graph->NewCNode(cnode);
  333. }
  334. MS_EXCEPTION_IF_NULL(new_node);
  335. new_node->set_inputs(new_inputs);
  336. return new_node;
  337. }
  338. AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  339. MS_EXCEPTION_IF_NULL(graph);
  340. MS_EXCEPTION_IF_NULL(node);
  341. auto prim = std::make_shared<Primitive>(kMemCpyAsyncOpName);
  342. std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
  343. auto new_node = graph->NewCNode(new_node_inputs);
  344. MS_EXCEPTION_IF_NULL(new_node);
  345. new_node->set_abstract(node->abstract());
  346. new_node->set_scope(node->scope());
  347. return new_node;
  348. }
  349. } // namespace opt
  350. } // namespace mindspore