|
|
|
@@ -310,16 +310,16 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) { |
|
|
|
bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) { |
|
|
|
if (input_format.empty()) { |
|
|
|
// Set default format |
|
|
|
if (ge::GetParserContext().type = domi::TENSORFLOW) { |
|
|
|
if (ge::GetParserContext().type == domi::TENSORFLOW) { |
|
|
|
input_format = "NHWC"; |
|
|
|
} else { |
|
|
|
input_format = "NCHW"; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} else if (ge::GetParserContext().type = domi::CAFFE) { // caffe |
|
|
|
} else if (ge::GetParserContext().type == domi::CAFFE) { // caffe |
|
|
|
if (kCaffeSupportInputFormatSet.find(input_format) != kCaffeSupportInputFormatSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -329,7 +329,7 @@ bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) { |
|
|
|
GELOGE(ge::FAILED, |
|
|
|
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport); |
|
|
|
return false; |
|
|
|
} else if (ge::GetParserContext().type = domi::TENSORFLOW) { // tf |
|
|
|
} else if (ge::GetParserContext().type ==domi::TENSORFLOW) { // tf |
|
|
|
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -343,7 +343,7 @@ bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclFormat(const string &input_format) { |
|
|
|
domi::Status AclGrphParseUtil::ParseAclFormat(string &input_format) { |
|
|
|
ge::GetParserContext().format = domi::DOMI_TENSOR_ND; |
|
|
|
if (!CheckAclInputFormat(input_format)) { |
|
|
|
GELOGE(PARAM_INVALID, "Check input_format failed"); |
|
|
|
|