diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 4d18e3b28a..8a14b438bb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -54,7 +54,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt CNodePtr trans_data = nullptr; std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; - TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); MS_EXCEPTION_IF_NULL(node); // if insert transdata for input we need to change the input @@ -63,7 +62,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; } auto cnode = node->cast(); - dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); input_node = AnfAlgo::GetInputNode(cnode, insert_index); padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); @@ -95,7 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt trans_node = reshape_node; } // refresh the transdata's format to ori format & dst format - RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis); + RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); return trans_node; } @@ -162,22 +160,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const return make_tuple; } } // namespace -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const AnfNodePtr &trans_data, const std::vector &reshape_type) { MS_EXCEPTION_IF_NULL(trans_data); - MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); - auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); - KernelBuildInfoBuilder builder; - builder.SetInputsFormat({input_format}); - builder.SetInputReshapeType({reshape_type}); - builder.SetInputReshapeType({reshape_type}); - builder.SetOutputsFormat({output_format}); - builder.SetInputsDeviceType({device_type}); - builder.SetOutputsDeviceType({device_type}); - builder.SetKernelType(ori_build_info->kernel_type()); - builder.SetFusionType(ori_build_info->fusion_type()); - builder.SetProcessor(ori_build_info->processor()); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get()); + auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); + MS_EXCEPTION_IF_NULL(ori_build_info); + auto builder = std::make_shared(ori_build_info); + builder->SetInputsFormat({input_format}); + builder->SetInputReshapeType({reshape_type}); + builder->SetOutputReshapeType({reshape_type}); + builder->SetOutputsFormat({output_format}); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); } CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index a385e574a4..ad48ca5291 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -70,7 +70,7 @@ class KernelQuery { } }; using KernelQueryPtr = std::shared_ptr; -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 2bf79d389d..f909dae9e4 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP if (origin_format != cur_format && cur_shape.size() > 1) { auto kernel_select = std::make_shared(); final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node); + RefreshKernelBuildInfo(cur_format, origin_format, final_node); final_index = 0; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc index 0305104f5b..bfb7e50486 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc @@ -69,13 +69,11 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n // trans input_format to hwcn new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), - new_transdata_node); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); // trans hwcn to default_format new_transpose_node = NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), - new_transpose_node); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); new_replace_node = new_transpose_node; } else { @@ -83,14 +81,12 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, false, prim::kPrimTranspose->name()); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), - new_transpose_node); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); // trans hwcn to output_format new_transdata_node = NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), - new_transdata_node); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); new_replace_node = new_transdata_node; } FuncGraphManagerPtr manager = func_graph->manager();