Browse Source

clean code

pull/640/head
13291271729 3 years ago
parent
commit
3985ac6c86
3 changed files with 30 additions and 191 deletions
  1. +30
    -120
      parser/common/convert/message2operator.cc
  2. +0
    -9
      parser/common/convert/message2operator.h
  3. +0
    -62
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 30
- 120
parser/common/convert/message2operator.cc View File

@@ -44,16 +44,12 @@ Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *mes
for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (field->is_repeated()) {
std::cout << "1111" << std::endl;
if (ParseRepeatedField(reflection, message, field, ops) != SUCCESS) {
std::cout << "[Parse][RepeatedField]" << field->name().c_str() << "failed." << std::endl;
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str());
return FAILED;
}
} else {
std::cout << "2222" << std::endl;
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) {
std::cout << "[Parse][Field]" << field->name().c_str() << "failed." << std::endl;
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str());
return FAILED;
}
@@ -62,57 +58,24 @@ Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *mes
return SUCCESS;
}

Status Message2Operator::ParseBaseTypeField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, ge::Operator &ops) {
std::cout << "weewf" << std::endl;
switch (field->cpp_type()) {
case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
int32_t value = reflection->GetInt32(*message, field);
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name().c_str(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
std::cout << "weCPPTYPE_UINT32ewf" << std::endl;
uint32_t value = reflection->GetUInt32(*message, field);
GELOGD("Parse result(%s : %u)", field->name().c_str(), value);
(void)ops.SetAttr(field->name().c_str(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
std::cout << "CPPTYPE_INT64" << std::endl;
int64_t value = reflection->GetInt64(*message, field);
GELOGD("Parse result(%s : %ld)", field->name().c_str(), value);
(void)ops.SetAttr(field->name().c_str(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
float value = reflection->GetFloat(*message, field);
GELOGD("Parse result(%s : %f)", field->name().c_str(), value);
(void)ops.SetAttr(field->name().c_str(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
bool value = reflection->GetBool(*message, field);
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name().c_str(), value);
break;
}
default: {
return FAILED;
}
}
return SUCCESS;
}

Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
GE_CHK_BOOL_RET_STATUS(ParseBaseTypeField(reflection, message, field, ops) == FAILED, SUCCESS,
"Parse field: %s success.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
valuetype value = reflection->Get##method(*message, field); \
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \
(void)ops.SetAttr(field->name().c_str(), value); \
break; \
}
CASE_FIELD_TYPE(INT32, Int32, int32_t, d);
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u);
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld);
CASE_FIELD_TYPE(FLOAT, Float, float, f);
CASE_FIELD_TYPE(BOOL, Bool, bool, d);
#undef CASE_FIELD_TYPE
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field));
int value = reflection->GetEnum(*message, field)->number();
@@ -145,74 +108,6 @@ Status Message2Operator::ParseField(const google::protobuf::Reflection *reflecti
return SUCCESS;
}

Status Message2Operator::ParseRepeatedBaseTypeField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
ge::Operator &ops, const int field_size) {
switch (field->cpp_type()) {
case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
std::vector<int32_t> attr_value;
for (int i = 0; i < field_size; i++) {
int32_t value = reflection->GetRepeatedInt32(*message, field, i);
attr_value.push_back(value);
}
(void)ops.SetAttr(field->name().c_str(), attr_value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
std::vector<uint32_t> attr_value;
std::cout << "sdfsfCPPTYPE_UINT32";
for (int i = 0; i < field_size; i++) {
uint32_t value = reflection->GetRepeatedUInt32(*message, field, i);
attr_value.push_back(value);
}
(void)ops.SetAttr(field->name().c_str(), attr_value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
std::vector<int64_t> attr_value;
for (int i = 0; i < field_size; i++) {
int64_t value = reflection->GetRepeatedInt64(*message, field, i);
attr_value.push_back(value);
}
(void)ops.SetAttr(field->name().c_str(), attr_value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
std::cout << "sdfsfCPPTYPE_FLOAT" << std::endl;
std::vector<float> attr_value;
for (int i = 0; i < field_size; i++) {
float value = reflection->GetRepeatedFloat(*message, field, i);
attr_value.push_back(value);
}
(void)ops.SetAttr(field->name().c_str(), attr_value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
std::vector<bool> attr_value;
for (int i = 0; i < field_size; i++) {
bool value = reflection->GetRepeatedBool(*message, field, i);
attr_value.push_back(value);
}
(void)ops.SetAttr(field->name().c_str(), attr_value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
std::vector<string> attr_value;
for (int i = 0; i < field_size; i++) {
string value = reflection->GetRepeatedString(*message, field, i);
attr_value.push_back(value);
}
(void)ops.SetAttr(field->name().c_str(), attr_value);
break;
}
default: {
return FAILED;
}
}
return SUCCESS;
}

Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
@@ -224,9 +119,24 @@ Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str());
return FAILED;
}
GE_CHK_BOOL_RET_STATUS(ParseRepeatedBaseTypeField(reflection, message, field, ops, field_size) == FAILED, SUCCESS,
"Parse repeated field: %s success.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
std::vector<valuetype> attr_value; \
for (int i = 0; i < field_size; i++) { \
valuetype value = reflection->GetRepeated##method(*message, field, i); \
attr_value.push_back(value); \
} \
(void)ops.SetAttr(field->name().c_str(), attr_value); \
break; \
}
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t);
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t);
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t);
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float);
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool);
CASE_FIELD_TYPE_REPEATED(STRING, String, string);
#undef CASE_FIELD_TYPE_REPEATED
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
nlohmann::json message_json;
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), message_json[field->name()],


+ 0
- 9
parser/common/convert/message2operator.h View File

@@ -33,15 +33,6 @@ class Message2Operator {
static Status ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, ge::Operator &ops);

static Status ParseBaseTypeField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, ge::Operator &ops);

static Status ParseRepeatedBaseTypeField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
ge::Operator &ops, const int field_size);
};
} // namespace ge
#endif // PARSER_MESSAGE2OPERATOR_H

+ 0
- 62
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -3622,68 +3622,6 @@ TEST_F(UtestTensorflowParser, tensorflow_Message2Operator_ParseOperatorAttrs_tes
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_Message2Operator_ParseOperatorAttrs_success)
{
Message2Operator mess2Op;
tensorflow::NodeDef nodedef;
nodedef.set_name("const");
nodedef.set_op("const");

ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>();
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);

tensorflow::AttrValue value;
value.set_s("string");
TensorFlowUtil::AddNodeAttr("str", value, &nodedef);
value.clear_value();
value.set_i(1);
TensorFlowUtil::AddNodeAttr("num", value, &nodedef);
value.clear_value();
domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
value.set_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("float", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_UINT32;
value.set_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("uint32", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_INT64;
value.set_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("int64", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_BOOL;
value.set_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("bool", value, &nodedef);
Status ret = mess2Op.ParseOperatorAttrs(&nodedef, 1, ops);
EXPECT_EQ(ret, SUCCESS);

value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_STRING;
value.mutable_list()->add_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("str_list", value, &nodedef);
value.clear_value();
value.mutable_list()->add_i(1);
TensorFlowUtil::AddNodeAttr("num_list", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
value.mutable_list()->add_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("float_list", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_UINT32;
value.mutable_list()->add_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("uint32_list", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_INT64;
value.mutable_list()->add_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("int64_list", value, &nodedef);
value.clear_value();
VALUE_TYPE = domi::tensorflow::DataType::DT_BOOL;
value.mutable_list()->add_type(VALUE_TYPE);
TensorFlowUtil::AddNodeAttr("bool_list", value, &nodedef);
ret = mess2Op.ParseOperatorAttrs(&nodedef, 1, ops);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_RepeatedEnum2Json_test)
{
Pb2Json toJson;


Loading…
Cancel
Save