|
|
|
@@ -187,10 +187,10 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
void GetTransDataInputFormat(const AnfNodePtr &node, std::string *input_format) { |
|
|
|
void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_format); |
|
|
|
if (AnfAlgo::IsRealKernel(node)) { |
|
|
|
*input_format = AnfAlgo::GetOutputFormat(node, 0); |
|
|
|
*input_format = AnfAlgo::GetOutputFormat(node, idx); |
|
|
|
} else { |
|
|
|
*input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); |
|
|
|
} |
|
|
|
@@ -206,7 +206,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const |
|
|
|
bool padding_flag = false; |
|
|
|
|
|
|
|
std::string output_format; |
|
|
|
GetTransDataInputFormat(node, &output_format); |
|
|
|
GetTransDataInputFormat(node, output_idx, &output_format); |
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) { |
|
|
|
MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node " |
|
|
|
<< node->DebugString(); |
|
|
|
|