Browse Source

Description:Support model_exit in GE

Team:HISI_SW
Feature or Bugfix:Feature
pull/63/head
l00444296 5 years ago
parent
commit
018fdb9397
2 changed files with 7 additions and 7 deletions
  1. +5
    -5
      parser/common/acl_graph_parser_util.cc
  2. +2
    -2
      parser/common/acl_graph_parser_util.h

+ 5
- 5
parser/common/acl_graph_parser_util.cc View File

@@ -310,16 +310,16 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s
return SUCCESS;
}

bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) {
bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) {
if (input_format.empty()) {
// Set default format
if (ge::GetParserContext().type = domi::TENSORFLOW) {
if (ge::GetParserContext().type == domi::TENSORFLOW) {
input_format = "NHWC";
} else {
input_format = "NCHW";
}
return true;
} else if (ge::GetParserContext().type = domi::CAFFE) { // caffe
} else if (ge::GetParserContext().type == domi::CAFFE) { // caffe
if (kCaffeSupportInputFormatSet.find(input_format) != kCaffeSupportInputFormatSet.end()) {
return true;
}
@@ -329,7 +329,7 @@ bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) {
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport);
return false;
} else if (ge::GetParserContext().type = domi::TENSORFLOW) { // tf
} else if (ge::GetParserContext().type ==domi::TENSORFLOW) { // tf
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) {
return true;
}
@@ -343,7 +343,7 @@ bool AclGrphParseUtil::CheckAclInputFormat(const string &input_format) {
return true;
}

domi::Status AclGrphParseUtil::ParseAclFormat(const string &input_format) {
domi::Status AclGrphParseUtil::ParseAclFormat(string &input_format) {
ge::GetParserContext().format = domi::DOMI_TENSOR_ND;
if (!CheckAclInputFormat(input_format)) {
GELOGE(PARAM_INVALID, "Check input_format failed");


+ 2
- 2
parser/common/acl_graph_parser_util.h View File

@@ -53,8 +53,8 @@ class AclGrphParseUtil {
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
domi::Status ParseAclLogLevel(const std::string &log);
bool CheckAclInputFormat(const string &input_format);
domi::Status ParseAclFormat(const std::string &input_format);
bool CheckAclInputFormat(string &input_format);
domi::Status ParseAclFormat(std::string &input_format);
bool ParseInputShape(const std::string &input_shape,
std::unordered_map<std::string, vector<int64_t>> &shape_map,
vector<pair<std::string, vector<int64_t>>> &user_shape_map,


Loading…
Cancel
Save