|
|
|
@@ -26,8 +26,11 @@ |
|
|
|
#include "common/debug/log.h" |
|
|
|
#include "common/op/ge_op_utils.h" |
|
|
|
#include "common/properties_manager.h" |
|
|
|
#include "common/type.h" |
|
|
|
#include "common/util.h" |
|
|
|
#include "ge/ge_api_types.h" |
|
|
|
#include "graph/opsproto_manager.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "omg/parser/parser_inner_ctx.h" |
|
|
|
#include "tbe_plugin_loader.h" |
|
|
|
#include "framework/common/debug/ge_log.h" |
|
|
|
@@ -53,7 +56,7 @@ static std::map<std::string, domiTensorFormat_t> kInputFormatStrToGeformat = { |
|
|
|
{"NCDHW", domi::DOMI_TENSOR_NCDHW}, |
|
|
|
{"NDHWC", domi::DOMI_TENSOR_NDHWC} |
|
|
|
}; |
|
|
|
static char *const kIsOutputAdjustHwLayoutKey = "is_output_adjust_hw_layout"; |
|
|
|
const char *const kIsOutputAdjustHwLayoutKey = "is_output_adjust_hw_layout"; |
|
|
|
// datatype/formats from user to GE, Unified to util interface file later |
|
|
|
const std::map<std::string, ge::DataType> kOutputTypeSupportDatatype = { |
|
|
|
{"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; |
|
|
|
@@ -64,7 +67,10 @@ const char *const kSplitError1 = "size not equal to 2 split by \":\""; |
|
|
|
const char *const kEmptyError = "can not be empty"; |
|
|
|
const char *const kFloatNumError = "exist float number"; |
|
|
|
const char *const kDigitError = "is not digit"; |
|
|
|
|
|
|
|
const std::string kGraphDefaultName = "domi_default"; |
|
|
|
const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; |
|
|
|
const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; |
|
|
|
c |
|
|
|
/// The maximum length of the file. |
|
|
|
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 |
|
|
|
const int kMaxFileSizeLimit = INT_MAX; |
|
|
|
@@ -180,6 +186,22 @@ static domi::Status StringToInt(std::string &str, int32_t &value) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { |
|
|
|
int32_t out_size = op_desc->GetOutputsSize(); |
|
|
|
if (index < 0 || index >= out_size) { |
|
|
|
GELOGE(domi::FAILED, |
|
|
|
"out_node [%s] output index:%d must be smaller " |
|
|
|
"than node output size:%d and can not be negative!", |
|
|
|
op_desc->GetName().c_str(), index, out_size); |
|
|
|
std::string fail_reason = "output index:" + to_string(index) + " must be smaller than output size:" + |
|
|
|
to_string(out_size) + " and can not be negative!"; |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, |
|
|
|
{"out_nodes", op_desc->GetName(), fail_reason}); |
|
|
|
return domi::FAILED; |
|
|
|
} |
|
|
|
return domi::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
domi::Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) { |
|
|
|
std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes; |
|
|
|
std::set<std::string> out_nodes_info; |
|
|
|
@@ -291,7 +313,7 @@ domi::Status AclGrphParseUtil::ParseAclFormat(const string &input_format) { |
|
|
|
} |
|
|
|
|
|
|
|
bool AclGrphParseUtil::ParseInputShape(const string &input_shape, |
|
|
|
unordered_map<string, vector<int64_t>> &shape_map, |
|
|
|
std::unordered_map<string, vector<int64_t>> &shape_map, |
|
|
|
vector<pair<string, vector<int64_t>>> &user_shape_map, |
|
|
|
bool is_dynamic_input) { |
|
|
|
vector<string> shape_vec = StringUtils::Split(input_shape, ';'); |
|
|
|
@@ -386,7 +408,7 @@ domi::Status AclGrphParseUtil::ParseAclShape(const string &input_shape, bool is_ |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
unordered_map<string, vector<int64_t>> &shape_map = ge::GetParserContext().input_dims; |
|
|
|
std::unordered_map<string, vector<int64_t>> &shape_map = ge::GetParserContext().input_dims; |
|
|
|
if (!ParseInputShape(input_shape, ge::GetParserContext().input_dims, |
|
|
|
ge::GetParserContext().user_input_dims, is_dynamic_input) || shape_map.empty()) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); |
|
|
|
@@ -496,6 +518,7 @@ domi::Status AclGrphParseUtil::ParseAclOpConf(const char *op_conf) { |
|
|
|
// Return map and put it into ATC global variable |
|
|
|
ge::GetParserContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap(); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) { |
|
|
|
@@ -504,9 +527,10 @@ domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fu |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
ge::GetParserContext().enable_scope_fusion_passes = enable_scope_fusion_passes; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
static void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, |
|
|
|
void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, |
|
|
|
const string &fp16_nodes_name, uint32_t index, |
|
|
|
OpDescPtr &op_desc) { |
|
|
|
if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { |
|
|
|
@@ -832,7 +856,9 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string |
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOpConf(op_conf_str.c_str()) != SUCCESS, |
|
|
|
return PARAM_INVALID, "Parse op_name_map failed"); |
|
|
|
|
|
|
|
GetAclParams(parser_params, "output", graph_name); |
|
|
|
string tmp_name; |
|
|
|
GetAclParams(parser_params, "output", tmp_name); |
|
|
|
graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + CurrentTimeInStr()) : tmp_name; |
|
|
|
|
|
|
|
string enable_scope_fusion_passes; |
|
|
|
GetAclParams(parser_params, "enable_scope_fusion_passes", enable_scope_fusion_passes); |
|
|
|
@@ -846,7 +872,7 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, |
|
|
|
const std::map<std::string, std::string> &parser_params) { |
|
|
|
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, |
|
|
|
// log |
|
|
|
compute_graph = GraphUtils::GetComputeGraph(graph); |
|
|
|
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); |
|
|
|
|
|
|
|
string input_fp16_nodes; |
|
|
|
GetAclParams(parser_params, "input_fp16_nodes", input_fp16_nodes); |
|
|
|
|