| @@ -187,10 +187,10 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An | |||||
| return node; | return node; | ||||
| } | } | ||||
| void GetTransDataInputFormat(const AnfNodePtr &node, std::string *input_format) { | |||||
| void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) { | |||||
| MS_EXCEPTION_IF_NULL(input_format); | MS_EXCEPTION_IF_NULL(input_format); | ||||
| if (AnfAlgo::IsRealKernel(node)) { | if (AnfAlgo::IsRealKernel(node)) { | ||||
| *input_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| *input_format = AnfAlgo::GetOutputFormat(node, idx); | |||||
| } else { | } else { | ||||
| *input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); | *input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); | ||||
| } | } | ||||
| @@ -206,7 +206,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||||
| bool padding_flag = false; | bool padding_flag = false; | ||||
| std::string output_format; | std::string output_format; | ||||
| GetTransDataInputFormat(node, &output_format); | |||||
| GetTransDataInputFormat(node, output_idx, &output_format); | |||||
| if (output_format == kOpFormat_NC1KHKWHWC0) { | if (output_format == kOpFormat_NC1KHKWHWC0) { | ||||
| MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node " | MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node " | ||||
| << node->DebugString(); | << node->DebugString(); | ||||