|
|
|
@@ -19,8 +19,8 @@ |
|
|
|
#include <iostream> |
|
|
|
#include "common/convert/pb2json.h" |
|
|
|
#include "common/util.h" |
|
|
|
#include "common/ge_types.h" |
|
|
|
#include "common/util/error_manager/error_manager.h" |
|
|
|
#include "common/ge_types.h" |
|
|
|
#include "external/graph/operator_factory.h" |
|
|
|
#include "external/register/register_error_codes.h" |
|
|
|
#include "external/parser/onnx_parser.h" |
|
|
|
@@ -39,17 +39,16 @@ |
|
|
|
#include "register/op_registry.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, |
|
|
|
ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) { |
|
|
|
graphStatus aclgrphParseONNX(const char *model_file, |
|
|
|
std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { |
|
|
|
GE_CHECK_NOTNULL(model_file); |
|
|
|
GetParserContext().type = domi::ONNX; |
|
|
|
std::map<string, string> options; |
|
|
|
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); |
|
|
|
|
|
|
|
if (acl_graph_parse_util.AclParserInitialize(options) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Acl parser initialize failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
// load custom plugin so and proto |
|
|
|
AclGrphParseUtil acl_graph_parse_util; |
|
|
|
(void)acl_graph_parse_util.AclParserInitialize(options); |
|
|
|
|
|
|
|
string output_name; |
|
|
|
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { |
|
|
|
@@ -62,40 +61,9 @@ graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); |
|
|
|
model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); |
|
|
|
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); |
|
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, |
|
|
|
ge::Graph &graph) { |
|
|
|
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Parser params after graph failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus aclgrphParseONNX(const char *model_file, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { |
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
GE_CHECK_NOTNULL(model_file); |
|
|
|
// load custom plugin so and proto |
|
|
|
AclGrphParseUtil acl_graph_parse_util; |
|
|
|
std::shared_ptr<domi::ModelParser> model_parser; |
|
|
|
|
|
|
|
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Prepare before parse failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
|
// parse caffe model_file to GE graph |
|
|
|
ge::graphStatus ret = model_parser->Parse(model_file, graph); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
@@ -104,44 +72,63 @@ graphStatus aclgrphParseONNX(const char *model_file, |
|
|
|
} |
|
|
|
GELOGI("Parser graph %s success.", graph.GetName().c_str()); |
|
|
|
|
|
|
|
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Handle after parse failed."); |
|
|
|
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Parser params after graph failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); |
|
|
|
#endif |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, |
|
|
|
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { |
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t buffer_size, |
|
|
|
std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { |
|
|
|
GE_CHECK_NOTNULL(buffer); |
|
|
|
GetParserContext().type = domi::ONNX; |
|
|
|
std::map<string, string> options; |
|
|
|
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); |
|
|
|
|
|
|
|
// load custom plugin so and proto |
|
|
|
AclGrphParseUtil acl_graph_parse_util; |
|
|
|
std::shared_ptr<domi::ModelParser> model_parser; |
|
|
|
(void)acl_graph_parse_util.AclParserInitialize(options); |
|
|
|
|
|
|
|
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Prepare before parse failed."); |
|
|
|
string output_name; |
|
|
|
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Parser params before graph failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
// Create an empty computegraph |
|
|
|
string graph_name = output_name.empty() ? "tmpGraph" : output_name; |
|
|
|
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
// parse caffe model_file to GE graph |
|
|
|
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph); |
|
|
|
graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); |
|
|
|
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); |
|
|
|
GE_CHECK_NOTNULL(model_parser); |
|
|
|
|
|
|
|
// parse caffe model_file and weights_file to GE graph |
|
|
|
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)buffer_size, graph); |
|
|
|
if (ret != ge::SUCCESS) { |
|
|
|
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
GELOGI("Parser graph %s success.", graph.GetName().c_str()); |
|
|
|
|
|
|
|
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Handle after parse failed."); |
|
|
|
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Parser params after graph failed."); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { |
|
|
|
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); |
|
|
|
return ge::FAILED; |
|
|
|
} |
|
|
|
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); |
|
|
|
#endif |
|
|
|
return ge::SUCCESS; |
|
|
|
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); |
|
|
|
return ge::SUCCESS; |
|
|
|
} |
|
|
|
} // namespace ge |
|
|
|
|
|
|
|
@@ -159,7 +146,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
|
GELOGE(FAILED, "Onnx graph has zero input"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
// get input value info map |
|
|
|
std::map<std::string, ge::onnx::TensorProto> input_name_tensor; |
|
|
|
for (int i = 0; i < onnx_graph.input_size(); i++) { |
|
|
|
@@ -173,7 +159,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
|
initializer_name_tensor.erase(initializer_iter); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
ge::onnx::TensorProto tensor_tmp; |
|
|
|
if (value_info.has_type()) { |
|
|
|
const ge::onnx::TypeProto type = value_info.type(); |
|
|
|
@@ -194,7 +179,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
|
} |
|
|
|
input_name_tensor[value_info.name()] = tensor_tmp; |
|
|
|
} |
|
|
|
|
|
|
|
// Construct node for input |
|
|
|
int64_t index = 0; |
|
|
|
for (auto it : input_name_tensor) { |
|
|
|
@@ -511,11 +495,10 @@ Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) { |
|
|
|
input_ops.emplace_back(in_op->second); |
|
|
|
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) { |
|
|
|
Status OnnxModelParser::GetModelFromfile(const char *file, ge::onnx::ModelProto &onnx_model) { |
|
|
|
GE_CHECK_NOTNULL(file); |
|
|
|
GELOGI("File path is %s.", file); |
|
|
|
|
|
|
|
@@ -529,20 +512,18 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) { |
|
|
|
GE_CHECK_NOTNULL(data); |
|
|
|
|
|
|
|
// 1. Get graph from onnx model file. |
|
|
|
if (!ge::parser::ReadProtoFromArray(data, size, &onnx_model)) { |
|
|
|
// 1. Get graph from memory. |
|
|
|
if (!ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &onnx_model)) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
|
"E19021", {"reason"}, {"Read onnx model from memory failed."}); |
|
|
|
GELOGE(PARAM_INVALID, "Read onnx model from memory failed."); |
|
|
|
"E19021", {"reason"}, {"Read onnx model file failed."}); |
|
|
|
GELOGE(PARAM_INVALID, "Read onnx model file failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { |
|
|
|
if (!onnx_model.has_graph()) { |
|
|
|
@@ -551,13 +532,11 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
ge::onnx::GraphProto onnx_graph = onnx_model.graph(); |
|
|
|
|
|
|
|
auto opset_import = onnx_model.opset_import(); |
|
|
|
for (auto it : opset_import) { |
|
|
|
domain_verseion_[it.domain()] = it.version(); |
|
|
|
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); |
|
|
|
} |
|
|
|
|
|
|
|
// 2. Get all inializer. |
|
|
|
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; |
|
|
|
for (int i = 0; i < onnx_graph.initializer_size(); i++) { |
|
|
|
@@ -567,7 +546,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
GELOGI("Initializer name: %s .", initializer_tensor.name().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// 3. Parse Input from graph. |
|
|
|
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); |
|
|
|
Status ret = ParseInput(onnx_graph, initializer_name_tensor); |
|
|
|
@@ -576,21 +554,18 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
return ret; |
|
|
|
} |
|
|
|
GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); |
|
|
|
|
|
|
|
// 4. Parse Constant from graph. |
|
|
|
ret = ParseInitializer(onnx_graph, initializer_name_tensor); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Parse initializer for onnx failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
// 5. Update node name for node do not has name. |
|
|
|
ret = UpdateAllNodeName(onnx_graph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Update all node name for onnx failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
// 6 Precheck. |
|
|
|
ret = Prechecker(onnx_graph); |
|
|
|
bool is_precheck_failed = (ret != SUCCESS) || (ge::PreChecker::Instance().HasError()); |
|
|
|
@@ -624,7 +599,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
|
|
|
|
// 9. Construct graph. |
|
|
|
std::vector<ge::Operator> input_ops; |
|
|
|
|
|
|
|
ret = GetGraphInputs(input_ops); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Get graph inputs failed."); |
|
|
|
@@ -642,35 +616,33 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model |
|
|
|
|
|
|
|
Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { |
|
|
|
ge::onnx::ModelProto onnx_model; |
|
|
|
Status ret = GetModelFromFile(file, onnx_model); |
|
|
|
Status ret = GetModelFromfile(file, onnx_model); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "get model from file failed."); |
|
|
|
return FAILED; |
|
|
|
GELOGE(ret, "Get model from file failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
ret = ModelParseToGraph(onnx_model, graph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "parse model failed."); |
|
|
|
return FAILED; |
|
|
|
GELOGE(ret, "Parse model failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
#ifndef ONLY_COMPILE_OPEN_SRC |
|
|
|
Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { |
|
|
|
ge::onnx::ModelProto onnx_model; |
|
|
|
Status ret = GetModelFromMemory(data, size, onnx_model); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "get model from file failed."); |
|
|
|
return FAILED; |
|
|
|
GELOGE(ret, "Get model from memory failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
ret = ModelParseToGraph(onnx_model, graph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "parse model failed."); |
|
|
|
return FAILED; |
|
|
|
GELOGE(ret, "Parse model failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { |
|
|
|
if (model_file == nullptr) { |
|
|
|
|