Browse Source

!284 register subgraph auto mapping func to metadef

Merge pull request !284 from 陈华/master
pull/284/MERGE
i-robot Gitee 4 years ago
parent
commit
f12a415964
7 changed files with 333 additions and 0 deletions
  1. +2
    -0
      parser/common/CMakeLists.txt
  2. +150
    -0
      parser/common/auto_mapping_subgraph_io_index_func.cc
  3. +30
    -0
      parser/common/auto_mapping_subgraph_io_index_func.h
  4. +1
    -0
      tests/ut/parser/CMakeLists.txt
  5. BIN
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.onnx
  6. +73
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py
  7. +77
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 2
- 0
parser/common/CMakeLists.txt View File

@@ -25,6 +25,7 @@ set(SRC_LIST
"parser_fp16_t.cc" "parser_fp16_t.cc"
"thread_pool.cc" "thread_pool.cc"
"parser_utils.cc" "parser_utils.cc"
"auto_mapping_subgraph_io_index_func.cc"
) )


############ libparser_common.so ############ ############ libparser_common.so ############
@@ -45,6 +46,7 @@ target_include_directories(parser_common PRIVATE
${CMAKE_CURRENT_LIST_DIR} ${CMAKE_CURRENT_LIST_DIR}
${PARSER_DIR} ${PARSER_DIR}
${PARSER_DIR}/parser ${PARSER_DIR}/parser
${METADEF_DIR}
${METADEF_DIR}/inc ${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph ${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register ${METADEF_DIR}/inc/register


+ 150
- 0
parser/common/auto_mapping_subgraph_io_index_func.cc View File

@@ -0,0 +1,150 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "auto_mapping_subgraph_io_index_func.h"
#include <vector>
#include "external/register/register.h"
#include "graph/compute_graph.h"
#include "graph/op_desc.h"
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_util.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "register/register_fmk_types.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
namespace {
std::vector<NodePtr> FindNodesByType(const ge::ComputeGraphPtr &graph, const std::string &type) {
std::vector<NodePtr> nodes;
for (const auto &node : graph->GetDirectNode()) {
if (node == nullptr) {
continue;
}
std::string node_type = NodeUtils::GetNodeType(node);
GELOGI("Find node %s, node type is %s.", node->GetName().c_str(), node_type.c_str());
if (node_type == type) {
nodes.push_back(node);
continue;
}
}
return nodes;
}

Status AutoMappingSubgraphIndexByOutputNodesInfo(const ge::ComputeGraphPtr &compute_graph,
const std::function<Status(int netoutput_index, int &parent_output_index)> &output) {
const auto &out_nodes_info = compute_graph->GetGraphOutNodesInfo();
for (size_t i = 0; i < out_nodes_info.size(); ++i) {
const auto &out_node = out_nodes_info[i].first;
int32_t output_index = out_nodes_info[i].second;
int64_t index = static_cast<int64_t>(i);
int parent_index = -1;
auto ret = output(index, parent_index);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get parent output index %ld failed, node:%s", index, out_node->GetName().c_str());
GELOGE(FAILED, "[Get][ParentOutputIndex] Get parent output index %ld failed, node:%s",
index, out_node->GetName().c_str());
return FAILED;
}
auto op_desc = out_node->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(FAILED, "[Get][OpDesc] Op desc is null!");
return FAILED;
}
auto output_desc = op_desc->MutableOutputDesc(output_index);
if (output_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "Can not find output tensor desc from node:%s, index %d",
out_node->GetName().c_str(), output_index);
GELOGE(FAILED, "[Get][OutputDesc] Can not find output tensor desc from node:%s, index %d",
out_node->GetName().c_str(), output_index);
return FAILED;
}
if (!ge::AttrUtils::SetInt(output_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
REPORT_INNER_ERROR("E19999", "Set attr:%s of op:%s failed, parent_index:%d",
ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), out_node->GetName().c_str(), parent_index);
GELOGE(FAILED, "[Set][Attr] Set attr:%s of op:%s failed, parent_index:%d",
ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), out_node->GetName().c_str(), parent_index);
return FAILED;
}
GELOGI("Generate subgraph output map for subgraph %s, out node index %ld, parent node index %d, node name:%s",
compute_graph->GetName().c_str(), index, parent_index, out_node->GetName().c_str());
}

return SUCCESS;
}

Status AutoMappingSubgraphIndexByDataNode(const ge::ComputeGraphPtr &compute_graph,
const std::function<Status(int data_index, int &parent_input_index)> &input) {
auto nodes = FindNodesByType(compute_graph, "Data");
for (size_t i = 0; i < nodes.size(); ++i) {
int parent_index = -1;
int index = -1;
if (!ge::AttrUtils::GetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_INDEX, index)) {
REPORT_INNER_ERROR("E19999", "Get attr:index failed, op_name:%s", nodes[i]->GetName().c_str());
GELOGE(FAILED, "[Get][Attr] Get attr:index failed, op_name:%s", nodes[i]->GetName().c_str());
return FAILED;
}
GELOGI("Get index %d from data[%zu], node:%s", index, i, nodes[i]->GetName().c_str());
auto ret = input(index, parent_index);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get data index failed, op_name:%s", nodes[i]->GetName().c_str());
GELOGE(FAILED, "[Get][ParentInputIndex] Get data index failed, op_name:%s", nodes[i]->GetName().c_str());
return FAILED;
}
if (!ge::AttrUtils::SetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
REPORT_INNER_ERROR("E19999", "Set attr:%s failed, op_name:%s, ",
ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), nodes[i]->GetName().c_str());
GELOGE(FAILED, "[Set][Attr] Set attr:%s failed, op_name:%s, ",
ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), nodes[i]->GetName().c_str());
return FAILED;
}
GELOGI("Generate subgraph input map for subgraph %s, data index %zu, parent node index %d",
compute_graph->GetName().c_str(), i, parent_index);
}
return SUCCESS;
}
}

Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(
const ge::Graph &graph,
const std::function<Status(int data_index, int &parent_input_index)> &input,
const std::function<Status(int netoutput_index, int &parent_output_index)> &output) {
GE_CHECK_NOTNULL(input);
GE_CHECK_NOTNULL(output);
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

auto ret = AutoMappingSubgraphIndexByDataNode(compute_graph, input);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Auto mapping graph:%s input index failed,", graph.GetName().c_str());
GELOGE(ret, "[Mapping][InputIndex] Auto mapping graph:%s input index failed,", graph.GetName().c_str());
return ret;
}
ret = AutoMappingSubgraphIndexByOutputNodesInfo(compute_graph, output);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Auto mapping graph:%s output index failed,", graph.GetName().c_str());
GELOGE(ret, "[Mapping][OutputIndex] Auto mapping graph:%s output index failed,", graph.GetName().c_str());
return ret;
}

return SUCCESS;
}
} // namespace ge

namespace domi {
REGISTER_AUTOMAPPING_SUBGRAPH_IO_INDEX_FUNC(ONNX, ge::AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo);
} // namespace domi

+ 30
- 0
parser/common/auto_mapping_subgraph_io_index_func.h View File

@@ -0,0 +1,30 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_
#define PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_

#include <functional>
#include "external/graph/graph.h"
#include "external/register/register_error_codes.h"

namespace ge {
domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(
const ge::Graph &graph,
const std::function<domi::Status(int data_index, int &parent_input_index)> &input,
const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output);
} // namespace ge
#endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_

+ 1
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -249,6 +249,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/register_tbe.cc" "${PARSER_DIR}/parser/common/register_tbe.cc"
"${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc"
"${PARSER_DIR}/parser/common/thread_pool.cc" "${PARSER_DIR}/parser/common/thread_pool.cc"
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc"
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"


BIN
tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.onnx View File


+ 73
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py View File

@@ -0,0 +1,73 @@
# Given a bool scalar input cond.
# return constant tensor x if cond is True, otherwise return constant tensor y.
import numpy as np
import onnx
from onnx import helper
from onnx import numpy_helper
from onnx import AttributeProto, TensorProto, GraphProto

then_out = onnx.helper.make_tensor_value_info('then_out', onnx.TensorProto.FLOAT, [5])
else_out = onnx.helper.make_tensor_value_info('else_out', onnx.TensorProto.FLOAT, [5])
then_in = onnx.helper.make_tensor_value_info('then_in', onnx.TensorProto.FLOAT, [5])
else_in = onnx.helper.make_tensor_value_info('else_in', onnx.TensorProto.FLOAT, [5])
cond = onnx.helper.make_tensor_value_info('cond', onnx.TensorProto.FLOAT, [])
res = onnx.helper.make_tensor_value_info('res', onnx.TensorProto.FLOAT, [5])

x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
y = np.array([5, 4, 3, 2, 1]).astype(np.float32)

add_out_node = onnx.helper.make_node(
'Add',
inputs=['then_in', 'else_in'],
outputs=['add_out'],
)

then_identity_node = onnx.helper.make_node(
'Identity',
inputs=['add_out'],
outputs=['then_out'],
)

else_identity_node = onnx.helper.make_node(
'Identity',
inputs=['add_out'],
outputs=['else_out'],
)

then_body = onnx.helper.make_graph(
[then_identity_node],
'then_body',
[],
[then_out]
)

else_body = onnx.helper.make_graph(
[else_identity_node],
'else_body',
[],
[else_out]
)

if_node = onnx.helper.make_node(
'If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body
)

add_if_node = onnx.helper.make_node(
'Add',
inputs=['add_out', 'res'],
outputs=['res'],
)

graph_def = helper.make_graph(
[add_out_node, if_node, add_if_node],
'test_if',
[cond, else_in, then_in],
[res],
)
model_def = helper.make_model(graph_def, producer_name='if-onnx')
model_def.opset_import[0].version = 11
onnx.save(model_def, "./if.onnx")

+ 77
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -39,12 +39,51 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator&
return SUCCESS; return SUCCESS;
} }


Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) {
domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func =
domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX);
if (auto_mapping_subgraph_index_func == nullptr) {
std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl;
return FAILED;
}
return auto_mapping_subgraph_index_func(graph,
[&](int data_index, int &parent_index) -> Status {
parent_index = data_index + 1;
return SUCCESS;
},
[&](int output_index, int &parent_index) -> Status {
parent_index = output_index;
return SUCCESS;
});
}

void UtestOnnxParser::RegisterCustomOp() { void UtestOnnxParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Conv2D") REGISTER_CUSTOM_OP("Conv2D")
.FrameworkType(domi::ONNX) .FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Conv") .OriginOpType("ai.onnx::11::Conv")
.ParseParamsFn(ParseParams); .ParseParamsFn(ParseParams);


// register if op info to GE
REGISTER_CUSTOM_OP("If")
.FrameworkType(domi::ONNX)
.OriginOpType({"ai.onnx::9::If",
"ai.onnx::10::If",
"ai.onnx::11::If",
"ai.onnx::12::If",
"ai.onnx::13::If"})
.ParseParamsFn(ParseParams)
.ParseSubgraphPostFn(ParseSubgraphPostFnIf);

REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Add")
.ParseParamsFn(ParseParams);

REGISTER_CUSTOM_OP("Identity")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Identity")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data); OpRegistrationTbe::Instance()->Finalize(reg_data);
@@ -79,6 +118,33 @@ REG_OP(Conv2D)
.ATTR(data_format, String, "NHWC") .ATTR(data_format, String, "NHWC")
.ATTR(offset_x, Int, 0) .ATTR(offset_x, Int, 0)
.OP_END_FACTORY_REG(Conv2D) .OP_END_FACTORY_REG(Conv2D)

REG_OP(If)
.INPUT(cond, TensorType::ALL())
.DYNAMIC_INPUT(input, TensorType::ALL())
.DYNAMIC_OUTPUT(output, TensorType::ALL())
.GRAPH(then_branch)
.GRAPH(else_branch)
.OP_END_FACTORY_REG(If)

REG_OP(Add)
.INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OP_END_FACTORY_REG(Add)

REG_OP(Identity)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(Identity)
} }


TEST_F(UtestOnnxParser, onnx_parser_success) { TEST_F(UtestOnnxParser, onnx_parser_success) {
@@ -93,5 +159,16 @@ TEST_F(UtestOnnxParser, onnx_parser_success) {
EXPECT_EQ(ret, domi::SUCCESS); EXPECT_EQ(ret, domi::SUCCESS);
} }


TEST_F(UtestOnnxParser, onnx_parser_if_node) {
RegisterCustomOp();

std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/if.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, domi::SUCCESS);
}


} // namespace ge } // namespace ge

Loading…
Cancel
Save