diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt index c15d455..6dc8b6d 100644 --- a/parser/common/CMakeLists.txt +++ b/parser/common/CMakeLists.txt @@ -25,6 +25,7 @@ set(SRC_LIST "parser_fp16_t.cc" "thread_pool.cc" "parser_utils.cc" + "auto_mapping_subgraph_io_index_func.cc" ) ############ libparser_common.so ############ @@ -45,6 +46,7 @@ target_include_directories(parser_common PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PARSER_DIR} ${PARSER_DIR}/parser + ${METADEF_DIR} ${METADEF_DIR}/inc ${METADEF_DIR}/inc/graph ${METADEF_DIR}/inc/register diff --git a/parser/common/auto_mapping_subgraph_io_index_func.cc b/parser/common/auto_mapping_subgraph_io_index_func.cc new file mode 100644 index 0000000..6926022 --- /dev/null +++ b/parser/common/auto_mapping_subgraph_io_index_func.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "auto_mapping_subgraph_io_index_func.h" +#include +#include "external/register/register.h" +#include "graph/compute_graph.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_util.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "register/register_fmk_types.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace { +std::vector FindNodesByType(const ge::ComputeGraphPtr &graph, const std::string &type) { + std::vector nodes; + for (const auto &node : graph->GetDirectNode()) { + if (node == nullptr) { + continue; + } + std::string node_type = NodeUtils::GetNodeType(node); + GELOGI("Find node %s, node type is %s.", node->GetName().c_str(), node_type.c_str()); + if (node_type == type) { + nodes.push_back(node); + continue; + } + } + return nodes; +} + +Status AutoMappingSubgraphIndexByOutputNodesInfo(const ge::ComputeGraphPtr &compute_graph, + const std::function &output) { + const auto &out_nodes_info = compute_graph->GetGraphOutNodesInfo(); + for (size_t i = 0; i < out_nodes_info.size(); ++i) { + const auto &out_node = out_nodes_info[i].first; + int32_t output_index = out_nodes_info[i].second; + int64_t index = static_cast(i); + int parent_index = -1; + auto ret = output(index, parent_index); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get parent output index %ld failed, node:%s", index, out_node->GetName().c_str()); + GELOGE(FAILED, "[Get][ParentOutputIndex] Get parent output index %ld failed, node:%s", + index, out_node->GetName().c_str()); + return FAILED; + } + auto op_desc = out_node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "[Get][OpDesc] Op desc is null!"); + return FAILED; + } + auto output_desc = op_desc->MutableOutputDesc(output_index); + if (output_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "Can not find output tensor desc from node:%s, index %d", + out_node->GetName().c_str(), output_index); + GELOGE(FAILED, "[Get][OutputDesc] Can not find output tensor desc from node:%s, index %d", + out_node->GetName().c_str(), output_index); + return FAILED; + } + if (!ge::AttrUtils::SetInt(output_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + REPORT_INNER_ERROR("E19999", "Set attr:%s of op:%s failed, parent_index:%d", + ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), out_node->GetName().c_str(), parent_index); + GELOGE(FAILED, "[Set][Attr] Set attr:%s of op:%s failed, parent_index:%d", + ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), out_node->GetName().c_str(), parent_index); + return FAILED; + } + GELOGI("Generate subgraph output map for subgraph %s, out node index %ld, parent node index %d, node name:%s", + compute_graph->GetName().c_str(), index, parent_index, out_node->GetName().c_str()); + } + + return SUCCESS; +} + +Status AutoMappingSubgraphIndexByDataNode(const ge::ComputeGraphPtr &compute_graph, + const std::function &input) { + auto nodes = FindNodesByType(compute_graph, "Data"); + for (size_t i = 0; i < nodes.size(); ++i) { + int parent_index = -1; + int index = -1; + if (!ge::AttrUtils::GetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_INDEX, index)) { + REPORT_INNER_ERROR("E19999", "Get attr:index failed, op_name:%s", nodes[i]->GetName().c_str()); + GELOGE(FAILED, "[Get][Attr] Get attr:index failed, op_name:%s", nodes[i]->GetName().c_str()); + return FAILED; + } + GELOGI("Get index %d from data[%zu], node:%s", index, i, nodes[i]->GetName().c_str()); + auto ret = input(index, parent_index); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get data index failed, op_name:%s", nodes[i]->GetName().c_str()); + GELOGE(FAILED, "[Get][ParentInputIndex] Get data index failed, op_name:%s", nodes[i]->GetName().c_str()); + return FAILED; + } + if (!ge::AttrUtils::SetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + REPORT_INNER_ERROR("E19999", "Set attr:%s failed, op_name:%s, ", + ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), nodes[i]->GetName().c_str()); + GELOGE(FAILED, "[Set][Attr] Set attr:%s failed, op_name:%s, ", + ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), nodes[i]->GetName().c_str()); + return FAILED; + } + GELOGI("Generate subgraph input map for subgraph %s, data index %zu, parent node index %d", + compute_graph->GetName().c_str(), i, parent_index); + } + return SUCCESS; +} +} + +Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( + const ge::Graph &graph, + const std::function &input, + const std::function &output) { + GE_CHECK_NOTNULL(input); + GE_CHECK_NOTNULL(output); + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + auto ret = AutoMappingSubgraphIndexByDataNode(compute_graph, input); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Auto mapping graph:%s input index failed,", graph.GetName().c_str()); + GELOGE(ret, "[Mapping][InputIndex] Auto mapping graph:%s input index failed,", graph.GetName().c_str()); + return ret; + } + ret = AutoMappingSubgraphIndexByOutputNodesInfo(compute_graph, output); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Auto mapping graph:%s output index failed,", graph.GetName().c_str()); + GELOGE(ret, "[Mapping][OutputIndex] Auto mapping graph:%s output index failed,", graph.GetName().c_str()); + return ret; + } + + return SUCCESS; +} +} // namespace ge + +namespace domi { +REGISTER_AUTOMAPPING_SUBGRAPH_IO_INDEX_FUNC(ONNX, ge::AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo); +} // namespace domi \ No newline at end of file diff --git a/parser/common/auto_mapping_subgraph_io_index_func.h b/parser/common/auto_mapping_subgraph_io_index_func.h new file mode 100644 index 0000000..a19e12e --- /dev/null +++ b/parser/common/auto_mapping_subgraph_io_index_func.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ +#define PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ + +#include +#include "external/graph/graph.h" +#include "external/register/register_error_codes.h" + +namespace ge { +domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( + const ge::Graph &graph, + const std::function &input, + const std::function &output); +} // namespace ge +#endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 153998a..9f0578d 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -249,6 +249,7 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/common/register_tbe.cc" "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" "${PARSER_DIR}/parser/common/thread_pool.cc" + "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.onnx b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.onnx new file mode 100644 index 0000000..ff2230a Binary files /dev/null and b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.onnx differ diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py new file mode 100644 index 0000000..e9aaef7 --- /dev/null +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py @@ -0,0 +1,73 @@ +# Given a bool scalar input cond. +# return constant tensor x if cond is True, otherwise return constant tensor y. +import numpy as np +import onnx +from onnx import helper +from onnx import numpy_helper +from onnx import AttributeProto, TensorProto, GraphProto + +then_out = onnx.helper.make_tensor_value_info('then_out', onnx.TensorProto.FLOAT, [5]) +else_out = onnx.helper.make_tensor_value_info('else_out', onnx.TensorProto.FLOAT, [5]) +then_in = onnx.helper.make_tensor_value_info('then_in', onnx.TensorProto.FLOAT, [5]) +else_in = onnx.helper.make_tensor_value_info('else_in', onnx.TensorProto.FLOAT, [5]) +cond = onnx.helper.make_tensor_value_info('cond', onnx.TensorProto.FLOAT, []) +res = onnx.helper.make_tensor_value_info('res', onnx.TensorProto.FLOAT, [5]) + +x = np.array([1, 2, 3, 4, 5]).astype(np.float32) +y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + +add_out_node = onnx.helper.make_node( + 'Add', + inputs=['then_in', 'else_in'], + outputs=['add_out'], +) + +then_identity_node = onnx.helper.make_node( + 'Identity', + inputs=['add_out'], + outputs=['then_out'], +) + +else_identity_node = onnx.helper.make_node( + 'Identity', + inputs=['add_out'], + outputs=['else_out'], +) + +then_body = onnx.helper.make_graph( + [then_identity_node], + 'then_body', + [], + [then_out] +) + +else_body = onnx.helper.make_graph( + [else_identity_node], + 'else_body', + [], + [else_out] +) + +if_node = onnx.helper.make_node( + 'If', + inputs=['cond'], + outputs=['res'], + then_branch=then_body, + else_branch=else_body +) + +add_if_node = onnx.helper.make_node( + 'Add', + inputs=['add_out', 'res'], + outputs=['res'], +) + +graph_def = helper.make_graph( + [add_out_node, if_node, add_if_node], + 'test_if', + [cond, else_in, then_in], + [res], +) +model_def = helper.make_model(graph_def, producer_name='if-onnx') +model_def.opset_import[0].version = 11 +onnx.save(model_def, "./if.onnx") diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 532c42e..a3c54d5 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -39,12 +39,51 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& return SUCCESS; } +Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { + domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = + domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); + if (auto_mapping_subgraph_index_func == nullptr) { + std::cout<<"auto mapping if subgraph func is nullptr!"< Status { + parent_index = data_index + 1; + return SUCCESS; + }, + [&](int output_index, int &parent_index) -> Status { + parent_index = output_index; + return SUCCESS; + }); +} + void UtestOnnxParser::RegisterCustomOp() { REGISTER_CUSTOM_OP("Conv2D") .FrameworkType(domi::ONNX) .OriginOpType("ai.onnx::11::Conv") .ParseParamsFn(ParseParams); + // register if op info to GE + REGISTER_CUSTOM_OP("If") + .FrameworkType(domi::ONNX) + .OriginOpType({"ai.onnx::9::If", + "ai.onnx::10::If", + "ai.onnx::11::If", + "ai.onnx::12::If", + "ai.onnx::13::If"}) + .ParseParamsFn(ParseParams) + .ParseSubgraphPostFn(ParseSubgraphPostFnIf); + + REGISTER_CUSTOM_OP("Add") + .FrameworkType(domi::ONNX) + .OriginOpType("ai.onnx::11::Add") + .ParseParamsFn(ParseParams); + + REGISTER_CUSTOM_OP("Identity") + .FrameworkType(domi::ONNX) + .OriginOpType("ai.onnx::11::Identity") + .ParseParamsFn(ParseParams); + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; for (auto reg_data : reg_datas) { OpRegistrationTbe::Instance()->Finalize(reg_data); @@ -79,6 +118,33 @@ REG_OP(Conv2D) .ATTR(data_format, String, "NHWC") .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv2D) + +REG_OP(If) + .INPUT(cond, TensorType::ALL()) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(then_branch) + .GRAPH(else_branch) + .OP_END_FACTORY_REG(If) + +REG_OP(Add) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OP_END_FACTORY_REG(Add) + +REG_OP(Identity) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(Identity) } TEST_F(UtestOnnxParser, onnx_parser_success) { @@ -93,5 +159,16 @@ TEST_F(UtestOnnxParser, onnx_parser_success) { EXPECT_EQ(ret, domi::SUCCESS); } +TEST_F(UtestOnnxParser, onnx_parser_if_node) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/onnx_model/if.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + EXPECT_EQ(ret, domi::SUCCESS); +} } // namespace ge \ No newline at end of file