Browse Source

Description:Support model_exit in GE

Team:HISI_SW
Feature or Bugfix:Feature
pull/63/head
l00444296 5 years ago
parent
commit
b2d02a7d9e
2 changed files with 42 additions and 1 deletions
  1. +41
    -1
      parser/common/acl_graph_parser_util.cc
  2. +1
    -0
      parser/common/acl_graph_parser_util.h

+ 41
- 1
parser/common/acl_graph_parser_util.cc View File

@@ -70,7 +70,10 @@ const char *const kDigitError = "is not digit";
const std::string kGraphDefaultName = "domi_default";
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"";
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";

static std::set<std::string> kCaffeSupportInputFormatSet = {"NCHW", "ND"};
static std::set<std::string> kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"};
static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
@@ -307,8 +310,45 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s
return SUCCESS;
}

bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) {
if (input_format.empty()) {
// Set default format
if (ge::GetParserContext().type = domi::TENSORFLOW) {
input_format = "NHWC";
} else {
input_format = "NCHW";
}
return true;
} else if (ge::GetParserContext().type = domi::CAFFE) { // caffe
if (kCaffeSupportInputFormatSet.find(input_format) != kCaffeSupportInputFormatSet.end()) {
return true;
}
// only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, kCaffeFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport);
return false;
} else if (ge::GetParserContext().type = domi::TENSORFLOW) { // tf
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) {
return true;
}
// only support NCHW NHWC ND NCDHW NDHWC
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, kTFFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kTFFormatSupport);
return false;
}
return true;
}

domi::Status AclGrphParseUtil::ParseAclFormat(const string &input_format) {
ge::GetParserContext().format = domi::DOMI_TENSOR_ND;
if (!CheckAclInputFormat(input_format)) {
GELOGE(PARAM_INVALID, "Check input_format failed");
return PARAM_INVALID;
}
if (!input_format.empty()) {
auto iter = kInputFormatStrToGeformat.find(input_format);
if (iter != kInputFormatStrToGeformat.end()) {


+ 1
- 0
parser/common/acl_graph_parser_util.h View File

@@ -53,6 +53,7 @@ class AclGrphParseUtil {
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
domi::Status ParseAclLogLevel(const std::string &log);
bool CheckAclInputFormat(const string &input_format);
domi::Status ParseAclFormat(const std::string &input_format);
bool ParseInputShape(const std::string &input_shape,
std::unordered_map<std::string, vector<int64_t>> &shape_map,


Loading…
Cancel
Save