| @@ -54,7 +54,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| CNodePtr trans_data = nullptr; | CNodePtr trans_data = nullptr; | ||||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); | std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); | ||||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; | std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; | ||||
| TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); | |||||
| std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| // if insert transdata for input we need to change the input | // if insert transdata for input we need to change the input | ||||
| @@ -63,7 +62,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); | |||||
| dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | ||||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | input_node = AnfAlgo::GetInputNode(cnode, insert_index); | ||||
| padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | ||||
| @@ -95,7 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| trans_node = reshape_node; | trans_node = reshape_node; | ||||
| } | } | ||||
| // refresh the transdata's format to ori format & dst format | // refresh the transdata's format to ori format & dst format | ||||
| RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis); | |||||
| RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); | |||||
| return trans_node; | return trans_node; | ||||
| } | } | ||||
| @@ -162,22 +160,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||||
| return make_tuple; | return make_tuple; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | |||||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||||
| const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) { | const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) { | ||||
| MS_EXCEPTION_IF_NULL(trans_data); | MS_EXCEPTION_IF_NULL(trans_data); | ||||
| MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | |||||
| auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | |||||
| KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({input_format}); | |||||
| builder.SetInputReshapeType({reshape_type}); | |||||
| builder.SetInputReshapeType({reshape_type}); | |||||
| builder.SetOutputsFormat({output_format}); | |||||
| builder.SetInputsDeviceType({device_type}); | |||||
| builder.SetOutputsDeviceType({device_type}); | |||||
| builder.SetKernelType(ori_build_info->kernel_type()); | |||||
| builder.SetFusionType(ori_build_info->fusion_type()); | |||||
| builder.SetProcessor(ori_build_info->processor()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get()); | |||||
| auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); | |||||
| MS_EXCEPTION_IF_NULL(ori_build_info); | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info); | |||||
| builder->SetInputsFormat({input_format}); | |||||
| builder->SetInputReshapeType({reshape_type}); | |||||
| builder->SetOutputReshapeType({reshape_type}); | |||||
| builder->SetOutputsFormat({output_format}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); | |||||
| } | } | ||||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | ||||
| @@ -70,7 +70,7 @@ class KernelQuery { | |||||
| } | } | ||||
| }; | }; | ||||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | using KernelQueryPtr = std::shared_ptr<KernelQuery>; | ||||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | |||||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||||
| const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}); | const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}); | ||||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | ||||
| @@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| if (origin_format != cur_format && cur_shape.size() > 1) { | if (origin_format != cur_format && cur_shape.size() > 1) { | ||||
| auto kernel_select = std::make_shared<KernelSelect>(); | auto kernel_select = std::make_shared<KernelSelect>(); | ||||
| final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | ||||
| RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node); | |||||
| RefreshKernelBuildInfo(cur_format, origin_format, final_node); | |||||
| final_index = 0; | final_index = 0; | ||||
| MS_EXCEPTION_IF_NULL(final_node); | MS_EXCEPTION_IF_NULL(final_node); | ||||
| MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | ||||
| @@ -69,13 +69,11 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||||
| // trans input_format to hwcn | // trans input_format to hwcn | ||||
| new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | ||||
| false, prim::KPrimTransData->name()); | false, prim::KPrimTransData->name()); | ||||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), | |||||
| new_transdata_node); | |||||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); | |||||
| // trans hwcn to default_format | // trans hwcn to default_format | ||||
| new_transpose_node = | new_transpose_node = | ||||
| NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); | NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); | ||||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), | |||||
| new_transpose_node); | |||||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); | |||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | ||||
| new_replace_node = new_transpose_node; | new_replace_node = new_transpose_node; | ||||
| } else { | } else { | ||||
| @@ -83,14 +81,12 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||||
| new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | ||||
| false, prim::kPrimTranspose->name()); | false, prim::kPrimTranspose->name()); | ||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | ||||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), | |||||
| new_transpose_node); | |||||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); | |||||
| // trans hwcn to output_format | // trans hwcn to output_format | ||||
| new_transdata_node = | new_transdata_node = | ||||
| NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); | NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); | ||||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), | |||||
| new_transdata_node); | |||||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); | |||||
| new_replace_node = new_transdata_node; | new_replace_node = new_transdata_node; | ||||
| } | } | ||||
| FuncGraphManagerPtr manager = func_graph->manager(); | FuncGraphManagerPtr manager = func_graph->manager(); | ||||