Browse Source

!251 add stage set

Merge pull request !251 from 王笑天/master
pull/251/MERGE
i-robot Gitee 5 years ago
parent
commit
c5ff9e6922
4 changed files with 17 additions and 2 deletions
  1. +1
    -1
      metadef
  2. +6
    -0
      parser/caffe/caffe_parser.cc
  3. +3
    -1
      parser/onnx/onnx_parser.cc
  4. +7
    -0
      parser/tensorflow/tensorflow_parser.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 3df578f6d08e51e2c4ed8a023fb8482109941665
Subproject commit 4f983b12aa972e9e89e1c212f4d4443aea00ef31

+ 6
- 0
parser/caffe/caffe_parser.cc View File

@@ -74,6 +74,7 @@ using std::ifstream;


namespace ge { namespace ge {
graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) { graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(model_file); GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::CAFFE; GetParserContext().type = domi::CAFFE;
std::map<string, string> options; std::map<string, string> options;
@@ -121,6 +122,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,


graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(model_file); GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::CAFFE; GetParserContext().type = domi::CAFFE;
std::map<string, string> options; std::map<string, string> options;
@@ -1464,6 +1466,7 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) {
} }


Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
bool has_error = false; bool has_error = false;


GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "model data is nullptr."); GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "model data is nullptr.");
@@ -1586,6 +1589,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
} }


Status CaffeModelParser::Parse(const char *model_path, ge::Graph &graph) { Status CaffeModelParser::Parse(const char *model_path, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(model_path); GE_CHECK_NOTNULL(model_path);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);
@@ -1865,6 +1869,7 @@ Status CaffeModelParser::ReorderInput(domi::caffe::NetParameter &net) {
} }


Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
if (data == nullptr) { if (data == nullptr) {
GELOGE(PARAM_INVALID, "Caffe weights data is nullptr"); GELOGE(PARAM_INVALID, "Caffe weights data is nullptr");
return PARAM_INVALID; return PARAM_INVALID;
@@ -1892,6 +1897,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::
} }


Status CaffeWeightsParser::Parse(const char *file, ge::Graph &graph) { Status CaffeWeightsParser::Parse(const char *file, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(file); GE_CHECK_NOTNULL(file);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);


+ 3
- 1
parser/onnx/onnx_parser.cc View File

@@ -83,7 +83,7 @@ graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util,
} }


graphStatus aclgrphParseONNX(const char *model_file, graphStatus aclgrphParseONNX(const char *model_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file); GE_CHECK_NOTNULL(model_file);
// load custom plugin so and proto // load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util; AclGrphParseUtil acl_graph_parse_util;
@@ -641,6 +641,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
} }


Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
ge::onnx::ModelProto onnx_model; ge::onnx::ModelProto onnx_model;
Status ret = GetModelFromFile(file, onnx_model); Status ret = GetModelFromFile(file, onnx_model);
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -656,6 +657,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
} }


Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
ge::onnx::ModelProto onnx_model; ge::onnx::ModelProto onnx_model;
Status ret = GetModelFromMemory(data, size, onnx_model); Status ret = GetModelFromMemory(data, size, onnx_model);
if (ret != SUCCESS) { if (ret != SUCCESS) {


+ 7
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -88,6 +88,7 @@ using ge::parser::ModelSaver;


namespace ge { namespace ge {
graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(model_file); GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::TENSORFLOW; GetParserContext().type = domi::TENSORFLOW;
std::map<string, string> options; std::map<string, string> options;
@@ -127,6 +128,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {


graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params, graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) { ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(model_file); GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::TENSORFLOW; GetParserContext().type = domi::TENSORFLOW;
std::map<string, string> options; std::map<string, string> options;
@@ -1074,6 +1076,7 @@ Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef
} }


Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(data); GE_CHECK_NOTNULL(data);
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);


@@ -1216,6 +1219,7 @@ Status TensorFlowModelParser::GetFunctionProto(const string &file,
} }


Status TensorFlowModelParser::Parse(const char *model_path, ge::Graph &graph) { Status TensorFlowModelParser::Parse(const char *model_path, ge::Graph &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(model_path); GE_CHECK_NOTNULL(model_path);
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(root_graph); GE_CHECK_NOTNULL(root_graph);
@@ -1309,6 +1313,7 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr
} }


Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
GE_CHECK_NOTNULL(proto); GE_CHECK_NOTNULL(proto);
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);


@@ -2141,6 +2146,7 @@ Status TensorFlowWeightsParser::ParseFromMemory(const char *data, uint32_t size,
Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) { return SUCCESS; } Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) { return SUCCESS; }


Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
PARSER_TIMESTAMP_START(ParseProto); PARSER_TIMESTAMP_START(ParseProto);
GE_CHECK_NOTNULL(proto); GE_CHECK_NOTNULL(proto);
@@ -2268,6 +2274,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,


Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Message *root_proto,
domi::GetGraphCallback callback, ge::ComputeGraphPtr &root_graph) { domi::GetGraphCallback callback, ge::ComputeGraphPtr &root_graph) {
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kParser);
ErrorManager::GetInstance().GenWorkStreamIdDefault(); ErrorManager::GetInstance().GenWorkStreamIdDefault();
GE_CHECK_NOTNULL(root_proto); GE_CHECK_NOTNULL(root_proto);
GE_CHECK_NOTNULL(callback); GE_CHECK_NOTNULL(callback);


Loading…
Cancel
Save