From 3985ac6c86f2833bee2c7f03d14a651436e9dd60 Mon Sep 17 00:00:00 2001 From: 13291271729 Date: Wed, 31 Aug 2022 16:03:59 +0800 Subject: [PATCH] clean code --- parser/common/convert/message2operator.cc | 150 ++++-------------- parser/common/convert/message2operator.h | 9 -- .../tensorflow_parser_unittest.cc | 62 -------- 3 files changed, 30 insertions(+), 191 deletions(-) diff --git a/parser/common/convert/message2operator.cc b/parser/common/convert/message2operator.cc index e6cc662..e02643e 100644 --- a/parser/common/convert/message2operator.cc +++ b/parser/common/convert/message2operator.cc @@ -44,16 +44,12 @@ Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *mes for (auto &field : field_desc) { GE_CHECK_NOTNULL(field); if (field->is_repeated()) { - std::cout << "1111" << std::endl; if (ParseRepeatedField(reflection, message, field, ops) != SUCCESS) { - std::cout << "[Parse][RepeatedField]" << field->name().c_str() << "failed." << std::endl; GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); return FAILED; } } else { - std::cout << "2222" << std::endl; if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { - std::cout << "[Parse][Field]" << field->name().c_str() << "failed." << std::endl; GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); return FAILED; } @@ -62,57 +58,24 @@ Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *mes return SUCCESS; } -Status Message2Operator::ParseBaseTypeField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, ge::Operator &ops) { - std::cout << "weewf" << std::endl; - switch (field->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { - int32_t value = reflection->GetInt32(*message, field); - GELOGD("Parse result(%s : %d)", field->name().c_str(), value); - (void)ops.SetAttr(field->name().c_str(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { - std::cout << "weCPPTYPE_UINT32ewf" << std::endl; - uint32_t value = reflection->GetUInt32(*message, field); - GELOGD("Parse result(%s : %u)", field->name().c_str(), value); - (void)ops.SetAttr(field->name().c_str(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { - std::cout << "CPPTYPE_INT64" << std::endl; - int64_t value = reflection->GetInt64(*message, field); - GELOGD("Parse result(%s : %ld)", field->name().c_str(), value); - (void)ops.SetAttr(field->name().c_str(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { - float value = reflection->GetFloat(*message, field); - GELOGD("Parse result(%s : %f)", field->name().c_str(), value); - (void)ops.SetAttr(field->name().c_str(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { - bool value = reflection->GetBool(*message, field); - GELOGD("Parse result(%s : %d)", field->name().c_str(), value); - (void)ops.SetAttr(field->name().c_str(), value); - break; - } - default: { - return FAILED; - } - } - return SUCCESS; -} - Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) { GELOGD("Start to parse field: %s.", field->name().c_str()); - GE_CHK_BOOL_RET_STATUS(ParseBaseTypeField(reflection, message, field, ops) == FAILED, SUCCESS, - "Parse field: %s success.", field->name().c_str()); switch (field->cpp_type()) { +#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + valuetype value = reflection->Get##method(*message, field); \ + GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ + (void)ops.SetAttr(field->name().c_str(), value); \ + break; \ + } + CASE_FIELD_TYPE(INT32, Int32, int32_t, d); + CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); + CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); + CASE_FIELD_TYPE(FLOAT, Float, float, f); + CASE_FIELD_TYPE(BOOL, Bool, bool, d); +#undef CASE_FIELD_TYPE case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); int value = reflection->GetEnum(*message, field)->number(); @@ -145,74 +108,6 @@ Status Message2Operator::ParseField(const google::protobuf::Reflection *reflecti return SUCCESS; } -Status Message2Operator::ParseRepeatedBaseTypeField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, - ge::Operator &ops, const int field_size) { - switch (field->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { - std::vector attr_value; - for (int i = 0; i < field_size; i++) { - int32_t value = reflection->GetRepeatedInt32(*message, field, i); - attr_value.push_back(value); - } - (void)ops.SetAttr(field->name().c_str(), attr_value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { - std::vector attr_value; - std::cout << "sdfsfCPPTYPE_UINT32"; - for (int i = 0; i < field_size; i++) { - uint32_t value = reflection->GetRepeatedUInt32(*message, field, i); - attr_value.push_back(value); - } - (void)ops.SetAttr(field->name().c_str(), attr_value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { - std::vector attr_value; - for (int i = 0; i < field_size; i++) { - int64_t value = reflection->GetRepeatedInt64(*message, field, i); - attr_value.push_back(value); - } - (void)ops.SetAttr(field->name().c_str(), attr_value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { - std::cout << "sdfsfCPPTYPE_FLOAT" << std::endl; - std::vector attr_value; - for (int i = 0; i < field_size; i++) { - float value = reflection->GetRepeatedFloat(*message, field, i); - attr_value.push_back(value); - } - (void)ops.SetAttr(field->name().c_str(), attr_value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { - std::vector attr_value; - for (int i = 0; i < field_size; i++) { - bool value = reflection->GetRepeatedBool(*message, field, i); - attr_value.push_back(value); - } - (void)ops.SetAttr(field->name().c_str(), attr_value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { - std::vector attr_value; - for (int i = 0; i < field_size; i++) { - string value = reflection->GetRepeatedString(*message, field, i); - attr_value.push_back(value); - } - (void)ops.SetAttr(field->name().c_str(), attr_value); - break; - } - default: { - return FAILED; - } - } - return SUCCESS; -} - Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, const google::protobuf::FieldDescriptor *field, @@ -224,9 +119,24 @@ Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection * GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); return FAILED; } - GE_CHK_BOOL_RET_STATUS(ParseRepeatedBaseTypeField(reflection, message, field, ops, field_size) == FAILED, SUCCESS, - "Parse repeated field: %s success.", field->name().c_str()); switch (field->cpp_type()) { +#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + std::vector attr_value; \ + for (int i = 0; i < field_size; i++) { \ + valuetype value = reflection->GetRepeated##method(*message, field, i); \ + attr_value.push_back(value); \ + } \ + (void)ops.SetAttr(field->name().c_str(), attr_value); \ + break; \ + } + CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); + CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); + CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); + CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); + CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); + CASE_FIELD_TYPE_REPEATED(STRING, String, string); +#undef CASE_FIELD_TYPE_REPEATED case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { nlohmann::json message_json; Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), message_json[field->name()], diff --git a/parser/common/convert/message2operator.h b/parser/common/convert/message2operator.h index cbf885d..b247112 100644 --- a/parser/common/convert/message2operator.h +++ b/parser/common/convert/message2operator.h @@ -33,15 +33,6 @@ class Message2Operator { static Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, const google::protobuf::FieldDescriptor *field, ge::Operator &ops); - - static Status ParseBaseTypeField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, ge::Operator &ops); - - static Status ParseRepeatedBaseTypeField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, - ge::Operator &ops, const int field_size); }; } // namespace ge #endif // PARSER_MESSAGE2OPERATOR_H diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 8223c9f..00678d8 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -3622,68 +3622,6 @@ TEST_F(UtestTensorflowParser, tensorflow_Message2Operator_ParseOperatorAttrs_tes EXPECT_EQ(ret, SUCCESS); } -TEST_F(UtestTensorflowParser, tensorflow_Message2Operator_ParseOperatorAttrs_success) -{ - Message2Operator mess2Op; - tensorflow::NodeDef nodedef; - nodedef.set_name("const"); - nodedef.set_op("const"); - - ge::OpDescPtr op_desc = std::make_shared(); - ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - tensorflow::AttrValue value; - value.set_s("string"); - TensorFlowUtil::AddNodeAttr("str", value, &nodedef); - value.clear_value(); - value.set_i(1); - TensorFlowUtil::AddNodeAttr("num", value, &nodedef); - value.clear_value(); - domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; - value.set_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("float", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_UINT32; - value.set_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("uint32", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_INT64; - value.set_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("int64", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_BOOL; - value.set_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("bool", value, &nodedef); - Status ret = mess2Op.ParseOperatorAttrs(&nodedef, 1, ops); - EXPECT_EQ(ret, SUCCESS); - - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_STRING; - value.mutable_list()->add_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("str_list", value, &nodedef); - value.clear_value(); - value.mutable_list()->add_i(1); - TensorFlowUtil::AddNodeAttr("num_list", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT; - value.mutable_list()->add_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("float_list", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_UINT32; - value.mutable_list()->add_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("uint32_list", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_INT64; - value.mutable_list()->add_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("int64_list", value, &nodedef); - value.clear_value(); - VALUE_TYPE = domi::tensorflow::DataType::DT_BOOL; - value.mutable_list()->add_type(VALUE_TYPE); - TensorFlowUtil::AddNodeAttr("bool_list", value, &nodedef); - ret = mess2Op.ParseOperatorAttrs(&nodedef, 1, ops); - EXPECT_EQ(ret, SUCCESS); -} - TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_RepeatedEnum2Json_test) { Pb2Json toJson;