From 04c512ea411c2235da3fa835955d92741d7db489 Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Mon, 1 Mar 2021 09:30:15 +0800 Subject: [PATCH] add perm attr to transpose when transdata spilt --- mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc | 5 ++++- mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h | 3 ++- .../format_type/insert_transpose_for_dynamic_gru_v2.cc | 6 +++--- .../optimizer/ascend/ir_fission/transdata_split.cc | 8 +++----- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 7bc76b4569..911d58130a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -240,7 +240,7 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & } CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name) { + const bool need_padding, const std::string &op_name, const std::vector &perm) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(input); CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared(op_name)), input}); @@ -261,6 +261,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, auto kernel_info = std::make_shared(); trans_node->set_kernel_info(kernel_info); } + if (op_name == prim::kPrimTranspose->name()) { + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node); + } MS_EXCEPTION_IF_NULL(kernel_select); kernel_select->SelectKernel(trans_node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index 5c2d497071..9afb0da47e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -92,7 +92,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & const TypeId &type_id = kTypeUnknown); CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name); + const bool need_padding, const std::string &op_name, + const std::vector &perm = std::vector{}); CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_dynamic_gru_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_dynamic_gru_v2.cc index 63e31f50fc..0fa3d0acee 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_dynamic_gru_v2.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_dynamic_gru_v2.cc @@ -53,9 +53,9 @@ CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { auto padding_axis = AnfAlgo::GetOutputReshapeType(transdata_node, 0); KernelSelectPtr kernel_select = std::make_shared(); // trans default to hwcn - new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast(), 0), - kernel_select, false, prim::kPrimTranspose->name()); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); + new_transpose_node = + NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast(), 0), kernel_select, false, + prim::kPrimTranspose->name(), std::vector{2, 3, 1, 0}); AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node); RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); // trans hwcn to output_format diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc index 10abf6d093..fd1d95aa4c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -78,16 +78,14 @@ CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePt false, prim::KPrimTransData->name()); RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis); // trans hwcn to default_format - new_transpose_node = - NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); + new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, + prim::kPrimTranspose->name(), std::vector{3, 2, 0, 1}); RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); new_replace_node = new_transpose_node; } else { // trans default to hwcn new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::kPrimTranspose->name()); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); + false, prim::kPrimTranspose->name(), std::vector{2, 3, 1, 0}); if (output_format == kOpFormat_FRACTAL_ZN_LSTM) { AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node); }