Browse Source

Pre Merge pull request !641 from xujiuxu/ge_dev

pull/641/MERGE
xujiuxu Gitee 3 years ago
parent
commit
3795ec534b
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 80 additions and 78 deletions
  1. +57
    -78
      parser/caffe/caffe_parser.cc
  2. +23
    -0
      parser/caffe/caffe_parser.h

+ 57
- 78
parser/caffe/caffe_parser.cc View File

@@ -1647,22 +1647,14 @@ Status CaffeModelParser::ReorderInput(domi::caffe::NetParameter &net) const {
for (const auto &it : move_input_vec) {
if (it.moveType == domi::OMG_INPUT_REORDER) {
auto inputs = layer->bottom();
if (static_cast<size_t>(inputs.size()) != it.input_order.size()) {
REPORT_INNER_ERROR("E19999", "Size of input is mismatched, check invalid,"
"new order size is %zu, input size is %d.", it.input_order.size(), inputs.size());
GELOGE(INTERNAL_ERROR, "[Check][Size]Size of input is mismatched, new order size is %zu, input size is %d.",
it.input_order.size(), inputs.size());
return INTERNAL_ERROR;
}
GE_CHK_BOOL_RET_STATUS(static_cast<size_t>(inputs.size()) == it.input_order.size(), INTERNAL_ERROR,
"[Check][Size]Size of input is mismatched, new order size is %zu, input size is %d.",
it.input_order.size(), inputs.size());
for (size_t j = 0; j < it.input_order.size(); ++j) {
int new_index = it.input_order[j];
if (new_index < 0 || new_index >= inputs.size()) {
REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, which is out of range, "
"inputs size:%d.", layer->name().c_str(), new_index, inputs.size());
GELOGE(INTERNAL_ERROR, "[Check][Param]New order of %s has invalid index %d, which is out of range, "
"inputs size:%d.", layer->name().c_str(), new_index, inputs.size());
return INTERNAL_ERROR;
}
GE_CHK_BOOL_RET_STATUS((new_index >= 0) && (new_index < inputs.size()), INTERNAL_ERROR,
"[Check][Param]New order of %s has invalid index %d, which is out of range, "
"inputs size:%d.", layer->name().c_str(), new_index, inputs.size());
layer->set_bottom(j, inputs[new_index]);
}
GELOGI("The input sequence of the node has been rearranged, node name:%s.", layer->name().c_str());
@@ -1720,7 +1712,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::Graph &graph) {
}

Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) {
if (file == nullptr) {
if (file == nullptr) {
REPORT_INNER_ERROR("E19999", "param file is nullptr, check invalid.");
GELOGE(FAILED, "[Check][Param]Caffe weights parse fail, Parameter file invalid");
return PARAM_INVALID;
@@ -1745,40 +1737,27 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) {
GELOGW("custom_proto_path:%s is not existed", custom_proto_path.c_str());
fusion_proto_file = caffe_proto_path;
} else {
if (proto_file_parser.CombineProtoFile(caffe_proto_path.c_str(), custom_proto_path.c_str(),\
fusion_proto_file) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "CombineProtoFile failed, caffe_proto_path:%s, custom_proto_path:%s.",
caffe_proto_path.c_str(), custom_proto_path.c_str());
GELOGE(FAILED, "[Invoke][CombineProtoFile]Create tmp fusion proto file from caffe and custom proto failed.");
return FAILED;
}
GE_CHK_BOOL_RET_STATUS(proto_file_parser.CombineProtoFile(caffe_proto_path.c_str(), custom_proto_path.c_str(),
fusion_proto_file) == SUCCESS, FAILED, "[Invoke][CombineProtoFile]"
"Create tmp fusion proto file from caffe and custom proto failed.");
}

string fusion_proto_path = ge::parser::RealPath(fusion_proto_file.c_str());
GELOGI("Get fusion proto file[%s]-[%s].", fusion_proto_file.c_str(), fusion_proto_path.c_str());
if (fusion_proto_path.empty()) {
REPORT_INNER_ERROR("E19999", "Fusion proto file path [%s] is not real existed.",
fusion_proto_file.c_str());
GELOGE(FAILED, "[Invoke][RealPath]Fusion proto file path [%s]-[%s] is not real existed.",
fusion_proto_file.c_str(), fusion_proto_path.c_str());
return FAILED;
}

GE_CHK_BOOL_RET_STATUS(!fusion_proto_path.empty(), FAILED,
"[Invoke][RealPath]Fusion proto file path [%s]-[%s] is not real existed.",
fusion_proto_file.c_str(), fusion_proto_path.c_str());
string fusion_proto_name;
if (CheckPathValid(file, fusion_proto_file, fusion_proto_path, fusion_proto_name) != SUCCESS) {
GELOGE(FAILED, "[Check][PathValid] of weight file[%s] and tmp proto[%s] failed.", file,
fusion_proto_file.c_str());
return FAILED;
}
GE_CHK_STATUS_RET(CheckPathValid(file, fusion_proto_file, fusion_proto_path, fusion_proto_name),
"[Check][PathValid] of weight file[%s] and tmp proto[%s] failed.",
file, fusion_proto_file.c_str());

GELOGI("Start to parse weight: %s by fusion proto: %s.", file, fusion_proto_file.c_str());
Status status = ParseWeightByFusionProto(file, fusion_proto_path, fusion_proto_name, graph);
if (status != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseWeightByFusionProto] failed. ret:%u", status);
return status;
}
GE_CHK_STATUS_RET(ParseWeightByFusionProto(file, fusion_proto_path, fusion_proto_name, graph),
"[Invoke][ParseWeightByFusionProto] failed.");

status = CheckNodes(graph);
Status status = CheckNodes(graph);
if (status != SUCCESS) {
GELOGE(ge::GRAPH_FAILED, "[Check][Nodes] failed, status=%u", status);
return domi::PARSE_WEIGHTS_FAILED;
@@ -2182,13 +2161,8 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE);
GE_CHECK_NOTNULL(factory);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);

if (op_parser.get() == nullptr) {
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
return FAILED;
}
GE_CHK_BOOL_RET_STATUS(op_parser.get() != nullptr, FAILED,
"[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());

// Parsing weight information through op parser
Status status = op_parser->ParseWeights(layer_message, node);
@@ -2236,14 +2210,7 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) {
return SUCCESS;
}

Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
int num_layer = param.layer_size();
int num_layers = param.layers_size();

// Operator name and occurrence map, handle duplicate operators
std::map<std::string, int32_t> layer_name_map;

Status CaffeWeightsParser::CheckLayerNumValid(const int32_t num_layer, const int32_t num_layers) {
if (num_layer == 0 && num_layers > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11023");
GELOGE(FAILED, "[Check][Param] The weight file is consisted of layers-structure "
@@ -2256,7 +2223,40 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
GELOGE(FAILED, "weight layer num is zero, weight file may be invalid.");
return FAILED;
}
return SUCCESS;
}

Status CaffeWeightsParser::ParserWeightByOpParserFactory(const string op_type, const domi::caffe::LayerParameter &layer,
const string &layer_name, ge::NodePtr node) {
// create OpParser
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE);
GE_CHECK_NOTNULL(factory);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);

if (op_parser.get() == nullptr) {
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
return FAILED;
}

// Parsing weight information through op parser
Status status = op_parser->ParseWeights(&layer, node);
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str());
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str());
return status;
}
return SUCCESS;
}

Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
int num_layer = param.layer_size();
int num_layers = param.layers_size();
GE_CHK_STATUS_RET_NOLOG(CheckLayerNumValid(num_layer, num_layers));
// Operator name and occurrence map, handle duplicate operators
std::map<std::string, int32_t> layer_name_map;
for (int i = 0; i < num_layer; ++i) {
const LayerParameter &layer = param.layer(i);
const string &param_layer_name = layer.name();
@@ -2295,12 +2295,10 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
ge::NodePtr node = graph->FindNode(layer_name);
layer_name_map.insert(std::make_pair(layer_name, kNumOne));
if (node == nullptr) {
// If there are redundant layers in the weight file, they should be skipped rather than returned with an error.
GELOGI("Layer %s not found in graph", layer_name.c_str());
continue;
}

// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later.
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer.type());
if (iter == caffe_op_map.end()) {
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str());
@@ -2308,26 +2306,7 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
}
GELOGD("Caffe layer name: %s , layer type: %s.", layer_name.c_str(), layer.type().c_str());
string op_type = iter->second;

// create OpParser
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE);
GE_CHECK_NOTNULL(factory);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);

if (op_parser.get() == nullptr) {
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
return FAILED;
}

// Parsing weight information through op parser
Status status = op_parser->ParseWeights(&layer, node);
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str());
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str());
return status;
}
GE_CHK_STATUS_RET_NOLOG(ParserWeightByOpParserFactory(op_type, layer, layer_name, node));
}
}



+ 23
- 0
parser/caffe/caffe_parser.h View File

@@ -381,6 +381,29 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {
*/
static Status ConvertNetParameter(const NetParameter &param, ge::ComputeGraphPtr &graph);

/**
* @ingroup domi_omg
* @brief Check layer number and layers number invalid or not
* @param [in] layer size
* @param [in] layers size
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
static Status CheckLayerNumValid(const int32_t num_layer, const int32_t num_layers);

/**
* @ingroup domi_omg
* @brief parse weigth by op factory
* @param [in] op_type type of op
* @param [in] layer layer parameters
* @param [in] layer_name name of layer
* @param [in] node ge node ptr
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
static Status ParserWeightByOpParserFactory(const string op_type, const domi::caffe::LayerParameter &layer,
const string &layer_name, ge::NodePtr node);

Status Parse(const char *file, ge::ComputeGraphPtr &graph);

Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path,


Loading…
Cancel
Save