Browse Source

fix

pull/321/head
wjm 4 years ago
parent
commit
759b5e419a
3 changed files with 2 additions and 183 deletions
  1. +2
    -141
      parser/caffe/caffe_parser.cc
  2. +0
    -40
      parser/caffe/caffe_parser.h
  3. +0
    -2
      parser/onnx/onnx_parser.cc

+ 2
- 141
parser/caffe/caffe_parser.cc View File

@@ -21,6 +21,7 @@
#include <sstream>
#include <memory>
#include <algorithm>
#include "common/convert/message2operator.h"
#include "parser/common/convert/pb2json.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/op_map.h"
@@ -578,7 +579,7 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co
return FAILED;
}

if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) {
if (Message2Operator::ParseOperatorAttrs(message, 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str());
return FAILED;
}
@@ -589,146 +590,6 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co
return SUCCESS;
}

Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) {
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return FAILED;
}

const google::protobuf::Reflection *reflection = message->GetReflection();
GE_CHECK_NOTNULL(reflection);
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);

for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (field->is_repeated()) {
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str());
return FAILED;
}
} else {
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str());
return FAILED;
}
}
}
return SUCCESS;
}

Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
int depth, ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
valuetype value = reflection->Get##method(*message, field); \
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \
(void)ops.SetAttr(field->name(), value); \
break; \
}
CASE_FIELD_TYPE(INT32, Int32, int32_t, d);
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u);
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld);
CASE_FIELD_TYPE(FLOAT, Float, float, f);
CASE_FIELD_TYPE(BOOL, Bool, bool, d);
#undef CASE_FIELD_TYPE
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field));
int value = reflection->GetEnum(*message, field)->number();
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetString(*message, field);
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str());
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str());
return FAILED;
}
break;
}
default: {
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}),
std::vector<std::string>({"model", field->name(), "Unsupported field type"}));
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse field: %s success.", field->name().c_str());
return SUCCESS;
}

Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth,
ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
int field_size = reflection->FieldSize(*message, field);
if (field_size <= 0) {
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str());
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str());
return FAILED;
}

switch (field->cpp_type()) {
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
vector<valuetype> attr_value; \
for (int i = 0; i < field_size; i++) { \
valuetype value = reflection->GetRepeated##method(*message, field, i); \
attr_value.push_back(value); \
} \
(void)ops.SetAttr(field->name(), attr_value); \
break; \
}
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t);
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t);
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t);
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float);
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool);
CASE_FIELD_TYPE_REPEATED(STRING, String, string);
#undef CASE_FIELD_TYPE_REPEATED
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
nlohmann::json message_json;
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(),
message_json[field->name()], false);
std::string repeated_message_str;
try {
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what());
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string.");
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string.");
return FAILED;
}
(void)ops.SetAttr(field->name(), repeated_message_str);
break;
}
default: {
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}),
std::vector<std::string>({"model", field->name(), "Unsupported field type"}));
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse repeated field: %s success.", field->name().c_str());
return SUCCESS;
}

void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) {
auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name);
if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) {


+ 0
- 40
parser/caffe/caffe_parser.h View File

@@ -209,46 +209,6 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
*/
Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message,
int index, std::vector<ge::Operator> &operators);

/*
* @ingroup domi_omg
* @brief Parse message and set operator attrs
* @param [in] message, message of model
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info
* @return SUCCESS parse message successfully
* @return FAILED parse message failed
*/
Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops);

/*
* @ingroup domi_omg
* @brief Parse field and set operator attrs
* @param [in] reflection, reflection of message
* @param [in] message, message of model
* @param [in] field, field of message
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info
* @return SUCCESS parse field successfully
* @return FAILED parse field failed
*/
Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

/*
* @ingroup domi_omg
* @brief Parse repeated field and set operator attrs
* @param [in] reflection, reflection of message
* @param [in] message, message of model
* @param [in] field, field of message
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info by vector
* @return SUCCESS parse field successfully
* @return FAILED parse field failed
*/
Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

/**
* @ingroup domi_omg
* @brief Add blob information to the bottom_blobs_map and top_blobs_map_


+ 0
- 2
parser/onnx/onnx_parser.cc View File

@@ -566,8 +566,6 @@ Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::
Status status = FAILED;
domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type);
if (parse_param_func == nullptr) {
//std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser);
//GE_CHECK_NOTNULL(onnx_op_parser);
status = op_parser->ParseParams(node_proto, op);
} else {
ge::Operator op_src(node_proto->name(), op_type);


Loading…
Cancel
Save