Browse Source

onnx_if_loop

pull/266/head
陈华 4 years ago
parent
commit
b87af40416
14 changed files with 695 additions and 48 deletions
  1. +0
    -4
      .gitmodules
  2. +0
    -1
      metadef
  3. +3
    -0
      parser/common/parser_types.cc
  4. +3
    -1
      parser/onnx/CMakeLists.txt
  5. +13
    -1
      parser/onnx/onnx_data_parser.cc
  6. +6
    -0
      parser/onnx/onnx_data_parser.h
  7. +125
    -0
      parser/onnx/onnx_if_subgraph_adapter.cc
  8. +47
    -0
      parser/onnx/onnx_if_subgraph_adapter.h
  9. +343
    -36
      parser/onnx/onnx_parser.cc
  10. +19
    -5
      parser/onnx/onnx_parser.h
  11. +48
    -0
      parser/onnx/onnx_retval_parser.cc
  12. +29
    -0
      parser/onnx/onnx_retval_parser.h
  13. +56
    -0
      parser/onnx/onnx_subgraph_adapter.h
  14. +3
    -0
      parser/onnx/onnx_util.h

+ 0
- 4
.gitmodules View File

@@ -1,4 +0,0 @@
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git
branch = master

+ 0
- 1
metadef

@@ -1 +0,0 @@
Subproject commit 86781b7e8ce21d2b901406cc3619d6bea2aeb18e

+ 3
- 0
parser/common/parser_types.cc View File

@@ -438,6 +438,9 @@ const char *HVDCALLBACKALLGATHER = "HorovodAllgather";
const char *HVDCALLBACKBROADCAST = "HorovodBroadcast";
const char *HVDWAIT = "HorovodWait";

// onnx output operator
const char *_RETVAL = "_Retval";

///
/// @brief Magic number of model file
///


+ 3
- 1
parser/onnx/CMakeLists.txt View File

@@ -8,7 +8,9 @@ set(SRC_LIST
"onnx_parser.cc"
"onnx_data_parser.cc"
"onnx_util.cc"
"onnx_constant_parser.cc"
"onnx_constant_parser.cc"
"onnx_retval_parser.cc"
"onnx_if_subgraph_adapter.cc"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})


+ 13
- 1
parser/onnx/onnx_data_parser.cc View File

@@ -35,6 +35,11 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def)
GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str());
return FAILED;
}
// Subgraph data operator don't need parse input shape
// the shape mappings from parent node input
if (IsSubgraphOp()) {
return SUCCESS;
}

if (ParseInputFromUser(op_def) != SUCCESS) {
GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str());
@@ -72,15 +77,23 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &
// Get attr t:'input_tensor' form NodeProto
int64_t data_type = 1;
int64_t index = 0;
isSubgraphOp_ = false;
for (auto it : node->attribute()) {
if (it.name() == ge::kAttrNameInput) {
data_type = ParseInputTensor(it);
} else if (it.name() == ge::kAttrNameIndex) {
index = it.i();
GELOGI("The node has attribute with index: %ld", index);
} else if (it.name() == ge::kAttrNameIsSubgraphOp) {
isSubgraphOp_ = true;
}
}

op_def.SetAttr(ge::ATTR_NAME_INDEX, index);
if (IsSubgraphOp()) {
return SUCCESS;
}

// Trans onnx type to ge type
DataType type = OnnxUtil::ConvertOnnxDataType(data_type);
if (type == ge::DataType::DT_UNDEFINED) {
@@ -88,7 +101,6 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &
return FAILED;
}
op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast<int64_t>(type));
op_def.SetAttr(ge::ATTR_NAME_INDEX, index);

return SUCCESS;
}


+ 6
- 0
parser/onnx/onnx_data_parser.h View File

@@ -32,11 +32,17 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser {

Status ParseInputFromUser(const ge::Operator &op_def);

bool IsSubgraphOp() {
return isSubgraphOp_;
}

int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute);

std::vector<int64_t> model_input_dims_v_;

std::vector<int64_t> user_input_dims_v_;

bool isSubgraphOp_;
};
} // namespace ge



+ 125
- 0
parser/onnx/onnx_if_subgraph_adapter.cc View File

@@ -0,0 +1,125 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/onnx/onnx_if_subgraph_adapter.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "onnx_parser.h"

using domi::ONNX;

namespace ge{
namespace {
const std::map<string, int> kAttrNames = {{"then_branch", 0}, {"else_branch", 1}};
const int kIfNodeMaxAttrSize = 2;
}
Status OnnxIfSubgraphAdapter::AdaptAndFindAllOnnxSubgraphs(ge::onnx::NodeProto *parent_node,
std::deque<ge::onnx::GraphProto *> &onnx_graph_tasks,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
GE_CHECK_NOTNULL(parent_node);
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str());

auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graph_tasks, name_to_onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "Parse if node failed.");
return ret;
}
return SUCCESS;
}

Status OnnxIfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node,
std::deque<ge::onnx::GraphProto *> &onnx_graph_tasks,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
if (parent_node->attribute_size() != kIfNodeMaxAttrSize) {
GELOGE(FAILED, "invalid graph, node->attribute_size():%d.", parent_node->attribute_size());
return FAILED;
}
GELOGI("node attribute size:%d.", parent_node->attribute_size());
std::set<std::string> all_inputs;
std::vector<ge::onnx::GraphProto*> onnx_graphs;
for (int i = 0; i < parent_node->attribute_size(); i++) {
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i);
std::string attr_name = attribute->name();
auto itr = kAttrNames.find(attr_name);
if (itr == kAttrNames.end()) {
GELOGE(FAILED, "invalid attribute name:%s.", attr_name.c_str());
return FAILED;
}
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g();

std::string sub_graph_name = parent_node->name() + "_" + std::to_string(itr->second) + "_" + itr->first;
GELOGI("Start parse if attribute:%s, subgraph name:%s.", attr_name.c_str(), sub_graph_name.c_str());
name_to_onnx_graph[sub_graph_name] = onnx_graph;
onnx_graph_tasks.push_back(onnx_graph);

auto ret = GetSubgraphsAllInputs(parent_node, onnx_graph, all_inputs);
if (ret != SUCCESS) {
GELOGE(ret, "get subgraps all inputs failed, attr_name:%s.", attr_name.c_str());
return ret;
}
onnx_graphs.emplace_back(onnx_graph);
}
for (auto onnx_graph : onnx_graphs) {
AddInputNodeForGraph(all_inputs, onnx_graph);
}
AddInputForParentNode(all_inputs, parent_node);
return SUCCESS;
}

Status OnnxIfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::NodeProto *parent_node, ge::onnx::GraphProto *sub_graph,
std::set<std::string> &all_inputs) {
std::set<std::string> graph_inputs;
std::set<std::string> graph_outputs;
for (int i = 0; i < sub_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = sub_graph->mutable_node(i);
std::string node_name = node_proto->name();

for (int j = 0; j < node_proto->input_size(); j++) {
graph_inputs.emplace(node_proto->input(j));
}

for (int j = 0; j < node_proto->output_size(); j++) {
graph_outputs.emplace(node_proto->output(j));
}
}

for (auto input : graph_inputs) {
auto out_iter = graph_outputs.find(input);
if (out_iter == graph_outputs.end()) {
// Record input node needed to be construct
all_inputs.emplace(input);
}
}

return SUCCESS;
}

void OnnxIfSubgraphAdapter::AddInputNodeForGraph(const std::set<std::string> &all_inputs,
ge::onnx::GraphProto *onnx_graph) {
for (auto input_name : all_inputs) {
ge::onnx::ValueInfoProto *value_info = onnx_graph->add_input();
value_info->set_name(input_name);
}
}

void OnnxIfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_inputs,
ge::onnx::NodeProto *parent_node) {
for (auto input_name : all_inputs) {
parent_node->add_input(input_name);
}
}
} // namespace ge

+ 47
- 0
parser/onnx/onnx_if_subgraph_adapter.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_ONNX_IF_SUBGRAPH_ADAPTER_H_
#define GE_PARSER_ONNX_ONNX_IF_SUBGRAPH_ADAPTER_H_

#include <set>
#include <string>
#include "onnx_subgraph_adapter.h"
#include "proto/onnx/ge_onnx.pb.h"

using ge::onnx::NodeProto;

namespace ge {
class PARSER_FUNC_VISIBILITY OnnxIfSubgraphAdapter : public OnnxSubgraphAdapter {
public:
/// @brief parse params
/// @param [in] parent_op parent op
/// @return SUCCESS parse success
/// @return FAILED Parse failed
Status AdaptAndFindAllOnnxSubgraphs(ge::onnx::NodeProto *parent_op,
std::deque<ge::onnx::GraphProto *> &onnx_graph_tasks,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override;
private:
Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::deque<ge::onnx::GraphProto *> &onnx_graph_tasks,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);
Status GetSubgraphsAllInputs(ge::onnx::NodeProto *parent_node, ge::onnx::GraphProto *sub_graph,
std::set<std::string> &all_inputs);
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto *onnx_graph);
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto *parent_node);
};
} // namespace ge

#endif // GE_PARSER_ONNX_ONNX_IF_SUBGRAPH_ADAPTER_H_

+ 343
- 36
parser/onnx/onnx_parser.cc View File

@@ -37,6 +37,10 @@
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"
#include "register/register_fmk_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "onnx_if_subgraph_adapter.h"

namespace ge {
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
@@ -95,7 +99,7 @@ graphStatus aclgrphParseONNX(const char *model_file,
}

GE_CHECK_NOTNULL(model_parser);
// parse caffe model_file to GE graph
// parse onnx model_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
@@ -144,25 +148,187 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
namespace ge {
namespace {
const std::map<std::string, std::string> kOnnxOpMap = {
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT},
{ge::kOpTypeInput, ge::parser::DATA},
{ge::kOpTypeConstant, ge::parser::CONSTANT},
{ge::kOpTypeOutput, ge::parser::_RETVAL},
};
const char* const MATMULV2 = "MatMulV2";

const std::vector<std::string> kMakeOperatorNotByIr = {ge::parser::_RETVAL};
const char *const MATMULV2 = "MatMulV2";
const std::vector<std::string> kNoNeedUpdateFormat = {MATMULV2};
const int64_t kDimValue = 1;

struct ParseArg {
ge::onnx::GraphProto *onnx_graph;
ge::NodePtr parent_node;
std::string graph_name;
uint32_t subgraph_index;
};

Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque<ParseArg> &args) {
GELOGI("Gen subgraph parse tasks start");
for (auto &node : parent_graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) {
auto i = subgraph_name_to_index.second;
auto subgraph_iname = subgraph_name_to_index.first;
if (subgraph_iname.empty()) {
GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str());
continue;
}

// A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE
auto unique_name = node->GetName() + "_" + std::to_string(i) + "_" + subgraph_iname;

GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s",
node->GetName().c_str(), i, unique_name.c_str());
args.push_back({nullptr, node, unique_name, i});
}
}
GELOGI("Gen subgraph parse tasks end");
return SUCCESS;
}

Status BuildLinkForSonAndParentGraph(const ge::ComputeGraphPtr &sub_graph, const ParseArg &arg) {
if (arg.parent_node == nullptr) {
return SUCCESS;
}
auto parent_node = arg.parent_node;
auto index = arg.subgraph_index;
auto ret = ge::NodeUtils::SetSubgraph(*parent_node, index, sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(),
parent_node->GetName().c_str(), index);
return ret;
}
return SUCCESS;
}

Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_graph) {
if (arg.parent_node == nullptr) {
return SUCCESS;
}
std::string op_type = arg.parent_node->GetType();
std::string op_name = arg.parent_node->GetName();
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
auto post_func =
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type);
if (post_func == nullptr) {
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str());
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc( op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) {
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str());
return SUCCESS;
}
}

GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(),
arg.parent_node->GetType().c_str());

// refresh node_name in subgraph
for (const ge::NodePtr &node : sub_graph->GetDirectNode()) {
if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) {
continue;
}
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName());
}

auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph);
Status ret = FAILED;
if (post_func != nullptr) {
ret = post_func(arg.graph_name, graph);
} else if (parse_func_v2 != nullptr) {
ret = parse_func_v2(arg.graph_name.c_str(), graph);
}
if (ret != SUCCESS) {
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s",
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(),
arg.parent_node->GetType().c_str());
return FAILED;
}
return SUCCESS;
}
}

Status OnnxModelParser::ParseOutput(ge::onnx::GraphProto *onnx_graph) {
if (onnx_graph->output_size() == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E16001");
GELOGE(FAILED, "Onnx graph has zero output");
return FAILED;
}

// get output value info map
int64_t index = 0;
for (int i = 0; i < onnx_graph->output_size(); i++) {
ge::onnx::ValueInfoProto value_info = onnx_graph->output(i);
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str());

ge::onnx::TensorProto tensor_tmp;
if (value_info.has_type()) {
const ge::onnx::TypeProto type = value_info.type();
if (type.has_tensor_type()) {
const ge::onnx::TypeProto_Tensor type_proto_tensor = type.tensor_type();
int32_t elem_type = type_proto_tensor.elem_type();
tensor_tmp.set_data_type(elem_type);
if (type_proto_tensor.has_shape()) {
const ge::onnx::TensorShapeProto tensor_shape = type_proto_tensor.shape();
for (int j = 0; j < tensor_shape.dim_size(); j++) {
const ge::onnx::TensorShapeProto_Dimension dimension = tensor_shape.dim(j);
int64_t dim_value = dimension.dim_value();
tensor_tmp.add_dims(dim_value);
GELOGI("elem_type: %d, dim_value: %ld", elem_type, dim_value);
}
}
}
}
// Construct node for output
ge::onnx::NodeProto *output_node = onnx_graph->add_node();
output_node->set_name(value_info.name());
output_node->set_op_type(ge::kOpTypeOutput);
output_node->add_input(value_info.name());
// add tensor
ge::onnx::AttributeProto *attribute = output_node->add_attribute();
attribute->set_name(ge::kAttrNameOutput);
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_tmp;
// add index
ge::onnx::AttributeProto *attribute_index = output_node->add_attribute();
attribute_index->set_name(ge::kAttrNameIndex);
attribute_index->set_i(index++);
output_node_names_.emplace_back(value_info.name());
}
return SUCCESS;
}

Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
Status OnnxModelParser::ParseInput(bool isSubgraph, ge::onnx::GraphProto *onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) {
if (onnx_graph.input_size() == 0) {
if (onnx_graph->input_size() == 0) {
if (!isSubgraph) {
ErrorManager::GetInstance().ATCReportErrMessage("E16001");
GELOGE(FAILED, "Root onnx graph has zero input");
return FAILED;
} else {
// sometime subgraph does not have input, we using constant as input nodes
for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph->mutable_node(i);
if (node->op_type() == "Constant") {
input_node_names_.emplace_back(node->name());
}
}
}
return SUCCESS;
}

if (onnx_graph->input_size() == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E16001");
GELOGE(FAILED, "Onnx graph has zero input");
GELOGE(FAILED, "onnx graph has zero input");
return FAILED;
}

// get input value info map
int64_t data_index = 0;
for (int i = 0; i < onnx_graph.input_size(); i++) {
ge::onnx::ValueInfoProto value_info = onnx_graph.input(i);
for (int i = 0; i < onnx_graph->input_size(); i++) {
ge::onnx::ValueInfoProto value_info = onnx_graph->input(i);
GELOGI("The index of %d input name : %s.", i, value_info.name().c_str());

/// if the input is initialized by a default value found in ‘initializer’,
@@ -194,7 +360,7 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
}
}
// Construct node for input
ge::onnx::NodeProto *input_node = onnx_graph.add_node();
ge::onnx::NodeProto *input_node = onnx_graph->add_node();
input_node->set_name(value_info.name());
input_node->set_op_type(ge::kOpTypeInput);
input_node->add_output(value_info.name());
@@ -207,17 +373,22 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
ge::onnx::AttributeProto *attribute_index = input_node->add_attribute();
attribute_index->set_name(ge::kAttrNameIndex);
attribute_index->set_i(data_index++);
// add subgraph attr
if (isSubgraph) {
attribute = input_node->add_attribute();
attribute->set_name(ge::kAttrNameIsSubgraphOp);
}
input_node_names_.emplace_back(value_info.name());
}
return SUCCESS;
}

Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto *onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) {
// Construct const node for weight
int index = 0;
for (auto it : initializer_name_tensor) {
ge::onnx::NodeProto *const_node = onnx_graph.add_node();
ge::onnx::NodeProto *const_node = onnx_graph->add_node();
std::string output_name = it.first + "_" + to_string(index++);
const_node->set_name(output_name);
const_node->set_op_type(ge::kOpTypeConstant);
@@ -231,10 +402,10 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
return SUCCESS;
}

Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) {
Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto *onnx_graph) {
int index = 0;
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph->mutable_node(i);
if (node->name().empty()) {
std::string node_name = node->op_type() + "_" + to_string(index++);
node->set_name(node_name);
@@ -318,9 +489,14 @@ Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_prot
string node_name = node_proto->name();
op = ge::OperatorFactory::CreateOperator(node_name, op_type);
if (op.GetName() != node_name) {
ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type});
GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
return INTERNAL_ERROR;
GELOGW("IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) {
op = ge::Operator(node_name, op_type);
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type});
GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
return INTERNAL_ERROR;
}
}

GELOGI("After create operator, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(),
@@ -381,18 +557,20 @@ Status OnnxModelParser::SetOperatorInputs() {
GE_CHECK_NOTNULL(dst_op_desc);
auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op);
GE_CHECK_NOTNULL(src_op_desc);
dst_op.SetInput(dst_op_desc->GetInputNameByIndex(dst_index), src_op,
src_op_desc->GetOutputNameByIndex(src_index));
std::string dst_name = dst_op_desc->GetInputNameByIndex(dst_index);
std::string src_name = src_op_desc->GetOutputNameByIndex(src_index);
GELOGI("dst_name:%s, src_name:%s.", dst_name.c_str(), src_name.c_str());
dst_op.SetInput(dst_name, src_op, src_name);
}
}
}
return SUCCESS;
}

Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) {
Status OnnxModelParser::Prechecker(ge::onnx::GraphProto *onnx_graph) {
ge::PreChecker::Instance().Clear();
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph->mutable_node(i);
std::string ori_type;
Status ret = ConstructOriType(node, ori_type);
if (ret != SUCCESS) {
@@ -418,9 +596,9 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) {
return SUCCESS;
}

Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto *onnx_graph, ge::Graph &graph) {
for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i);
std::string node_name = node_proto->name();
std::string ori_type = node_proto->op_type();
GELOGI("Start parse node which name is %s, type is %s", node_name.c_str(), ori_type.c_str());
@@ -455,6 +633,10 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
return status;
}

GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(),
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize());


ge::graphStatus graph_status = graph.AddOp(op);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str());
@@ -488,6 +670,21 @@ Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) {
return SUCCESS;
}

Status OnnxModelParser::GetGraphOutputs(std::vector<ge::Operator> &output_ops) {
for (auto out_name : output_node_names_) {
auto out_op = name_operator_.find(out_name);
if (out_op == name_operator_.end()) {
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.",
out_name.c_str());
return PARAM_INVALID;
}
output_ops.emplace_back(out_op->second);
GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str());
}

return SUCCESS;
}

Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) {
GE_CHECK_NOTNULL(file);
GELOGI("File path is %s.", file);
@@ -515,13 +712,56 @@ Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) {
void OnnxModelParser::ClearMembers() {
name_operator_.clear();
input_node_names_.clear();
output_node_names_.clear();
inputs_map_.clear();
outputs_map_.clear();
}

Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::deque<ge::onnx::GraphProto *> onnx_graph_tasks;
int index = 0;
onnx_graph_tasks.push_back(&root_onnx_graph);

while (!onnx_graph_tasks.empty()) {
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front();
onnx_graph_tasks.pop_front();
for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i);
if (node_proto->name().empty()) {
std::string node_name = node_proto->op_type() + "_" + to_string(index++);
node_proto->set_name(node_name);
}
if (node_proto->op_type() == "If") {
OnnxIfSubgraphAdapter adapter;
if (adapter.AdaptAndFindAllOnnxSubgraphs(node_proto, onnx_graph_tasks, name_to_onnx_graph) != SUCCESS) {
GELOGE(FAILED, "adapt subgraph failed.");
return FAILED;
}
}
}
}
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph) {
if (!onnx_model.has_graph()) {
ErrorManager::GetInstance().ATCReportErrMessage("E16004");
GELOGE(PARAM_INVALID, "Onnx model do not has graph.");
return FAILED;
}
ge::onnx::GraphProto onnx_graph = onnx_model.graph();
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_graph;
std::deque<ParseArg> tasks;
ge::onnx::GraphProto root_onnx_graph = onnx_model.graph();

auto ret = AdaptAndFindAllOnnxGraph(root_onnx_graph, name_to_onnx_graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "adapt subgraph failed.");
return FAILED;
}

auto opset_import = onnx_model.opset_import();
for (auto it : opset_import) {
@@ -529,10 +769,72 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version());
}

tasks.push_back({&root_onnx_graph, nullptr, "", 0});

while (!tasks.empty()) {
ParseArg arg = tasks.front();
tasks.pop_front();
bool isSubgraph = (arg.parent_node != nullptr) ? true : false;

if (arg.onnx_graph == nullptr) {
auto itr = name_to_onnx_graph.find(arg.graph_name);
if (itr == name_to_onnx_graph.end()) {
GELOGE(FAILED, "can not find onnx graph, graph:%s.", arg.graph_name.c_str());
return FAILED;
}
arg.onnx_graph = itr->second;
}

ge::onnx::GraphProto *onnx_graph = arg.onnx_graph;
ge::Graph tmp_graph(arg.graph_name);
if (isSubgraph) {
ret = ModelParseToGraphImpl(isSubgraph, onnx_graph, tmp_graph);
} else {
ret = ModelParseToGraphImpl(isSubgraph, onnx_graph, root_graph);
}
if (ret != SUCCESS) {
GELOGE(ret, "model parse to graph impl failed.");
return ret;
}

ge::ComputeGraphPtr cur_compute_graph;
if (isSubgraph) {
cur_compute_graph = ge::GraphUtils::GetComputeGraph(tmp_graph);
} else {
cur_compute_graph = ge::GraphUtils::GetComputeGraph(root_graph);
}

ret = PostOpProcessForSubgraph(arg, cur_compute_graph);
if (ret != SUCCESS) {
GELOGE(ret, "PostOpProcessForSubgraph failed.");
return ret;
}

ret = BuildLinkForSonAndParentGraph(cur_compute_graph, arg);
if (ret != SUCCESS) {
GELOGE(ret, "BuildLinkForSonAndParentGraph failed.");
return ret;
}

ret = GenSubgraphParseTasks(cur_compute_graph, tasks);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", cur_compute_graph->GetName().c_str());
return ret;
}

}
UpdateDataFormat(root_graph);
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraphImpl(bool isSubgraph, ge::onnx::GraphProto *onnx_graph, ge::Graph &graph) {

ClearMembers();

// 2. Get all inializer.
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor;
for (int i = 0; i < onnx_graph.initializer_size(); i++) {
ge::onnx::TensorProto initializer_tensor = onnx_graph.initializer(i);
for (int i = 0; i < onnx_graph->initializer_size(); i++) {
ge::onnx::TensorProto initializer_tensor = onnx_graph->initializer(i);
if (!initializer_tensor.name().empty()) {
initializer_name_tensor[initializer_tensor.name()] = initializer_tensor;
GELOGI("Initializer name: %s .", initializer_tensor.name().c_str());
@@ -541,7 +843,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model

// 3. Parse Input from graph.
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size());
Status ret = ParseInput(onnx_graph, initializer_name_tensor);

Status ret = ParseInput(isSubgraph, onnx_graph, initializer_name_tensor);
if (ret != SUCCESS) {
GELOGE(ret, "Parse input for onnx failed.");
return ret;
@@ -555,6 +858,12 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
return ret;
}

ret = ParseOutput(onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "Parse output for onnx failed.");
return ret;
}

// 5. Update node name for node do not has name.
ret = UpdateAllNodeName(onnx_graph);
if (ret != SUCCESS) {
@@ -582,6 +891,10 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
return ret;
}

std::vector<string> op_names;
graph.GetAllOpName(op_names);
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size());

// 8. Set all operator input.
ret = SetOperatorInputs();
if (ret != SUCCESS) {
@@ -589,13 +902,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
return ret;
}

std::vector<string> op_names;
graph.GetAllOpName(op_names);
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size());

// 9. Construct graph.
std::vector<ge::Operator> input_ops;

ret = GetGraphInputs(input_ops);
if (ret != SUCCESS) {
GELOGE(ret, "Get graph inputs failed.");
@@ -604,7 +912,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
graph.SetInputs(input_ops);

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));
UpdateDataFormat(graph);

GELOGI("Onnx model parser success.");
return SUCCESS;


+ 19
- 5
parser/onnx/onnx_parser.h View File

@@ -34,6 +34,7 @@
#include <map>
#include <string>
#include <vector>
#include <deque>
#include "external/register/register_error_codes.h"
#include "omg/parser/model_parser.h"
#include "omg/parser/op_parser.h"
@@ -70,15 +71,17 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
}

private:
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph);
Status ParseAllNodeProto(ge::onnx::GraphProto *onnx_graph, ge::Graph &graph);

Status ParseInput(ge::onnx::GraphProto &onnx_graph,
Status ParseInput(bool isSubgraph, ge::onnx::GraphProto *onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

Status ParseInitializer(ge::onnx::GraphProto &onnx_graph,
Status ParseOutput(ge::onnx::GraphProto *onnx_graph);

Status ParseInitializer(ge::onnx::GraphProto *onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

Status UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph);
Status UpdateAllNodeName(ge::onnx::GraphProto *onnx_graph);

Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type);

@@ -92,7 +95,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

Status GetGraphInputs(std::vector<ge::Operator> &input_ops);

Status Prechecker(ge::onnx::GraphProto &onnx_graph);
Status GetGraphOutputs(std::vector<ge::Operator> &output_ops);

Status Prechecker(ge::onnx::GraphProto *onnx_graph);
Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model);

@@ -100,8 +105,15 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph);

Status ModelParseToGraphImpl(bool isSubgraph, ge::onnx::GraphProto *onnx_graph, ge::Graph &graph);

void UpdateDataFormat(ge::Graph &graph);

void ClearMembers();

Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);

std::map<std::string, std::string> ori_to_om_type_;

std::map<std::string, int64_t> domain_verseion_;
@@ -110,6 +122,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

std::vector<std::string> input_node_names_;

std::vector<std::string> output_node_names_;

std::map<std::string, std::vector<std::pair<std::string, int>>> inputs_map_;

std::map<std::string, std::vector<std::pair<std::string, int>>> outputs_map_;


+ 48
- 0
parser/onnx/onnx_retval_parser.cc View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/util.h"
#include "graph/debug/ge_attr_define.h"
#include "onnx_retval_parser.h"
#include "parser/common/op_parser_factory.h"
#include "parser/onnx/onnx_util.h"

using domi::ONNX;
using namespace ge::parser;

namespace ge {
Status OnnxRetvalParser::ParseParams(const Message *op_src, ge::Operator &op_def) {
GE_CHECK_NOTNULL(op_src);
const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src);
GE_CHECK_NOTNULL(node_src);
GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str());

int64_t index = 0;
for (auto it : node_src->attribute()) {
if (it.name() == ge::kAttrNameIndex) {
index = it.i();
GELOGI("The node has attribute with index: %ld", index);
}
}
op_def.SetAttr(ge::RETVAL_ATTR_NAME_INDEX, index);
ge::GeTensorDesc tensor_desc;
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_def);
op_desc->AddInputDesc(tensor_desc);
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(ONNX, _RETVAL, OnnxRetvalParser);
} // namespace ge

+ 29
- 0
parser/onnx/onnx_retval_parser.h View File

@@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_ONNX_RETVAL_PARSER_H_
#define GE_PARSER_ONNX_ONNX_RETVAL_PARSER_H_

#include "parser/onnx/onnx_op_parser.h"

namespace ge {
class PARSER_FUNC_VISIBILITY OnnxRetvalParser : public OnnxOpParser {
public:
Status ParseParams(const Message *op_src, ge::Operator &op_def) override;
};
} // namespace ge

#endif // GE_PARSER_ONNX_ONNX_RETVAL_PARSER_H_

+ 56
- 0
parser/onnx/onnx_subgraph_adapter.h View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_ONNX_SUBGRAPH_ADAPTER_H_
#define GE_PARSER_ONNX_ONNX_SUBGRAPH_ADAPTER_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define PARSER_FUNC_VISIBILITY _declspec(dllexport)
#else
#define PARSER_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define PARSER_FUNC_VISIBILITY
#endif
#endif

#include <map>
#include <deque>
#include "proto/onnx/ge_onnx.pb.h"
#include "external/register/register_error_codes.h"

using Status = domi::Status;

namespace ge {
class PARSER_FUNC_VISIBILITY OnnxSubgraphAdapter {
public:
/// @brief parse params
/// @param [in] parent_op parent op
/// @return SUCCESS parse success
/// @return FAILED Parse failed
virtual Status AdaptAndFindAllOnnxSubgraphs(ge::onnx::NodeProto *parent_op,
std::deque<ge::onnx::GraphProto *> &onnx_graph_tasks,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
return domi::SUCCESS;
}
};
} // namespace ge

#endif // GE_PARSER_ONNX_ONNX_SUBGRAPH_ADAPTER_H_

+ 3
- 0
parser/onnx/onnx_util.h View File

@@ -44,9 +44,12 @@ enum OnnxDataType {
namespace ge {
const char *const kAttrNameValue = "value";
const char *const kAttrNameInput = "input_tensor";
const char *const kAttrNameOutput = "output_tensor";
const char *const kAttrNameIndex = "index";
const char *const kAttrNameIsSubgraphOp = "isSubgraphOp";
const char *const kOpTypeConstant = "Constant";
const char *const kOpTypeInput = "Input";
const char *const kOpTypeOutput = "Output";

class OnnxUtil {
public:


Loading…
Cancel
Save