| Author | SHA1 | Message | Date |
|---|---|---|---|
|
|
6c69b97e86 |
!339 fixed coverity warning
Merge pull request !339 from 李磊/r1.5.0 |
4 years ago |
|
|
9462b9675d |
!342 update submodule
Merge pull request !342 from HW_KK/r1.5.0 |
4 years ago |
|
|
14d1a77ddf | update submodule | 4 years ago |
|
|
596b630a4d | fixed coverity warning | 4 years ago |
|
|
b59a36e241 |
!338 fix coverity
Merge pull request !338 from wangjiming/r1.5.0 |
4 years ago |
|
|
b79ef8ad19 |
!330 update owners
Merge pull request !330 from 王涛/r1.5.0 |
4 years ago |
|
|
aec3c227ca |
!328 custom op register
Merge pull request !328 from wangjiming/r1.5.0 |
4 years ago |
|
|
c074dfa596 |
!329 update protobuf to 3.13.0
Merge pull request !329 from 李磊/r1.5.0 |
4 years ago |
|
|
57505df4ab | update version of protobuf to v3.13.0 | 4 years ago |
|
|
59ac22dfe4 | update owners | 4 years ago |
|
|
978cf3a0df |
!325 update submodule file
Merge pull request !325 from 王涛/r1.5.0 |
4 years ago |
|
|
4c82774e0c | update .gitmodules. | 4 years ago |
|
|
29a321b404 | fix | 4 years ago |
|
|
6c9441a473 | fix | 4 years ago |
|
|
6e3fa785dc | custom op register | 4 years ago |
| @@ -1,4 +1,4 @@ | |||||
| [submodule "metadef"] | [submodule "metadef"] | ||||
| path = metadef | path = metadef | ||||
| url = https://gitee.com/ascend/metadef.git | url = https://gitee.com/ascend/metadef.git | ||||
| branch = master | |||||
| branch = r1.5.0 | |||||
| @@ -1,7 +1,11 @@ | |||||
| approvers: | approvers: | ||||
| - ji_chen | - ji_chen | ||||
| - wqtshg | |||||
| - ljl0711 | |||||
| - startzgf168 | |||||
| - lbisdaddy | |||||
| - liyihan123 | |||||
| reviewers: | reviewers: | ||||
| - xchu42 | - xchu42 | ||||
| - sheng-nan | - sheng-nan | ||||
| - wqtshg | |||||
| - wangxiaotian22 | |||||
| - zhangxiaokun9 | |||||
| @@ -15,7 +15,7 @@ else() | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") | ||||
| set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") | ||||
| else() | else() | ||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") | |||||
| set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | set(MD5 "1a6274bc4a65b55a6fa70e264d796490") | ||||
| endif () | endif () | ||||
| endif() | endif() | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit c6030152c6dc05515115765babb5d64fde649df4 | |||||
| Subproject commit 3ace5b6f10e0af784a1c3211fd769d6e8860e864 | |||||
| @@ -21,6 +21,7 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "common/convert/message2operator.h" | |||||
| #include "parser/common/convert/pb2json.h" | #include "parser/common/convert/pb2json.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
| #include "common/op_map.h" | #include "common/op_map.h" | ||||
| @@ -202,11 +203,9 @@ const int32_t kAnchorIndexTwo = 2; | |||||
| const int32_t kAnchorIndexThree = 3; | const int32_t kAnchorIndexThree = 3; | ||||
| const int32_t kNumOne = 1; | const int32_t kNumOne = 1; | ||||
| const size_t kTensorNum = 2; | const size_t kTensorNum = 2; | ||||
| const int kMaxParseDepth = 5; | |||||
| const int32_t kMinLineWorldSize = 3; | const int32_t kMinLineWorldSize = 3; | ||||
| const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 | const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 | ||||
| const int32_t kBase = 10; | const int32_t kBase = 10; | ||||
| const uint32_t kInteval = 2; | |||||
| const char *const kPython = "Python"; | const char *const kPython = "Python"; | ||||
| const char *const kProposalLayer = "ProposalLayer"; | const char *const kProposalLayer = "ProposalLayer"; | ||||
| const char *const kDetectionOutput = "DetectionOutput"; | const char *const kDetectionOutput = "DetectionOutput"; | ||||
| @@ -578,7 +577,7 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) { | |||||
| if (Message2Operator::ParseOperatorAttrs(message, 1, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str()); | GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -589,146 +588,6 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { | |||||
| if (depth > kMaxParseDepth) { | |||||
| REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); | |||||
| GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); | |||||
| return FAILED; | |||||
| } | |||||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
| GE_CHECK_NOTNULL(reflection); | |||||
| vector<const google::protobuf::FieldDescriptor *> field_desc; | |||||
| reflection->ListFields(*message, &field_desc); | |||||
| for (auto &field : field_desc) { | |||||
| GE_CHECK_NOTNULL(field); | |||||
| if (field->is_repeated()) { | |||||
| if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection, | |||||
| const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, | |||||
| int depth, ge::Operator &ops) { | |||||
| GELOGD("Start to parse field: %s.", field->name().c_str()); | |||||
| switch (field->cpp_type()) { | |||||
| #define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||||
| valuetype value = reflection->Get##method(*message, field); \ | |||||
| GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ | |||||
| (void)ops.SetAttr(field->name(), value); \ | |||||
| break; \ | |||||
| } | |||||
| CASE_FIELD_TYPE(INT32, Int32, int32_t, d); | |||||
| CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); | |||||
| CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); | |||||
| CASE_FIELD_TYPE(FLOAT, Float, float, f); | |||||
| CASE_FIELD_TYPE(BOOL, Bool, bool, d); | |||||
| #undef CASE_FIELD_TYPE | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { | |||||
| GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); | |||||
| int value = reflection->GetEnum(*message, field)->number(); | |||||
| GELOGD("Parse result(%s : %d)", field->name().c_str(), value); | |||||
| (void)ops.SetAttr(field->name(), value); | |||||
| break; | |||||
| } | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { | |||||
| string value = reflection->GetString(*message, field); | |||||
| GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); | |||||
| (void)ops.SetAttr(field->name(), value); | |||||
| break; | |||||
| } | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||||
| const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||||
| if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||||
| std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||||
| GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("Parse field: %s success.", field->name().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection, | |||||
| const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, | |||||
| ge::Operator &ops) { | |||||
| GELOGD("Start to parse field: %s.", field->name().c_str()); | |||||
| int field_size = reflection->FieldSize(*message, field); | |||||
| if (field_size <= 0) { | |||||
| REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); | |||||
| GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| switch (field->cpp_type()) { | |||||
| #define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||||
| vector<valuetype> attr_value; \ | |||||
| for (int i = 0; i < field_size; i++) { \ | |||||
| valuetype value = reflection->GetRepeated##method(*message, field, i); \ | |||||
| attr_value.push_back(value); \ | |||||
| } \ | |||||
| (void)ops.SetAttr(field->name(), attr_value); \ | |||||
| break; \ | |||||
| } | |||||
| CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); | |||||
| CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); | |||||
| CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); | |||||
| CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); | |||||
| CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); | |||||
| CASE_FIELD_TYPE_REPEATED(STRING, String, string); | |||||
| #undef CASE_FIELD_TYPE_REPEATED | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||||
| nlohmann::json message_json; | |||||
| Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), | |||||
| message_json[field->name()], false); | |||||
| std::string repeated_message_str; | |||||
| try { | |||||
| repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||||
| } catch (std::exception &e) { | |||||
| REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); | |||||
| GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); | |||||
| return FAILED; | |||||
| } catch (...) { | |||||
| REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); | |||||
| GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); | |||||
| return FAILED; | |||||
| } | |||||
| (void)ops.SetAttr(field->name(), repeated_message_str); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||||
| std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||||
| GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("Parse repeated field: %s success.", field->name().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { | void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { | ||||
| auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); | auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); | ||||
| if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { | if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { | ||||
| @@ -209,46 +209,6 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
| */ | */ | ||||
| Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, | Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, | ||||
| int index, std::vector<ge::Operator> &operators); | int index, std::vector<ge::Operator> &operators); | ||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse message and set operator attrs | |||||
| * @param [in] message, message of model | |||||
| * @param [in/out] depth, depth of recursion | |||||
| * @param [out] ops, operator saving custom info | |||||
| * @return SUCCESS parse message successfully | |||||
| * @return FAILED parse message failed | |||||
| */ | |||||
| Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse field and set operator attrs | |||||
| * @param [in] reflection, reflection of message | |||||
| * @param [in] message, message of model | |||||
| * @param [in] field, field of message | |||||
| * @param [in/out] depth, depth of recursion | |||||
| * @param [out] ops, operator saving custom info | |||||
| * @return SUCCESS parse field successfully | |||||
| * @return FAILED parse field failed | |||||
| */ | |||||
| Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse repeated field and set operator attrs | |||||
| * @param [in] reflection, reflection of message | |||||
| * @param [in] message, message of model | |||||
| * @param [in] field, field of message | |||||
| * @param [in/out] depth, depth of recursion | |||||
| * @param [out] ops, operator saving custom info by vector | |||||
| * @return SUCCESS parse field successfully | |||||
| * @return FAILED parse field failed | |||||
| */ | |||||
| Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief Add blob information to the bottom_blobs_map and top_blobs_map_ | * @brief Add blob information to the bottom_blobs_map and top_blobs_map_ | ||||
| @@ -15,6 +15,7 @@ set(SRC_LIST | |||||
| "../tensorflow/tensorflow_fusion_op_parser.cc" | "../tensorflow/tensorflow_fusion_op_parser.cc" | ||||
| "../tensorflow/tensorflow_util.cc" | "../tensorflow/tensorflow_util.cc" | ||||
| "convert/pb2json.cc" | "convert/pb2json.cc" | ||||
| "convert/message2operator.cc" | |||||
| "op_def/ir_pb_converter.cc" | "op_def/ir_pb_converter.cc" | ||||
| "op_def/defs.cc" | "op_def/defs.cc" | ||||
| "op_def/op_schema.cc" | "op_def/op_schema.cc" | ||||
| @@ -0,0 +1,170 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "message2operator.h" | |||||
| #include <vector> | |||||
| #include "common/convert/pb2json.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const int kMaxParseDepth = 5; | |||||
| const uint32_t kInteval = 2; | |||||
| } // namespace | |||||
| Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { | |||||
| GE_CHECK_NOTNULL(message); | |||||
| if (depth > kMaxParseDepth) { | |||||
| REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); | |||||
| GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); | |||||
| return FAILED; | |||||
| } | |||||
| const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
| GE_CHECK_NOTNULL(reflection); | |||||
| std::vector<const google::protobuf::FieldDescriptor *> field_desc; | |||||
| reflection->ListFields(*message, &field_desc); | |||||
| for (auto &field : field_desc) { | |||||
| GE_CHECK_NOTNULL(field); | |||||
| if (field->is_repeated()) { | |||||
| if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection, | |||||
| const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) { | |||||
| GELOGD("Start to parse field: %s.", field->name().c_str()); | |||||
| switch (field->cpp_type()) { | |||||
| #define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||||
| valuetype value = reflection->Get##method(*message, field); \ | |||||
| GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ | |||||
| (void)ops.SetAttr(field->name(), value); \ | |||||
| break; \ | |||||
| } | |||||
| CASE_FIELD_TYPE(INT32, Int32, int32_t, d); | |||||
| CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); | |||||
| CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); | |||||
| CASE_FIELD_TYPE(FLOAT, Float, float, f); | |||||
| CASE_FIELD_TYPE(BOOL, Bool, bool, d); | |||||
| #undef CASE_FIELD_TYPE | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { | |||||
| GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); | |||||
| int value = reflection->GetEnum(*message, field)->number(); | |||||
| GELOGD("Parse result(%s : %d)", field->name().c_str(), value); | |||||
| (void)ops.SetAttr(field->name(), value); | |||||
| break; | |||||
| } | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { | |||||
| string value = reflection->GetString(*message, field); | |||||
| GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); | |||||
| (void)ops.SetAttr(field->name(), value); | |||||
| break; | |||||
| } | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||||
| const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||||
| if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||||
| std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||||
| GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("Parse field: %s success.", field->name().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection, | |||||
| const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, | |||||
| ge::Operator &ops) { | |||||
| GELOGD("Start to parse field: %s.", field->name().c_str()); | |||||
| int field_size = reflection->FieldSize(*message, field); | |||||
| if (field_size <= 0) { | |||||
| REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); | |||||
| GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| switch (field->cpp_type()) { | |||||
| #define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||||
| std::vector<valuetype> attr_value; \ | |||||
| for (int i = 0; i < field_size; i++) { \ | |||||
| valuetype value = reflection->GetRepeated##method(*message, field, i); \ | |||||
| attr_value.push_back(value); \ | |||||
| } \ | |||||
| (void)ops.SetAttr(field->name(), attr_value); \ | |||||
| break; \ | |||||
| } | |||||
| CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); | |||||
| CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); | |||||
| CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); | |||||
| CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); | |||||
| CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); | |||||
| CASE_FIELD_TYPE_REPEATED(STRING, String, string); | |||||
| #undef CASE_FIELD_TYPE_REPEATED | |||||
| case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||||
| nlohmann::json message_json; | |||||
| Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), message_json[field->name()], | |||||
| false); | |||||
| std::string repeated_message_str; | |||||
| try { | |||||
| repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||||
| } catch (std::exception &e) { | |||||
| REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); | |||||
| GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); | |||||
| return FAILED; | |||||
| } catch (...) { | |||||
| REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); | |||||
| GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); | |||||
| return FAILED; | |||||
| } | |||||
| (void)ops.SetAttr(field->name(), repeated_message_str); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||||
| std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||||
| GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("Parse repeated field: %s success.", field->name().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_MESSAGE2OPERATOR_H | |||||
| #define PARSER_MESSAGE2OPERATOR_H | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| #include "external/graph/operator.h" | |||||
| #include "google/protobuf/message.h" | |||||
| namespace ge { | |||||
| class Message2Operator { | |||||
| public: | |||||
| static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); | |||||
| private: | |||||
| static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||||
| static Status ParseRepeatedField(const google::protobuf::Reflection *reflection, | |||||
| const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_MESSAGE2OPERATOR_H | |||||
| @@ -1,4 +1,4 @@ | |||||
| #!/usr/bin/python3 | |||||
| #!/usr/bin/env python3 | |||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| #------------------------------------------------------------------- | #------------------------------------------------------------------- | ||||
| # Purpose: | # Purpose: | ||||
| @@ -15,23 +15,25 @@ | |||||
| */ | */ | ||||
| #include "parser/onnx/onnx_custom_parser_adapter.h" | #include "parser/onnx/onnx_custom_parser_adapter.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| using domi::ParseParamFunc; | |||||
| using domi::ONNX; | using domi::ONNX; | ||||
| using domi::ParseParamByOpFunc; | |||||
| using domi::ParseParamFunc; | |||||
| namespace ge{ | |||||
| namespace ge { | |||||
| Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | ||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | ||||
| GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
| GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | ||||
| ParseParamFunc | |||||
| custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); | |||||
| ParseParamFunc custom_op_parser = | |||||
| domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); | |||||
| GE_CHECK_NOTNULL(custom_op_parser); | GE_CHECK_NOTNULL(custom_op_parser); | ||||
| if (custom_op_parser(op_src, op_dest) != SUCCESS) { | if (custom_op_parser(op_src, op_dest) != SUCCESS) { | ||||
| GELOGE(FAILED, "[Invoke][Custom_Op_Parser] Custom parser params failed."); | GELOGE(FAILED, "[Invoke][Custom_Op_Parser] Custom parser params failed."); | ||||
| @@ -40,5 +42,18 @@ Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OnnxCustomParserAdapter::ParseParams(const Operator &op_src, Operator &op_dest) { | |||||
| ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); | |||||
| GE_CHECK_NOTNULL(custom_op_parser); | |||||
| if (custom_op_parser(op_src, op_dest) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Invoke][Custom_Op_Parser] failed, node name:%s, type:%s", op_src.GetName().c_str(), | |||||
| op_src.GetOpType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter); | REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter); | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -28,6 +28,8 @@ class PARSER_FUNC_VISIBILITY OnnxCustomParserAdapter : public OnnxOpParser { | |||||
| /// @return SUCCESS parse successfully | /// @return SUCCESS parse successfully | ||||
| /// @return FAILED parse failed | /// @return FAILED parse failed | ||||
| Status ParseParams(const Message *op_src, ge::Operator &op_dest) override; | Status ParseParams(const Message *op_src, ge::Operator &op_dest) override; | ||||
| Status ParseParams(const Operator &op_src, Operator &op_dest); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -42,7 +42,7 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser { | |||||
| std::vector<int64_t> user_input_dims_v_; | std::vector<int64_t> user_input_dims_v_; | ||||
| bool is_subgraph_data_op_; | |||||
| bool is_subgraph_data_op_ = false; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <queue> | #include <queue> | ||||
| #include "common/convert/message2operator.h" | |||||
| #include "common/convert/pb2json.h" | #include "common/convert/pb2json.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| @@ -36,6 +37,7 @@ | |||||
| #include "parser/common/model_saver.h" | #include "parser/common/model_saver.h" | ||||
| #include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
| #include "parser/common/prototype_pass_manager.h" | #include "parser/common/prototype_pass_manager.h" | ||||
| #include "parser/onnx/onnx_custom_parser_adapter.h" | |||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
| @@ -555,6 +557,40 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, | |||||
| std::shared_ptr<OpParser> &op_parser) { | |||||
| GE_CHECK_NOTNULL(node_proto); | |||||
| GE_CHECK_NOTNULL(op_parser); | |||||
| std::string op_type = node_proto->op_type(); | |||||
| Status status = FAILED; | |||||
| domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); | |||||
| if (parse_param_func == nullptr) { | |||||
| status = op_parser->ParseParams(node_proto, op); | |||||
| } else { | |||||
| ge::Operator op_src(node_proto->name(), op_type); | |||||
| status = Message2Operator::ParseOperatorAttrs(node_proto, 1, op_src); | |||||
| if (status != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Auto mapping node:%s(%s) to operator failed", | |||||
| node_proto->name().c_str(), op_type.c_str()); | |||||
| GELOGE(status, "Node[%s] auto mapping failed.", node_proto->name().c_str()); | |||||
| return status; | |||||
| } | |||||
| std::shared_ptr<ge::OnnxCustomParserAdapter> onnx_custom_op_parser = | |||||
| std::dynamic_pointer_cast<ge::OnnxCustomParserAdapter>(op_parser); | |||||
| status = onnx_custom_op_parser->ParseParams(op_src, op); | |||||
| op_src.BreakConnect(); | |||||
| } | |||||
| if (status != SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type}); | |||||
| GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str()); | |||||
| return status; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | ||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
| ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ||||
| @@ -586,11 +622,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
| GE_CHECK_NOTNULL(factory); | GE_CHECK_NOTNULL(factory); | ||||
| std::shared_ptr<ge::OpParser> op_parser = factory->CreateOpParser(op_type); | std::shared_ptr<ge::OpParser> op_parser = factory->CreateOpParser(op_type); | ||||
| GE_CHECK_NOTNULL(op_parser); | GE_CHECK_NOTNULL(op_parser); | ||||
| std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser); | |||||
| GE_CHECK_NOTNULL(onnx_op_parser); | |||||
| status = onnx_op_parser->ParseParams(node_proto, op); | |||||
| status = ParseOpParam(node_proto, op, op_parser); | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | |||||
| GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | ||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -598,7 +631,6 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
| GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), | GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), | ||||
| op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); | op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); | ||||
| ge::graphStatus graph_status = graph.AddOp(op); | ge::graphStatus graph_status = graph.AddOp(op); | ||||
| if (graph_status != ge::GRAPH_SUCCESS) { | if (graph_status != ge::GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); | GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); | ||||
| @@ -110,6 +110,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
| void ClearMembers(); | void ClearMembers(); | ||||
| Status ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, std::shared_ptr<OpParser> &op_parser); | |||||
| Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, | Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, | ||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | ||||
| @@ -55,7 +55,7 @@ public: | |||||
| */ | */ | ||||
| std::shared_ptr<SubgraphAdapter> CreateSubgraphAdapter(const std::string &op_type); | std::shared_ptr<SubgraphAdapter> CreateSubgraphAdapter(const std::string &op_type); | ||||
| ~SubgraphAdapterFactory() = default; | |||||
| protected: | protected: | ||||
| /** | /** | ||||
| * @brief SubgraphAdapter creation function | * @brief SubgraphAdapter creation function | ||||
| @@ -1457,6 +1457,7 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library) | |||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; | AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; | ||||
| GE_CHECK_NOTNULL(funcDefBytes.GetData()); | |||||
| string str(reinterpret_cast<char *>(funcDefBytes.GetData()), funcDefBytes.GetSize()); | string str(reinterpret_cast<char *>(funcDefBytes.GetData()), funcDefBytes.GetSize()); | ||||
| GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( | GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( | ||||
| funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); | funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); | ||||
| @@ -75,9 +75,9 @@ Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | |||||
| op->TfOpDef(attr_v.s()); | op->TfOpDef(attr_v.s()); | ||||
| } else { | } else { | ||||
| GE_CHK_BOOL_EXEC(type == "_Retval", | GE_CHK_BOOL_EXEC(type == "_Retval", | ||||
| GE_DELETE_NEW_SINGLE(pkg_node); | |||||
| REPORT_INNER_ERROR("E19999", "In NodeDef:%s Attr:opdef is not exist, check invalid", | REPORT_INNER_ERROR("E19999", "In NodeDef:%s Attr:opdef is not exist, check invalid", | ||||
| pkg_node->name().c_str()); | pkg_node->name().c_str()); | ||||
| GE_DELETE_NEW_SINGLE(pkg_node); | |||||
| return PARAM_INVALID, "In NodeDef %s Attr opdef is not exist.", pkg_node->name().c_str()); | return PARAM_INVALID, "In NodeDef %s Attr opdef is not exist.", pkg_node->name().c_str()); | ||||
| } | } | ||||
| @@ -221,6 +221,7 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc" | "${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc" | ||||
| "${PARSER_DIR}/parser/common/acl_graph_parser_util.cc" | "${PARSER_DIR}/parser/common/acl_graph_parser_util.cc" | ||||
| "${PARSER_DIR}/parser/common/convert/pb2json.cc" | "${PARSER_DIR}/parser/common/convert/pb2json.cc" | ||||
| "${PARSER_DIR}/parser/common/convert/message2operator.cc" | |||||
| "${PARSER_DIR}/parser/common/data_op_parser.cc" | "${PARSER_DIR}/parser/common/data_op_parser.cc" | ||||
| "${PARSER_DIR}/parser/common/model_saver.cc" | "${PARSER_DIR}/parser/common/model_saver.cc" | ||||
| "${PARSER_DIR}/parser/common/op_def/arg_op.cc" | "${PARSER_DIR}/parser/common/op_def/arg_op.cc" | ||||
| @@ -305,6 +306,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||||
| set(PARSER_UT_FILES | set(PARSER_UT_FILES | ||||
| "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | ||||
| "testcase/onnx_parser_testcase/message2operator_unittest.cc" | |||||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | ||||
| ) | ) | ||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/convert/message2operator.h" | |||||
| #include <gtest/gtest.h> | |||||
| #include "proto/onnx/ge_onnx.pb.h" | |||||
| namespace ge { | |||||
| class UtestMessage2Operator : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(UtestMessage2Operator, message_to_operator_success) { | |||||
| ge::onnx::NodeProto input_node; | |||||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
| attribute->set_name("attribute"); | |||||
| attribute->set_type(onnx::AttributeProto::AttributeType(1)); | |||||
| attribute->set_f(1.0); | |||||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
| attribute_tensor->set_data_type(1); | |||||
| attribute_tensor->add_dims(4); | |||||
| ge::Operator op_src("add", "Add"); | |||||
| auto ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestMessage2Operator, message_to_operator_fail) { | |||||
| ge::onnx::NodeProto input_node; | |||||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
| attribute_tensor->add_double_data(1.00); | |||||
| ge::Operator op_src("add", "Add"); | |||||
| auto ret = Message2Operator::ParseOperatorAttrs(attribute, 6, op_src); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,3 +1,10 @@ | |||||
| #!/usr/bin/env python3 | |||||
| # -*- coding: UTF-8 -*- | |||||
| #------------------------------------------------------------------- | |||||
| # Purpose: | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd. All rights reserved. | |||||
| #------------------------------------------------------------------- | |||||
| # Given a bool scalar input cond. | # Given a bool scalar input cond. | ||||
| # return constant tensor x if cond is True, otherwise return constant tensor y. | # return constant tensor x if cond is True, otherwise return constant tensor y. | ||||
| import numpy as np | import numpy as np | ||||
| @@ -39,6 +39,10 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { | Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { | ||||
| domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | ||||
| domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | ||||
| @@ -72,6 +76,7 @@ void UtestOnnxParser::RegisterCustomOp() { | |||||
| "ai.onnx::12::If", | "ai.onnx::12::If", | ||||
| "ai.onnx::13::If"}) | "ai.onnx::13::If"}) | ||||
| .ParseParamsFn(ParseParams) | .ParseParamsFn(ParseParams) | ||||
| .ParseParamsByOperatorFn(ParseParamByOpFunc) | |||||
| .ParseSubgraphPostFn(ParseSubgraphPostFnIf); | .ParseSubgraphPostFn(ParseSubgraphPostFnIf); | ||||
| REGISTER_CUSTOM_OP("Add") | REGISTER_CUSTOM_OP("Add") | ||||