Browse Source

!12725 add perm attr to transpose when transdata spilt

From: @lianliguang
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
16652972cd
4 changed files with 12 additions and 10 deletions
  1. +4
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  2. +2
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h
  3. +3
    -3
      mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_dynamic_gru_v2.cc
  4. +3
    -5
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc

+ 4
- 1
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -240,7 +240,7 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
}

CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name) {
const bool need_padding, const std::string &op_name, const std::vector<int64_t> &perm) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input);
CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
@@ -261,6 +261,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
auto kernel_info = std::make_shared<device::KernelInfo>();
trans_node->set_kernel_info(kernel_info);
}
if (op_name == prim::kPrimTranspose->name()) {
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node);
}
MS_EXCEPTION_IF_NULL(kernel_select);
kernel_select->SelectKernel(trans_node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);


+ 2
- 1
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h View File

@@ -92,7 +92,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
const TypeId &type_id = kTypeUnknown);

CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name);
const bool need_padding, const std::string &op_name,
const std::vector<int64_t> &perm = std::vector<int64_t>{});

CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,


+ 3
- 3
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transpose_for_dynamic_gru_v2.cc View File

@@ -53,9 +53,9 @@ CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto padding_axis = AnfAlgo::GetOutputReshapeType(transdata_node, 0);
KernelSelectPtr kernel_select = std::make_shared<KernelSelect>();
// trans default to hwcn
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0),
kernel_select, false, prim::kPrimTranspose->name());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{2, 3, 1, 0}), new_transpose_node);
new_transpose_node =
NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0), kernel_select, false,
prim::kPrimTranspose->name(), std::vector<int64_t>{2, 3, 1, 0});
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node);
// trans hwcn to output_format


+ 3
- 5
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc View File

@@ -78,16 +78,14 @@ CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePt
false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis);
// trans hwcn to default_format
new_transpose_node =
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false,
prim::kPrimTranspose->name(), std::vector<int64_t>{3, 2, 0, 1});
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{3, 2, 0, 1}), new_transpose_node);
new_replace_node = new_transpose_node;
} else {
// trans default to hwcn
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
false, prim::kPrimTranspose->name());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{2, 3, 1, 0}), new_transpose_node);
false, prim::kPrimTranspose->name(), std::vector<int64_t>{2, 3, 1, 0});
if (output_format == kOpFormat_FRACTAL_ZN_LSTM) {
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node);
}


Loading…
Cancel
Save