Browse Source

Pre Merge pull request !385 from 李正龙/r1.6.0-person

pull/385/MERGE
李正龙 Gitee 4 years ago
parent
commit
ba2efa5312
7 changed files with 46 additions and 57 deletions
  1. +2
    -6
      parser/caffe/caffe_data_parser.cc
  2. +5
    -9
      parser/caffe/caffe_parser.cc
  3. +1
    -0
      parser/common/parser_utils.cc
  4. +29
    -27
      parser/common/pre_checker.cc
  5. +1
    -3
      parser/onnx/onnx_constant_parser.cc
  6. +1
    -3
      parser/onnx/onnx_data_parser.cc
  7. +7
    -9
      parser/tensorflow/tensorflow_parser.cc

+ 2
- 6
parser/caffe/caffe_data_parser.cc View File

@@ -80,9 +80,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l
vector<int64_t> shape;
std::map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims;
std::vector<int64_t> model_dims;
for (auto &blob_shape_dim_temp : blob_shape.dim()) {
model_dims.push_back(blob_shape_dim_temp);
}
std::copy(blob_shape.dim().begin(), blob_shape.dim().end(), model_dims.begin());
string name = layer->name();
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name));
GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op),
@@ -124,9 +122,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete
vector<int64_t> shape;
std::map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims;
std::vector<int64_t> model_dims;
for (auto &blob_shape_dim_temp : blob_shape.dim()) {
model_dims.push_back(blob_shape_dim_temp);
}
std::copy(blob_shape.dim().begin(), blob_shape.dim().end(), model_dims.begin());

string name = layer->name();
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name));


+ 5
- 9
parser/caffe/caffe_parser.cc View File

@@ -1159,12 +1159,10 @@ bool CaffeModelParser::CheckValidLayer(const domi::caffe::LayerParameter &layer)
}

bool CaffeModelParser::IsInplaceTopBlob(const domi::caffe::LayerParameter &layer, const std::string &top_name) {
for (auto &bottom_name : layer.bottom()) {
if (top_name == bottom_name) {
return true;
}
}
return false;
return std::any_of(layer.bottom().begin(), layer.bottom().end(),
[=](const std::string &bottom_name) {
return (top_name == bottom_name);
});
}

std::string CaffeModelParser::RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name,
@@ -1358,9 +1356,7 @@ Status CaffeModelParser::Parse(const char *model_path, ge::Graph &graph) {
void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer) {
string name = layer.name();
vector<string> tops;
for (auto top : layer.top()) {
tops.push_back(top);
}
std::copy(layer.top().begin(), layer.top().end(), tops.begin());
auto it = layer_tops_map_.find(name);
if (it == layer_tops_map_.end()) {
layer_tops_map_[name] = tops;


+ 1
- 0
parser/common/parser_utils.cc View File

@@ -26,6 +26,7 @@
#include "graph/utils/node_adapter.h"
#include "graph/utils/op_desc_utils.h"
#include "register/op_registry.h"
#include "operator.h"

namespace ge {
namespace {


+ 29
- 27
parser/common/pre_checker.cc View File

@@ -93,19 +93,20 @@ Status PreChecker::CheckName(OpId id) {
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "[Check][Param] Id does not exist.");

Info &info = iter->second;
for (auto &v : op_map_) {
auto is_exist = std::find_if(op_map_.begin(), op_map_.end(),
[=](const pair<OpId, Info> &v) {
return (id != v.first && info.name == v.second.name);
});
if (is_exist != op_map_.end()) {
// If the name is duplicate, an error is logged
if (id != v.first && info.name == v.second.name) {
Cause cause;
cause.code = NAME_REPEATED;
cause.message = "The name is repeated.";
Cause cause;
cause.code = NAME_REPEATED;
cause.message = "The name is repeated.";

GELOGI("Name %s repeated.", info.name.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name});
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed.");
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "[Add][Cause] failed.");
break;
}
GELOGI("Name %s repeated.", info.name.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name});
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed.");
GE_RETURN_WITH_LOG_IF_ERROR(AddCause((*is_exist).first, cause), "[Add][Cause] failed.");
}

return SUCCESS;
@@ -142,11 +143,13 @@ FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddCause(OpId id, ErrorCode code, co

FMK_FUNC_HOST_VISIBILITY void PreChecker::RefreshErrorMessageByName(const string &op_name, ErrorCode code,
const string &msg) {
for (const auto &op : op_map_) {
if (op.second.name == op_name) {
AddCause(op.second.id, code, msg);
return;
}
auto is_exist = std::find_if(op_map_.begin(), op_map_.end(),
[=](const pair<OpId, Info> &op) {
return (op.second.name == op_name);
});
if (is_exist != op_map_.end()) {
AddCause((*is_exist).second.id, code, msg);
return;
}
GELOGW("Node [%s] not founded in prechecking list.", op_name.c_str());
}
@@ -158,10 +161,12 @@ Status PreChecker::AddCause(OpId id, const Cause &cause) {
Info &info = iter->second;

// Avoid adding repeatedly
for (Cause &c : info.causes) {
if (c.code == cause.code && c.message == cause.message) {
return SUCCESS;
}
auto is_exist = std::any_of(info.causes.begin(), info.causes.end(),
[=](Cause &c) {
return (c.code == cause.code && c.message == cause.message);
});
if (is_exist) {
return SUCCESS;
}

info.causes.push_back(cause);
@@ -279,12 +284,9 @@ bool PreChecker::HasError(OpId id) {
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "[Check][Param] Id does not exist.");

Info &info = iter->second;
for (const Cause &cause : info.causes) {
if (cause.code != ErrorCode::OK) {
return true;
}
}

return false;
return std::any_of(info.causes.begin(), info.causes.end(),
[=](Cause &cause) {
return (cause.code != ErrorCode::OK);
});
}
} // namespace ge

+ 1
- 3
parser/onnx/onnx_constant_parser.cc View File

@@ -123,9 +123,7 @@ void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &t
// for string values
case OnnxDataType::STRING: {
std::vector<std::string> data;
for (auto str_data : tensor_proto.string_data()) {
data.emplace_back(str_data);
}
std::copy(tensor_proto.string_data().begin(), tensor_proto.string_data().end(), data.begin());
tensor.SetData(data);
break;
}


+ 1
- 3
parser/onnx/onnx_data_parser.cc View File

@@ -64,9 +64,7 @@ int64_t OnnxDataParser::ParseInputTensor(const ge::onnx::AttributeProto &attribu
const ::ge::onnx::TensorProto it_tensor = attribute.t();
int64_t data_type = it_tensor.data_type();
GELOGI("Attr name: %s, data type: %ld ", attribute.name().c_str(), data_type);
for (auto dim : it_tensor.dims()) {
model_input_dims_v_.push_back(dim);
}
std::copy(it_tensor.dims().begin(), it_tensor.dims().end(), model_input_dims_v_.begin());
return data_type;
}



+ 7
- 9
parser/tensorflow/tensorflow_parser.cc View File

@@ -596,11 +596,10 @@ void TensorFlowModelParser::GetInputOutputTensorNum(ge::OpDescPtr &op_desc, size
// input number
input_tensor_num = 0;
for (auto &input_vec : dest_input_map) {
for (auto &input_v : input_vec.second) {
if (input_v.second != kControlSlot) {
input_tensor_num++;
}
}
input_tensor_num = std::count_if(input_vec.second.begin(), input_vec.second.end(),
[=](const std::pair<int32_t, int32_t> &input_v) {
return (input_v.second != kControlSlot);
});
}

// output number
@@ -2219,10 +2218,9 @@ bool TensorFlowModelParser::GetEdgesControlInfo(const string &node_name, const i
// If the node name is included, then confirm whether the index is the same
auto iter = edges_control_map.find(node_name);
if (iter != edges_control_map.end()) {
for (auto &i : iter->second) {
if (i == index) {
return true;
}
auto is_exist = std::any_of(iter->second.begin(), iter->second.end(), [=](int32_t &i) {return (i == index);});
if (is_exist) {
return true;
}
}



Loading…
Cancel
Save