Browse Source

Feature:Support user options of aclgrphParse interface

pull/63/head
l00444296 5 years ago
parent
commit
3f8c34eea8
2 changed files with 204 additions and 102 deletions
  1. +114
    -16
      parser/common/acl_graph_parser_util.cc
  2. +90
    -86
      parser/common/acl_graph_parser_util.h

+ 114
- 16
parser/common/acl_graph_parser_util.cc View File

@@ -73,14 +73,17 @@ const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";
static std::set<std::string> kCaffeSupportInputFormatSet = {"NCHW", "ND"};
static std::set<std::string> kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"};
static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
const int kMaxBuffSize = 256;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
const int kOutputTypeNode = 0;
const int kOutputTypeIndex = 1;
const int kOutputTypeDataType = 2;

vector<string> SplitInputShape(const std::string &input_shape) {
vector<string> shape_pair_vec;
@@ -137,7 +140,7 @@ static void GetOpsProtoPath(string &opsproto_path) {
}

static void GetAclParams(const std::map<string, string> &parser_params,
const string &key, string &value) {
const string &key, string &value) {
auto iter = parser_params.find(key);
if (iter != parser_params.end()) {
value = iter->second;
@@ -375,7 +378,7 @@ bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) {
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport);
return false;
} else if (ge::GetParserContext().type ==domi::TENSORFLOW) { // tf
} else if (ge::GetParserContext().type == domi::TENSORFLOW) { // tf
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) {
return true;
}
@@ -494,7 +497,7 @@ bool AclGrphParseUtil::ParseInputShape(const string &input_shape,
return true;
}

//Parse user input shape info
// Parse user input shape info
domi::Status AclGrphParseUtil::ParseAclShape(const string &input_shape, bool is_dynamic_input) {
ge::GetParserContext().input_dims.clear();
ge::GetParserContext().user_input_dims.clear();
@@ -541,7 +544,8 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not all index or top_name"});
GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str());
GELOGE(PARAM_INVALID,
"This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
}
// stoi: The method may throw an exception: invalid_argument/out_of_range
@@ -647,8 +651,8 @@ domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fu
}

void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec,
const string &fp16_nodes_name, uint32_t index,
OpDescPtr &op_desc) {
const string &fp16_nodes_name, uint32_t index,
OpDescPtr &op_desc) {
if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) {
if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) {
GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str());
@@ -727,6 +731,10 @@ domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr
std::string compress_nodes;
ifs >> compress_nodes;
ifs.close();
if (compress_nodes.empty()) {
GELOGW("Compress weight of nodes info is empty");
return SUCCESS;
}
GELOGI("Compress weight of nodes: %s", compress_nodes.c_str());

vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';');
@@ -763,14 +771,14 @@ domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type
return domi::FAILED;
}
ge::DataType tmp_dt;
std::string node_name = StringUtils::Trim(node_index_type_v[0]);
std::string index_str = StringUtils::Trim(node_index_type_v[1]);
std::string node_name = StringUtils::Trim(node_index_type_v[kOutputTypeNode]);
std::string index_str = StringUtils::Trim(node_index_type_v[kOutputTypeIndex]);
int32_t index;
if (StringToInt(index_str, index) != SUCCESS) {
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str());
return domi::FAILED;
}
std::string dt_value = StringUtils::Trim(node_index_type_v[2]);
std::string dt_value = StringUtils::Trim(node_index_type_v[kOutputTypeDataType]);
auto it = kOutputTypeSupportDatatype.find(dt_value);
if (it == kOutputTypeSupportDatatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
@@ -811,7 +819,8 @@ void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::Node
std::string node_name = output_nodes_info[i].first->GetName();
int32_t index = output_nodes_info[i].second;
if (i < ge::GetParserContext().out_top_names.size()) {
output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + ge::GetParserContext().out_top_names[i]);
output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" +
ge::GetParserContext().out_top_names[i]);
} else {
GELOGW("Get top name of node [%s] fail.", node_name.c_str());
output_nodes_name.push_back(node_name + ":" + std::to_string(index));
@@ -982,6 +991,86 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<std::string, std::str
return SUCCESS;
}

domi::Status AclGrphParseUtil::CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input) {
if (!is_dynamic_input)
{
for (auto node : graph->GetDirectNode())
{
if (node->GetType() == ge::parser::DATA)
{
auto data_op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(data_op_desc);
auto tensor_desc = data_op_desc->MutableInputDesc(0);
GE_CHECK_NOTNULL(tensor_desc);
for (auto dim : tensor_desc->GetShape().GetDims())
{
if (dim < 0)
{
GELOGE(PARAM_INVALID,
"Input op [%s] shape %ld is negative, maybe you should set input_shape to specify its shape",
node->GetName().c_str(), dim);
const string reason = "maybe you should set input_shape to specify its shape";
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {node->GetName(), to_string(dim), reason});
return PARAM_INVALID;
}
}
}
}
}
for (auto it : ge::GetParserContext().user_input_dims)
{
std::string node_name = it.first;
ge::NodePtr node = graph->FindNode(node_name);
if (node == nullptr)
{
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str());
return PARAM_INVALID;
}
if (node->GetType() != ge::parser::DATA)
{
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str());
return PARAM_INVALID;
}
}
return SUCCESS;
}

domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) {
GE_CHECK_NOTNULL(graph);
unordered_map<string, string> graphNodeTypes;
for (const NodePtr &node : graph->GetAllNodes())
{
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr)
{
GELOGE(PARAM_INVALID, "Invalid parameter for opDesc.");
return PARAM_INVALID;
}
graphNodeTypes[op_desc->GetType()] = "";
}
std::map<std::string, std::string> &propertiesMap = ge::GetParserContext().op_conf_map;
if (propertiesMap.empty())
{
ErrorManager::GetInstance().ATCReportErrMessage(
"E10003", {"parameter", "value", "reason"}, {"op_name_map", op_conf, "the file content is empty"});
GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!");
return PARAM_INVALID;
}
for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++)
{
GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(),
ErrorManager::GetInstance().ATCReportErrMessage(
"E10003",
{"parameter", "value", "reason"},
{"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"});
GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map.");
return PARAM_INVALID;);
}
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string, std::string> &parser_params,
string &graph_name) {
@@ -995,7 +1084,6 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclLogLevel(log_level) != SUCCESS,
return PARAM_INVALID, "Parse log_level failed");


string input_format;
GetAclParams(parser_params, ge::ir_option::INPUT_FORMAT, input_format);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclFormat(input_format) != SUCCESS,
@@ -1017,7 +1105,6 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS,
return PARAM_INVALID, "Parse out_nodes failed");


string is_output_adjust_hw_layout;
GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS,
@@ -1050,13 +1137,24 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,

string is_input_adjust_hw_layout;
GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS,
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes,
is_input_adjust_hw_layout) != SUCCESS,
return PARAM_INVALID, "Parse input_fp16_nodes failed");
//GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input));

bool is_dynamic_input = ge::GetParserContext().is_dynamic_input;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclInputShapeNode(compute_graph, is_dynamic_input) != SUCCESS,
return PARAM_INVALID,
"Check nodes input_shape info failed");

string compress_weight_conf;
GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS,
return PARAM_INVALID, "Parse compress_weight_conf failed");
string op_conf_str;
GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclOpNameMap(compute_graph, op_conf_str) != SUCCESS,
return PARAM_INVALID, "Check op_name_map info failed");
return SUCCESS;
}



+ 90
- 86
parser/common/acl_graph_parser_util.h View File

@@ -77,105 +77,109 @@ class AclGrphParseUtil {
std::map<std::string, vector<std::string>> &output_node_dt_map);
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input);
domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf);
};

namespace parser {
///
/// @ingroup: domi_common
/// @brief: get length of file
/// @param [in] input_file: path of file
/// @return long: File length. If the file length fails to be obtained, the value -1 is returned.
///
extern long GetFileLength(const std::string &input_file);
namespace parser
{
///
/// @ingroup: domi_common
/// @brief: get length of file
/// @param [in] input_file: path of file
/// @return long: File length. If the file length fails to be obtained, the value -1 is returned.
///
extern long GetFileLength(const std::string &input_file);

///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);
///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);

///
/// @ingroup domi_common
/// @brief Obtains the absolute time (timestamp) of the current system.
/// @return Timestamp, in microseconds (US)
///
///
uint64_t GetCurrentTimestamp();
///
/// @ingroup domi_common
/// @brief Obtains the absolute time (timestamp) of the current system.
/// @return Timestamp, in microseconds (US)
///
///
uint64_t GetCurrentTimestamp();

///
/// @ingroup domi_common
/// @brief Reads all data from a binary file.
/// @param [in] file_name path of file
/// @param [out] buffer Output memory address, which needs to be released by the caller.
/// @param [out] length Output memory size
/// @return false fail
/// @return true success
///
bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length);
///
/// @ingroup domi_common
/// @brief Reads all data from a binary file.
/// @param [in] file_name path of file
/// @param [out] buffer Output memory address, which needs to be released by the caller.
/// @param [out] length Output memory size
/// @return false fail
/// @return true success
///
bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length);

///
/// @ingroup domi_common
/// @brief proto file in bianary format
/// @param [in] file path of proto file
/// @param [out] proto memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromBinaryFile(const char *file, Message *proto);
///
/// @ingroup domi_common
/// @brief proto file in bianary format
/// @param [in] file path of proto file
/// @param [out] proto memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromBinaryFile(const char *file, Message *proto);

///
/// @ingroup domi_common
/// @brief Reads the proto structure from an array.
/// @param [in] data proto data to be read
/// @param [in] size proto data size
/// @param [out] proto Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromArray(const void *data, int size, Message *proto);
///
/// @ingroup domi_common
/// @brief Reads the proto structure from an array.
/// @param [in] data proto data to be read
/// @param [in] size proto data size
/// @param [out] proto Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromArray(const void *data, int size, Message *proto);

///
/// @ingroup domi_proto
/// @brief Reads the proto file in the text format.
/// @param [in] file path of proto file
/// @param [out] message Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromText(const char *file, google::protobuf::Message *message);
///
/// @ingroup domi_proto
/// @brief Reads the proto file in the text format.
/// @param [in] file path of proto file
/// @param [out] message Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromText(const char *file, google::protobuf::Message *message);

bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message);
bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message);

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
domi::Status GetOriginalType(const ge::NodePtr &node, string &type);
///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
domi::Status GetOriginalType(const ge::NodePtr &node, string &type);

///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);
///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);

///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();
///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();

template <typename T, typename... Args>
static inline std::shared_ptr<T> MakeShared(Args &&... args) {
typedef typename std::remove_const<T>::type T_nc;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
template <typename T, typename... Args>
static inline std::shared_ptr<T> MakeShared(Args &&... args)
{
typedef typename std::remove_const<T>::type T_nc;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
}

/// @ingroup math_util


Loading…
Cancel
Save