Browse Source

remove to new file

pull/321/head
wjm 4 years ago
parent
commit
90fd868979
5 changed files with 266 additions and 212 deletions
  1. +168
    -168
      parser/common/convert/message2operator.cc
  2. +38
    -38
      parser/common/convert/message2operator.h
  3. +1
    -0
      tests/ut/parser/CMakeLists.txt
  4. +58
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc
  5. +1
    -6
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 168
- 168
parser/common/convert/message2operator.cc View File

@@ -1,169 +1,169 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "message2operator.h"
#include <vector>
#include "common/convert/pb2json.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
namespace ge {
namespace {
const int kMaxParseDepth = 5;
const uint32_t kInteval = 2;
} // namespace
Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) {
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return FAILED;
}
const google::protobuf::Reflection *reflection = message->GetReflection();
GE_CHECK_NOTNULL(reflection);
std::vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (field->is_repeated()) {
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str());
return FAILED;
}
} else {
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str());
return FAILED;
}
}
}
return SUCCESS;
}
Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
valuetype value = reflection->Get##method(*message, field); \
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \
(void)ops.SetAttr(field->name(), value); \
break; \
}
CASE_FIELD_TYPE(INT32, Int32, int32_t, d);
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u);
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld);
CASE_FIELD_TYPE(FLOAT, Float, float, f);
CASE_FIELD_TYPE(BOOL, Bool, bool, d);
#undef CASE_FIELD_TYPE
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field));
int value = reflection->GetEnum(*message, field)->number();
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetString(*message, field);
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str());
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str());
return FAILED;
}
break;
}
default: {
ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"},
{field->name().c_str(), "Unsupported field type"});
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse field: %s success.", field->name().c_str());
return SUCCESS;
}
Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth,
ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
int field_size = reflection->FieldSize(*message, field);
if (field_size <= 0) {
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str());
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str());
return FAILED;
}
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
std::vector<valuetype> attr_value; \
for (int i = 0; i < field_size; i++) { \
valuetype value = reflection->GetRepeated##method(*message, field, i); \
attr_value.push_back(value); \
} \
(void)ops.SetAttr(field->name(), attr_value); \
break; \
}
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t);
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t);
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t);
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float);
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool);
CASE_FIELD_TYPE_REPEATED(STRING, String, string);
#undef CASE_FIELD_TYPE_REPEATED
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
nlohmann::json message_json;
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), message_json[field->name()],
false);
std::string repeated_message_str;
try {
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()});
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E19008");
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string.");
return FAILED;
}
(void)ops.SetAttr(field->name(), repeated_message_str);
break;
}
default: {
ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"},
{field->name().c_str(), "Unsupported field type"});
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse repeated field: %s success.", field->name().c_str());
return SUCCESS;
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "message2operator.h"
#include <vector>
#include "common/convert/pb2json.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
namespace ge {
namespace {
const int kMaxParseDepth = 5;
const uint32_t kInteval = 2;
} // namespace
Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) {
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return FAILED;
}
const google::protobuf::Reflection *reflection = message->GetReflection();
GE_CHECK_NOTNULL(reflection);
std::vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (field->is_repeated()) {
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str());
return FAILED;
}
} else {
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str());
return FAILED;
}
}
}
return SUCCESS;
}
Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
valuetype value = reflection->Get##method(*message, field); \
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \
(void)ops.SetAttr(field->name(), value); \
break; \
}
CASE_FIELD_TYPE(INT32, Int32, int32_t, d);
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u);
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld);
CASE_FIELD_TYPE(FLOAT, Float, float, f);
CASE_FIELD_TYPE(BOOL, Bool, bool, d);
#undef CASE_FIELD_TYPE
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field));
int value = reflection->GetEnum(*message, field)->number();
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetString(*message, field);
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str());
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str());
return FAILED;
}
break;
}
default: {
ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"},
{field->name().c_str(), "Unsupported field type"});
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse field: %s success.", field->name().c_str());
return SUCCESS;
}
Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth,
ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
int field_size = reflection->FieldSize(*message, field);
if (field_size <= 0) {
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str());
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str());
return FAILED;
}
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
std::vector<valuetype> attr_value; \
for (int i = 0; i < field_size; i++) { \
valuetype value = reflection->GetRepeated##method(*message, field, i); \
attr_value.push_back(value); \
} \
(void)ops.SetAttr(field->name(), attr_value); \
break; \
}
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t);
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t);
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t);
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float);
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool);
CASE_FIELD_TYPE_REPEATED(STRING, String, string);
#undef CASE_FIELD_TYPE_REPEATED
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
nlohmann::json message_json;
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), message_json[field->name()],
false);
std::string repeated_message_str;
try {
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()});
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E19008");
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string.");
return FAILED;
}
(void)ops.SetAttr(field->name(), repeated_message_str);
break;
}
default: {
ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"},
{field->name().c_str(), "Unsupported field type"});
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse repeated field: %s success.", field->name().c_str());
return SUCCESS;
}
} // namespace ge

+ 38
- 38
parser/common/convert/message2operator.h View File

@@ -1,38 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PARSER_MESSAGE2OPERATOR_H
#define PARSER_MESSAGE2OPERATOR_H
#include "external/ge/ge_api_error_codes.h"
#include "external/graph/operator.h"
#include "google/protobuf/message.h"
namespace ge {
class Message2Operator {
public:
static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops);
private:
static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);
static Status ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);
};
} // namespace ge
#endif // PARSER_MESSAGE2OPERATOR_H
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PARSER_MESSAGE2OPERATOR_H
#define PARSER_MESSAGE2OPERATOR_H
#include "external/ge/ge_api_error_codes.h"
#include "external/graph/operator.h"
#include "google/protobuf/message.h"
namespace ge {
class Message2Operator {
public:
static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops);
private:
static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);
static Status ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);
};
} // namespace ge
#endif // PARSER_MESSAGE2OPERATOR_H

+ 1
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -306,6 +306,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)

set(PARSER_UT_FILES
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc"
"testcase/onnx_parser_testcase/message2operator_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
)



+ 58
- 0
tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/convert/message2operator.h"

#include <gtest/gtest.h>

#include "proto/onnx/ge_onnx.pb.h"

namespace ge {
class UtestMessage2Operator : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

TEST_F(UtestMessage2Operator, message_to_operator_success) {
ge::onnx::NodeProto input_node;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("attribute");
attribute->set_type(onnx::AttributeProto::AttributeType(1));
attribute->set_f(1.0);
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
attribute_tensor->set_data_type(1);
attribute_tensor->add_dims(4);
ge::Operator op_src("add", "Add");
auto ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestMessage2Operator, message_to_operator_fail) {
ge::onnx::NodeProto input_node;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
attribute_tensor->add_double_data(1.00);

ge::Operator op_src("add", "Add");
auto ret = Message2Operator::ParseOperatorAttrs(attribute, 6, op_src);
EXPECT_EQ(ret, FAILED);

ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src);
EXPECT_EQ(ret, FAILED);
}
} // namespace ge

+ 1
- 6
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -40,10 +40,6 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator&
}

static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) {
string node_info;
if(op_src.GetAttr("attribute", node_info)==ge::GRAPH_SUCCESS) {
//std::cout<<node_info<<std::endl;
}
return SUCCESS;
}

@@ -86,8 +82,7 @@ void UtestOnnxParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Add")
.ParseParamsFn(ParseParams)
.ParseParamsByOperatorFn(ParseParamByOpFunc);
.ParseParamsFn(ParseParams);

REGISTER_CUSTOM_OP("Identity")
.FrameworkType(domi::ONNX)


Loading…
Cancel
Save