diff --git a/parser/common/convert/message2operator.cc b/parser/common/convert/message2operator.cc index 9d7ab64..2f08049 100644 --- a/parser/common/convert/message2operator.cc +++ b/parser/common/convert/message2operator.cc @@ -1,169 +1,169 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "message2operator.h" - -#include - -#include "common/convert/pb2json.h" -#include "common/util.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -namespace { -const int kMaxParseDepth = 5; -const uint32_t kInteval = 2; -} // namespace - -Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { - if (depth > kMaxParseDepth) { - REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); - GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); - return FAILED; - } - - const google::protobuf::Reflection *reflection = message->GetReflection(); - GE_CHECK_NOTNULL(reflection); - std::vector field_desc; - reflection->ListFields(*message, &field_desc); - - for (auto &field : field_desc) { - GE_CHECK_NOTNULL(field); - if (field->is_repeated()) { - if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); - return FAILED; - } - } else { - if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); - return FAILED; - } - } - } - return SUCCESS; -} - -Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) { - GELOGD("Start to parse field: %s.", field->name().c_str()); - switch (field->cpp_type()) { -#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ - case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ - valuetype value = reflection->Get##method(*message, field); \ - GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ - (void)ops.SetAttr(field->name(), value); \ - break; \ - } - CASE_FIELD_TYPE(INT32, Int32, int32_t, d); - CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); - CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); - CASE_FIELD_TYPE(FLOAT, Float, float, f); - CASE_FIELD_TYPE(BOOL, Bool, bool, d); -#undef CASE_FIELD_TYPE - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); - int value = reflection->GetEnum(*message, field)->number(); - GELOGD("Parse result(%s : %d)", field->name().c_str(), value); - (void)ops.SetAttr(field->name(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { - string value = reflection->GetString(*message, field); - GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); - (void)ops.SetAttr(field->name(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); - if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); - return FAILED; - } - break; - } - default: { - ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, - {field->name().c_str(), "Unsupported field type"}); - GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); - return FAILED; - } - } - GELOGD("Parse field: %s success.", field->name().c_str()); - return SUCCESS; -} - -Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, - ge::Operator &ops) { - GELOGD("Start to parse field: %s.", field->name().c_str()); - int field_size = reflection->FieldSize(*message, field); - if (field_size <= 0) { - REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); - GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); - return FAILED; - } - - switch (field->cpp_type()) { -#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ - case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ - std::vector attr_value; \ - for (int i = 0; i < field_size; i++) { \ - valuetype value = reflection->GetRepeated##method(*message, field, i); \ - attr_value.push_back(value); \ - } \ - (void)ops.SetAttr(field->name(), attr_value); \ - break; \ - } - CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); - CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); - CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); - CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); - CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); - CASE_FIELD_TYPE_REPEATED(STRING, String, string); -#undef CASE_FIELD_TYPE_REPEATED - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - nlohmann::json message_json; - Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), message_json[field->name()], - false); - std::string repeated_message_str; - try { - repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); - } catch (std::exception &e) { - ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); - GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); - return FAILED; - } catch (...) { - ErrorManager::GetInstance().ATCReportErrMessage("E19008"); - GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); - return FAILED; - } - (void)ops.SetAttr(field->name(), repeated_message_str); - break; - } - default: { - ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, - {field->name().c_str(), "Unsupported field type"}); - GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); - return FAILED; - } - } - GELOGD("Parse repeated field: %s success.", field->name().c_str()); - return SUCCESS; -} +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "message2operator.h" + +#include + +#include "common/convert/pb2json.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace { +const int kMaxParseDepth = 5; +const uint32_t kInteval = 2; +} // namespace + +Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { + if (depth > kMaxParseDepth) { + REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); + GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); + return FAILED; + } + + const google::protobuf::Reflection *reflection = message->GetReflection(); + GE_CHECK_NOTNULL(reflection); + std::vector field_desc; + reflection->ListFields(*message, &field_desc); + + for (auto &field : field_desc) { + GE_CHECK_NOTNULL(field); + if (field->is_repeated()) { + if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { + GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); + return FAILED; + } + } else { + if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { + GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + +Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + switch (field->cpp_type()) { +#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + valuetype value = reflection->Get##method(*message, field); \ + GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ + (void)ops.SetAttr(field->name(), value); \ + break; \ + } + CASE_FIELD_TYPE(INT32, Int32, int32_t, d); + CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); + CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); + CASE_FIELD_TYPE(FLOAT, Float, float, f); + CASE_FIELD_TYPE(BOOL, Bool, bool, d); +#undef CASE_FIELD_TYPE + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); + int value = reflection->GetEnum(*message, field)->number(); + GELOGD("Parse result(%s : %d)", field->name().c_str(), value); + (void)ops.SetAttr(field->name(), value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetString(*message, field); + GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); + (void)ops.SetAttr(field->name(), value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { + GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); + return FAILED; + } + break; + } + default: { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "Unsupported field type"}); + GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); + return FAILED; + } + } + GELOGD("Parse field: %s success.", field->name().c_str()); + return SUCCESS; +} + +Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, + ge::Operator &ops) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + int field_size = reflection->FieldSize(*message, field); + if (field_size <= 0) { + REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); + GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); + return FAILED; + } + + switch (field->cpp_type()) { +#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + std::vector attr_value; \ + for (int i = 0; i < field_size; i++) { \ + valuetype value = reflection->GetRepeated##method(*message, field, i); \ + attr_value.push_back(value); \ + } \ + (void)ops.SetAttr(field->name(), attr_value); \ + break; \ + } + CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); + CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); + CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); + CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); + CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); + CASE_FIELD_TYPE_REPEATED(STRING, String, string); +#undef CASE_FIELD_TYPE_REPEATED + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + nlohmann::json message_json; + Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), message_json[field->name()], + false); + std::string repeated_message_str; + try { + repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); + } catch (std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); + GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); + return FAILED; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E19008"); + GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); + return FAILED; + } + (void)ops.SetAttr(field->name(), repeated_message_str); + break; + } + default: { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "Unsupported field type"}); + GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); + return FAILED; + } + } + GELOGD("Parse repeated field: %s success.", field->name().c_str()); + return SUCCESS; +} } // namespace ge \ No newline at end of file diff --git a/parser/common/convert/message2operator.h b/parser/common/convert/message2operator.h index 0a7b8af..f33d4f3 100644 --- a/parser/common/convert/message2operator.h +++ b/parser/common/convert/message2operator.h @@ -1,38 +1,38 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARSER_MESSAGE2OPERATOR_H -#define PARSER_MESSAGE2OPERATOR_H - -#include "external/ge/ge_api_error_codes.h" -#include "external/graph/operator.h" -#include "google/protobuf/message.h" - -namespace ge { -class Message2Operator { - public: - static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); - - private: - static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); - - static Status ParseRepeatedField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); -}; -} // namespace ge -#endif // PARSER_MESSAGE2OPERATOR_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_MESSAGE2OPERATOR_H +#define PARSER_MESSAGE2OPERATOR_H + +#include "external/ge/ge_api_error_codes.h" +#include "external/graph/operator.h" +#include "google/protobuf/message.h" + +namespace ge { +class Message2Operator { + public: + static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); + + private: + static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); + + static Status ParseRepeatedField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); +}; +} // namespace ge +#endif // PARSER_MESSAGE2OPERATOR_H diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 9f16487..ebc9885 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -306,6 +306,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) set(PARSER_UT_FILES "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" + "testcase/onnx_parser_testcase/message2operator_unittest.cc" "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" ) diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc new file mode 100644 index 0000000..39e1480 --- /dev/null +++ b/tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/convert/message2operator.h" + +#include + +#include "proto/onnx/ge_onnx.pb.h" + +namespace ge { +class UtestMessage2Operator : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +TEST_F(UtestMessage2Operator, message_to_operator_success) { + ge::onnx::NodeProto input_node; + ge::onnx::AttributeProto *attribute = input_node.add_attribute(); + attribute->set_name("attribute"); + attribute->set_type(onnx::AttributeProto::AttributeType(1)); + attribute->set_f(1.0); + ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); + attribute_tensor->set_data_type(1); + attribute_tensor->add_dims(4); + ge::Operator op_src("add", "Add"); + auto ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestMessage2Operator, message_to_operator_fail) { + ge::onnx::NodeProto input_node; + ge::onnx::AttributeProto *attribute = input_node.add_attribute(); + ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); + attribute_tensor->add_double_data(1.00); + + ge::Operator op_src("add", "Add"); + auto ret = Message2Operator::ParseOperatorAttrs(attribute, 6, op_src); + EXPECT_EQ(ret, FAILED); + + ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); + EXPECT_EQ(ret, FAILED); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index ed63849..678b8a6 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -40,10 +40,6 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& } static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { - string node_info; - if(op_src.GetAttr("attribute", node_info)==ge::GRAPH_SUCCESS) { - //std::cout<