|
|
|
@@ -19,9 +19,12 @@ |
|
|
|
#include <iostream> |
|
|
|
#include "common/convert/pb2json.h" |
|
|
|
#include "common/util.h" |
|
|
|
#include "common/ge_types.h" |
|
|
|
#include "common/util/error_manager/error_manager.h" |
|
|
|
#include "external/graph/operator_factory.h" |
|
|
|
#include "external/register/register_error_codes.h" |
|
|
|
#include "external/parser/onnx_parser.h" |
|
|
|
#include "external/ge/ge_api_types.h" |
|
|
|
#include "framework/omg/parser/parser_inner_ctx.h" |
|
|
|
#include "framework/omg/parser/parser_types.h" |
|
|
|
#include "omg/parser/parser_factory.h" |
|
|
|
@@ -35,6 +38,113 @@ |
|
|
|
#include "parser/onnx/onnx_util.h" |
|
|
|
#include "register/op_registry.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, |
|
|
|
ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) { |
|
|
|
GetParserContext().type = domi::ONNX; |
|
|
|
std::map<string, string> options; |
|
|
|
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); |
|
|
|
|
|
|
|
if (acl_graph_parse_util.AclParserInitialize(options) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Acl parser initialize failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
string output_name; |
|
|
|
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Parser params before graph failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
// Create an empty computegraph |
|
|
|
string graph_name = output_name.empty() ? "tmpGraph" : output_name; |
|
|
|
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); |
|
|
|
model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); |
|
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, |
|
|
|
ge::Graph &graph) { |
|
|
|
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Parser params after graph failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus aclgrphParseONNX(const char *model_file, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { |
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
GE_CHECK_NOTNULL(model_file); |
|
|
|
// load custom plugin so and proto |
|
|
|
AclGrphParseUtil acl_graph_parse_util; |
|
|
|
std::shared_ptr<domi::ModelParser> model_parser; |
|
|
|
|
|
|
|
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Prepare before parse failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
|
// parse caffe model_file to GE graph |
|
|
|
ge::graphStatus ret = model_parser->Parse(model_file, graph); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
GELOGI("Parser graph %s success.", graph.GetName().c_str()); |
|
|
|
|
|
|
|
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Handle after parse failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); |
|
|
|
#endif |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { |
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
GE_CHECK_NOTNULL(buffer); |
|
|
|
// load custom plugin so and proto |
|
|
|
AclGrphParseUtil acl_graph_parse_util; |
|
|
|
std::shared_ptr<domi::ModelParser> model_parser; |
|
|
|
|
|
|
|
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Prepare before parse failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
// parse caffe model_file to GE graph |
|
|
|
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
GELOGI("Parser graph %s success.", graph.GetName().c_str()); |
|
|
|
|
|
|
|
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Handle after parse failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); |
|
|
|
#endif |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
} // namespace ge |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
std::map<std::string, std::string> kOnnxOpMap = { |
|
|
|
@@ -107,27 +217,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::ParseOutput(const ge::onnx::GraphProto &onnx_graph) { |
|
|
|
if (onnx_graph.output_size() == 0) { |
|
|
|
GELOGE(FAILED, "Onnx graph has zero output"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < onnx_graph.output_size(); i++) { |
|
|
|
ge::onnx::ValueInfoProto value_info = onnx_graph.output(i); |
|
|
|
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); |
|
|
|
|
|
|
|
auto it = outputs_map_.find(value_info.name()); |
|
|
|
if (it != outputs_map_.end()) { |
|
|
|
std::string node_name = it->second[0].first; |
|
|
|
output_node_names_.emplace_back(node_name); |
|
|
|
GELOGI("Output node name: %s", node_name.c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, |
|
|
|
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) { |
|
|
|
// Construct const node for weight |
|
|
|
@@ -411,8 +500,7 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::GetGraphInputsOutputs(std::vector<ge::Operator> &input_ops, |
|
|
|
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs) { |
|
|
|
Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) { |
|
|
|
for (auto in_name : input_node_names_) { |
|
|
|
auto in_op = name_operator_.find(in_name); |
|
|
|
if (in_op == name_operator_.end()) { |
|
|
|
@@ -424,31 +512,39 @@ Status OnnxModelParser::GetGraphInputsOutputs(std::vector<ge::Operator> &input_o |
|
|
|
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
for (auto it : output_node_names_) { |
|
|
|
auto out_op = name_operator_.find(it); |
|
|
|
if (out_op == name_operator_.end()) { |
|
|
|
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", |
|
|
|
it.c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
output_indexs.emplace_back(out_op->second, std::vector<size_t>{}); |
|
|
|
GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str()); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { |
|
|
|
Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) { |
|
|
|
GE_CHECK_NOTNULL(file); |
|
|
|
GELOGI("File path is %s.", file); |
|
|
|
|
|
|
|
// 1. Get graph from onnx model file. |
|
|
|
ge::onnx::ModelProto onnx_model; |
|
|
|
if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
|
"E19021", {"reason"}, {"Read onnx model file failed."}); |
|
|
|
GELOGE(PARAM_INVALID, "Read onnx model file failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) { |
|
|
|
GE_CHECK_NOTNULL(data); |
|
|
|
|
|
|
|
// 1. Get graph from onnx model file. |
|
|
|
if (!ge::parser::ReadProtoFromArray(data, size, &onnx_model)) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
|
"E19021", {"reason"}, {"Read onnx model from memory failed."}); |
|
|
|
GELOGE(PARAM_INVALID, "Read onnx model from memory failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { |
|
|
|
if (!onnx_model.has_graph()) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E16004"); |
|
|
|
GELOGE(PARAM_INVALID, "Onnx model do not has graph."); |
|
|
|
@@ -515,14 +611,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
// 8. Parse output from graph. |
|
|
|
ret = ParseOutput(onnx_graph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Parse output failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
// 9. Set all operator input. |
|
|
|
// 8. Set all operator input. |
|
|
|
ret = SetOperatorInputs(); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Set operator input failed."); |
|
|
|
@@ -533,15 +622,15 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { |
|
|
|
graph.GetAllOpName(op_names); |
|
|
|
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); |
|
|
|
|
|
|
|
// 10. Construct graph. |
|
|
|
// 9. Construct graph. |
|
|
|
std::vector<ge::Operator> input_ops; |
|
|
|
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs; |
|
|
|
ret = GetGraphInputsOutputs(input_ops, output_indexs); |
|
|
|
|
|
|
|
ret = GetGraphInputs(input_ops); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Get graph inputs and outputs failed."); |
|
|
|
GELOGE(ret, "Get graph inputs failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
graph.SetInputs(input_ops).SetOutputs(output_indexs); |
|
|
|
graph.SetInputs(input_ops); |
|
|
|
|
|
|
|
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); |
|
|
|
|
|
|
|
@@ -551,6 +640,38 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { |
|
|
|
ge::onnx::ModelProto onnx_model; |
|
|
|
Status ret = GetModelFromFile(file, onnx_model); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "get model from file failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
ret = ModelParseToGraph(onnx_model, graph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "parse model failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { |
|
|
|
ge::onnx::ModelProto onnx_model; |
|
|
|
Status ret = GetModelFromMemory(data, size, onnx_model); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "get model from file failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
ret = ModelParseToGraph(onnx_model, graph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "parse model failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { |
|
|
|
if (model_file == nullptr) { |
|
|
|
GELOGE(FAILED, "Model file is nullptr."); |
|
|
|
|