|
|
|
@@ -70,7 +70,10 @@ const char *const kDigitError = "is not digit"; |
|
|
|
const std::string kGraphDefaultName = "domi_default"; |
|
|
|
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; |
|
|
|
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; |
|
|
|
|
|
|
|
static std::set<std::string> kCaffeSupportInputFormatSet = {"NCHW", "ND"}; |
|
|
|
static std::set<std::string> kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; |
|
|
|
static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; |
|
|
|
static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; |
|
|
|
/// The maximum length of the file. |
|
|
|
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 |
|
|
|
const int kMaxFileSizeLimit = INT_MAX; |
|
|
|
@@ -307,8 +310,45 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) { |
|
|
|
if (input_format.empty()) { |
|
|
|
// Set default format |
|
|
|
if (ge::GetParserContext().type = domi::TENSORFLOW) { |
|
|
|
input_format = "NHWC"; |
|
|
|
} else { |
|
|
|
input_format = "NCHW"; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} else if (ge::GetParserContext().type = domi::CAFFE) { // caffe |
|
|
|
if (kCaffeSupportInputFormatSet.find(input_format) != kCaffeSupportInputFormatSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// only support NCHW ND |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, kCaffeFormatSupport}); |
|
|
|
GELOGE(ge::FAILED, |
|
|
|
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport); |
|
|
|
return false; |
|
|
|
} else if (ge::GetParserContext().type = domi::TENSORFLOW) { // tf |
|
|
|
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// only support NCHW NHWC ND NCDHW NDHWC |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
|
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, kTFFormatSupport}); |
|
|
|
GELOGE(ge::FAILED, |
|
|
|
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kTFFormatSupport); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclFormat(const string &input_format) { |
|
|
|
ge::GetParserContext().format = domi::DOMI_TENSOR_ND; |
|
|
|
if (!CheckAclInputFormat(input_format)) { |
|
|
|
GELOGE(PARAM_INVALID, "Check input_format failed"); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (!input_format.empty()) { |
|
|
|
auto iter = kInputFormatStrToGeformat.find(input_format); |
|
|
|
if (iter != kInputFormatStrToGeformat.end()) { |
|
|
|
|