| @@ -54,7 +54,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| CNodePtr trans_data = nullptr; | |||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; | |||
| TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); | |||
| std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // if insert transdata for input we need to change the input | |||
| @@ -63,7 +62,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); | |||
| dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | |||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | |||
| padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | |||
| @@ -95,7 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| trans_node = reshape_node; | |||
| } | |||
| // refresh the transdata's format to ori format & dst format | |||
| RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis); | |||
| RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); | |||
| return trans_node; | |||
| } | |||
| @@ -162,22 +160,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||
| return make_tuple; | |||
| } | |||
| } // namespace | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) { | |||
| MS_EXCEPTION_IF_NULL(trans_data); | |||
| MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | |||
| auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({input_format}); | |||
| builder.SetInputReshapeType({reshape_type}); | |||
| builder.SetInputReshapeType({reshape_type}); | |||
| builder.SetOutputsFormat({output_format}); | |||
| builder.SetInputsDeviceType({device_type}); | |||
| builder.SetOutputsDeviceType({device_type}); | |||
| builder.SetKernelType(ori_build_info->kernel_type()); | |||
| builder.SetFusionType(ori_build_info->fusion_type()); | |||
| builder.SetProcessor(ori_build_info->processor()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get()); | |||
| auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); | |||
| MS_EXCEPTION_IF_NULL(ori_build_info); | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info); | |||
| builder->SetInputsFormat({input_format}); | |||
| builder->SetInputReshapeType({reshape_type}); | |||
| builder->SetOutputReshapeType({reshape_type}); | |||
| builder->SetOutputsFormat({output_format}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); | |||
| } | |||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | |||
| @@ -70,7 +70,7 @@ class KernelQuery { | |||
| } | |||
| }; | |||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}); | |||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | |||
| @@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| if (origin_format != cur_format && cur_shape.size() > 1) { | |||
| auto kernel_select = std::make_shared<KernelSelect>(); | |||
| final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node); | |||
| RefreshKernelBuildInfo(cur_format, origin_format, final_node); | |||
| final_index = 0; | |||
| MS_EXCEPTION_IF_NULL(final_node); | |||
| MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | |||
| @@ -69,13 +69,11 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| // trans input_format to hwcn | |||
| new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | |||
| false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), | |||
| new_transdata_node); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); | |||
| // trans hwcn to default_format | |||
| new_transpose_node = | |||
| NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), | |||
| new_transpose_node); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | |||
| new_replace_node = new_transpose_node; | |||
| } else { | |||
| @@ -83,14 +81,12 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | |||
| false, prim::kPrimTranspose->name()); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), | |||
| new_transpose_node); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); | |||
| // trans hwcn to output_format | |||
| new_transdata_node = | |||
| NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), | |||
| new_transdata_node); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); | |||
| new_replace_node = new_transdata_node; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||