Merge pull request !284 from 陈华/masterpull/284/MERGE
| @@ -25,6 +25,7 @@ set(SRC_LIST | |||||
| "parser_fp16_t.cc" | "parser_fp16_t.cc" | ||||
| "thread_pool.cc" | "thread_pool.cc" | ||||
| "parser_utils.cc" | "parser_utils.cc" | ||||
| "auto_mapping_subgraph_io_index_func.cc" | |||||
| ) | ) | ||||
| ############ libparser_common.so ############ | ############ libparser_common.so ############ | ||||
| @@ -45,6 +46,7 @@ target_include_directories(parser_common PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | ${CMAKE_CURRENT_LIST_DIR} | ||||
| ${PARSER_DIR} | ${PARSER_DIR} | ||||
| ${PARSER_DIR}/parser | ${PARSER_DIR}/parser | ||||
| ${METADEF_DIR} | |||||
| ${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
| ${METADEF_DIR}/inc/graph | ${METADEF_DIR}/inc/graph | ||||
| ${METADEF_DIR}/inc/register | ${METADEF_DIR}/inc/register | ||||
| @@ -0,0 +1,150 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "auto_mapping_subgraph_io_index_func.h" | |||||
| #include <vector> | |||||
| #include "external/register/register.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/debug/ge_util.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "register/register_fmk_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| std::vector<NodePtr> FindNodesByType(const ge::ComputeGraphPtr &graph, const std::string &type) { | |||||
| std::vector<NodePtr> nodes; | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| if (node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string node_type = NodeUtils::GetNodeType(node); | |||||
| GELOGI("Find node %s, node type is %s.", node->GetName().c_str(), node_type.c_str()); | |||||
| if (node_type == type) { | |||||
| nodes.push_back(node); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| return nodes; | |||||
| } | |||||
| Status AutoMappingSubgraphIndexByOutputNodesInfo(const ge::ComputeGraphPtr &compute_graph, | |||||
| const std::function<Status(int netoutput_index, int &parent_output_index)> &output) { | |||||
| const auto &out_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||||
| for (size_t i = 0; i < out_nodes_info.size(); ++i) { | |||||
| const auto &out_node = out_nodes_info[i].first; | |||||
| int32_t output_index = out_nodes_info[i].second; | |||||
| int64_t index = static_cast<int64_t>(i); | |||||
| int parent_index = -1; | |||||
| auto ret = output(index, parent_index); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Get parent output index %ld failed, node:%s", index, out_node->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Get][ParentOutputIndex] Get parent output index %ld failed, node:%s", | |||||
| index, out_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto op_desc = out_node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(FAILED, "[Get][OpDesc] Op desc is null!"); | |||||
| return FAILED; | |||||
| } | |||||
| auto output_desc = op_desc->MutableOutputDesc(output_index); | |||||
| if (output_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Can not find output tensor desc from node:%s, index %d", | |||||
| out_node->GetName().c_str(), output_index); | |||||
| GELOGE(FAILED, "[Get][OutputDesc] Can not find output tensor desc from node:%s, index %d", | |||||
| out_node->GetName().c_str(), output_index); | |||||
| return FAILED; | |||||
| } | |||||
| if (!ge::AttrUtils::SetInt(output_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
| REPORT_INNER_ERROR("E19999", "Set attr:%s of op:%s failed, parent_index:%d", | |||||
| ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), out_node->GetName().c_str(), parent_index); | |||||
| GELOGE(FAILED, "[Set][Attr] Set attr:%s of op:%s failed, parent_index:%d", | |||||
| ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), out_node->GetName().c_str(), parent_index); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Generate subgraph output map for subgraph %s, out node index %ld, parent node index %d, node name:%s", | |||||
| compute_graph->GetName().c_str(), index, parent_index, out_node->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status AutoMappingSubgraphIndexByDataNode(const ge::ComputeGraphPtr &compute_graph, | |||||
| const std::function<Status(int data_index, int &parent_input_index)> &input) { | |||||
| auto nodes = FindNodesByType(compute_graph, "Data"); | |||||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||||
| int parent_index = -1; | |||||
| int index = -1; | |||||
| if (!ge::AttrUtils::GetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_INDEX, index)) { | |||||
| REPORT_INNER_ERROR("E19999", "Get attr:index failed, op_name:%s", nodes[i]->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Get][Attr] Get attr:index failed, op_name:%s", nodes[i]->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Get index %d from data[%zu], node:%s", index, i, nodes[i]->GetName().c_str()); | |||||
| auto ret = input(index, parent_index); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Get data index failed, op_name:%s", nodes[i]->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Get][ParentInputIndex] Get data index failed, op_name:%s", nodes[i]->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!ge::AttrUtils::SetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
| REPORT_INNER_ERROR("E19999", "Set attr:%s failed, op_name:%s, ", | |||||
| ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), nodes[i]->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Set][Attr] Set attr:%s failed, op_name:%s, ", | |||||
| ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), nodes[i]->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Generate subgraph input map for subgraph %s, data index %zu, parent node index %d", | |||||
| compute_graph->GetName().c_str(), i, parent_index); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||||
| const ge::Graph &graph, | |||||
| const std::function<Status(int data_index, int &parent_input_index)> &input, | |||||
| const std::function<Status(int netoutput_index, int &parent_output_index)> &output) { | |||||
| GE_CHECK_NOTNULL(input); | |||||
| GE_CHECK_NOTNULL(output); | |||||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| auto ret = AutoMappingSubgraphIndexByDataNode(compute_graph, input); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Auto mapping graph:%s input index failed,", graph.GetName().c_str()); | |||||
| GELOGE(ret, "[Mapping][InputIndex] Auto mapping graph:%s input index failed,", graph.GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret = AutoMappingSubgraphIndexByOutputNodesInfo(compute_graph, output); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Auto mapping graph:%s output index failed,", graph.GetName().c_str()); | |||||
| GELOGE(ret, "[Mapping][OutputIndex] Auto mapping graph:%s output index failed,", graph.GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| namespace domi { | |||||
| REGISTER_AUTOMAPPING_SUBGRAPH_IO_INDEX_FUNC(ONNX, ge::AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo); | |||||
| } // namespace domi | |||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ | |||||
| #define PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ | |||||
| #include <functional> | |||||
| #include "external/graph/graph.h" | |||||
| #include "external/register/register_error_codes.h" | |||||
| namespace ge { | |||||
| domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||||
| const ge::Graph &graph, | |||||
| const std::function<domi::Status(int data_index, int &parent_input_index)> &input, | |||||
| const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output); | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ | |||||
| @@ -249,6 +249,7 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/common/register_tbe.cc" | "${PARSER_DIR}/parser/common/register_tbe.cc" | ||||
| "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" | "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" | ||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
| "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
| @@ -0,0 +1,73 @@ | |||||
| # Given a bool scalar input cond. | |||||
| # return constant tensor x if cond is True, otherwise return constant tensor y. | |||||
| import numpy as np | |||||
| import onnx | |||||
| from onnx import helper | |||||
| from onnx import numpy_helper | |||||
| from onnx import AttributeProto, TensorProto, GraphProto | |||||
| then_out = onnx.helper.make_tensor_value_info('then_out', onnx.TensorProto.FLOAT, [5]) | |||||
| else_out = onnx.helper.make_tensor_value_info('else_out', onnx.TensorProto.FLOAT, [5]) | |||||
| then_in = onnx.helper.make_tensor_value_info('then_in', onnx.TensorProto.FLOAT, [5]) | |||||
| else_in = onnx.helper.make_tensor_value_info('else_in', onnx.TensorProto.FLOAT, [5]) | |||||
| cond = onnx.helper.make_tensor_value_info('cond', onnx.TensorProto.FLOAT, []) | |||||
| res = onnx.helper.make_tensor_value_info('res', onnx.TensorProto.FLOAT, [5]) | |||||
| x = np.array([1, 2, 3, 4, 5]).astype(np.float32) | |||||
| y = np.array([5, 4, 3, 2, 1]).astype(np.float32) | |||||
| add_out_node = onnx.helper.make_node( | |||||
| 'Add', | |||||
| inputs=['then_in', 'else_in'], | |||||
| outputs=['add_out'], | |||||
| ) | |||||
| then_identity_node = onnx.helper.make_node( | |||||
| 'Identity', | |||||
| inputs=['add_out'], | |||||
| outputs=['then_out'], | |||||
| ) | |||||
| else_identity_node = onnx.helper.make_node( | |||||
| 'Identity', | |||||
| inputs=['add_out'], | |||||
| outputs=['else_out'], | |||||
| ) | |||||
| then_body = onnx.helper.make_graph( | |||||
| [then_identity_node], | |||||
| 'then_body', | |||||
| [], | |||||
| [then_out] | |||||
| ) | |||||
| else_body = onnx.helper.make_graph( | |||||
| [else_identity_node], | |||||
| 'else_body', | |||||
| [], | |||||
| [else_out] | |||||
| ) | |||||
| if_node = onnx.helper.make_node( | |||||
| 'If', | |||||
| inputs=['cond'], | |||||
| outputs=['res'], | |||||
| then_branch=then_body, | |||||
| else_branch=else_body | |||||
| ) | |||||
| add_if_node = onnx.helper.make_node( | |||||
| 'Add', | |||||
| inputs=['add_out', 'res'], | |||||
| outputs=['res'], | |||||
| ) | |||||
| graph_def = helper.make_graph( | |||||
| [add_out_node, if_node, add_if_node], | |||||
| 'test_if', | |||||
| [cond, else_in, then_in], | |||||
| [res], | |||||
| ) | |||||
| model_def = helper.make_model(graph_def, producer_name='if-onnx') | |||||
| model_def.opset_import[0].version = 11 | |||||
| onnx.save(model_def, "./if.onnx") | |||||
| @@ -39,12 +39,51 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { | |||||
| domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | |||||
| domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | |||||
| if (auto_mapping_subgraph_index_func == nullptr) { | |||||
| std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl; | |||||
| return FAILED; | |||||
| } | |||||
| return auto_mapping_subgraph_index_func(graph, | |||||
| [&](int data_index, int &parent_index) -> Status { | |||||
| parent_index = data_index + 1; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [&](int output_index, int &parent_index) -> Status { | |||||
| parent_index = output_index; | |||||
| return SUCCESS; | |||||
| }); | |||||
| } | |||||
| void UtestOnnxParser::RegisterCustomOp() { | void UtestOnnxParser::RegisterCustomOp() { | ||||
| REGISTER_CUSTOM_OP("Conv2D") | REGISTER_CUSTOM_OP("Conv2D") | ||||
| .FrameworkType(domi::ONNX) | .FrameworkType(domi::ONNX) | ||||
| .OriginOpType("ai.onnx::11::Conv") | .OriginOpType("ai.onnx::11::Conv") | ||||
| .ParseParamsFn(ParseParams); | .ParseParamsFn(ParseParams); | ||||
| // register if op info to GE | |||||
| REGISTER_CUSTOM_OP("If") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType({"ai.onnx::9::If", | |||||
| "ai.onnx::10::If", | |||||
| "ai.onnx::11::If", | |||||
| "ai.onnx::12::If", | |||||
| "ai.onnx::13::If"}) | |||||
| .ParseParamsFn(ParseParams) | |||||
| .ParseSubgraphPostFn(ParseSubgraphPostFnIf); | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Add") | |||||
| .ParseParamsFn(ParseParams); | |||||
| REGISTER_CUSTOM_OP("Identity") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Identity") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
| for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | OpRegistrationTbe::Instance()->Finalize(reg_data); | ||||
| @@ -79,6 +118,33 @@ REG_OP(Conv2D) | |||||
| .ATTR(data_format, String, "NHWC") | .ATTR(data_format, String, "NHWC") | ||||
| .ATTR(offset_x, Int, 0) | .ATTR(offset_x, Int, 0) | ||||
| .OP_END_FACTORY_REG(Conv2D) | .OP_END_FACTORY_REG(Conv2D) | ||||
| REG_OP(If) | |||||
| .INPUT(cond, TensorType::ALL()) | |||||
| .DYNAMIC_INPUT(input, TensorType::ALL()) | |||||
| .DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||||
| .GRAPH(then_branch) | |||||
| .GRAPH(else_branch) | |||||
| .OP_END_FACTORY_REG(If) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| REG_OP(Identity) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Identity) | |||||
| } | } | ||||
| TEST_F(UtestOnnxParser, onnx_parser_success) { | TEST_F(UtestOnnxParser, onnx_parser_success) { | ||||
| @@ -93,5 +159,16 @@ TEST_F(UtestOnnxParser, onnx_parser_success) { | |||||
| EXPECT_EQ(ret, domi::SUCCESS); | EXPECT_EQ(ret, domi::SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestOnnxParser, onnx_parser_if_node) { | |||||
| RegisterCustomOp(); | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/onnx_model/if.onnx"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
| EXPECT_EQ(ret, domi::SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||