Browse Source

fix bug of got a error transdata's dest format

tags/v0.3.0-alpha
lianliguang chang zherui 5 years ago
parent
commit
d3400cde01
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc

+ 3
- 3
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc View File

@@ -187,10 +187,10 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
return node;
}

void GetTransDataInputFormat(const AnfNodePtr &node, std::string *input_format) {
void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) {
MS_EXCEPTION_IF_NULL(input_format);
if (AnfAlgo::IsRealKernel(node)) {
*input_format = AnfAlgo::GetOutputFormat(node, 0);
*input_format = AnfAlgo::GetOutputFormat(node, idx);
} else {
*input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
}
@@ -206,7 +206,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
bool padding_flag = false;

std::string output_format;
GetTransDataInputFormat(node, &output_format);
GetTransDataInputFormat(node, output_idx, &output_format);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node "
<< node->DebugString();


Loading…
Cancel
Save