|
|
|
@@ -78,16 +78,14 @@ CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePt |
|
|
|
false, prim::KPrimTransData->name()); |
|
|
|
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis); |
|
|
|
// trans hwcn to default_format |
|
|
|
new_transpose_node = |
|
|
|
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); |
|
|
|
new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, |
|
|
|
prim::kPrimTranspose->name(), std::vector<int64_t>{3, 2, 0, 1}); |
|
|
|
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{3, 2, 0, 1}), new_transpose_node); |
|
|
|
new_replace_node = new_transpose_node; |
|
|
|
} else { |
|
|
|
// trans default to hwcn |
|
|
|
new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, |
|
|
|
false, prim::kPrimTranspose->name()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{2, 3, 1, 0}), new_transpose_node); |
|
|
|
false, prim::kPrimTranspose->name(), std::vector<int64_t>{2, 3, 1, 0}); |
|
|
|
if (output_format == kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node); |
|
|
|
} |
|
|
|
|