From 1cd3b2fe03846e546bc76df46e8ae8646f8b95a1 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Mon, 25 Jan 2021 15:24:26 +0800 Subject: [PATCH] resolve issue --- .../parser/caffe/caffe_model_parser.cc | 41 ++++++++++++++--- .../parser/caffe/caffe_model_parser.h | 3 +- .../parser/caffe/caffe_reduce_parser.cc | 6 +-- .../converter/parser/onnx/onnx_relu_parser.cc | 15 ++++--- .../parser/tflite/tflite_fill_parser.cc | 5 ++- .../parser/tflite/tflite_model_parser.cc | 44 ++++++++----------- 6 files changed, 70 insertions(+), 44 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 6860ac7bdb..6e59dbdbb0 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -25,6 +25,13 @@ #include "src/param_value_lite.h" namespace mindspore::lite { +bool IsSkipedLayer(const caffe::LayerParameter &layer) { + if (layer.type() == "Input" || layer.type() == "Dropout" || layer.type() == "Split") { + return true; + } + return layer.include_size() == 1 && layer.include(0).phase() == caffe::TRAIN; +} + CaffeModelParser::CaffeModelParser() = default; CaffeModelParser::~CaffeModelParser() = default; @@ -68,6 +75,11 @@ STATUS CaffeModelParser::ConvertLayers() { } for (int i = 0; i < caffe_model_.layer_size(); i++) { auto layer = caffe_model_.layer(i); + + // save caffe layers + for (int top_idx = 0; top_idx < layer.top_size(); top_idx++) { + caffe_layers_[layer.top(top_idx)] = layer; + } caffe::LayerParameter weight; if (weight_layers.find(layer.name()) != weight_layers.end()) { weight = weight_layers.find(layer.name())->second; @@ -385,11 +397,17 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std:: return RET_NULL_PTR; } for (int i = 0; i < layer.bottom_size(); i++) { - if (nodes_.find(layer.bottom(i)) == nodes_.end()) { + string origin_layer = GetOriginLayerName(layer.bottom(i)); + if (origin_layer.empty()) { + MS_LOG(ERROR) << "layer not found"; + return RET_ERROR; + } + + if (nodes_.find(origin_layer) == nodes_.end()) { MS_LOG(ERROR) << "layer bottom " << layer.bottom(i) << " is not found"; return RET_NOT_FIND_OP; } - input_nodes->emplace_back(nodes_.find(layer.bottom(i))->second); + input_nodes->emplace_back(nodes_.find(origin_layer)->second); } return RET_OK; } @@ -422,11 +440,22 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN return RET_OK; } -bool CaffeModelParser::IsSkipedLayer(const caffe::LayerParameter &layer) { - if (layer.type() == "Input" || layer.type() == "Dropout") { - return true; +std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name) { + if (caffe_layers_.find(layer_name) == caffe_layers_.end()) { + return layer_name; } - return layer.include_size() == 1 && layer.include(0).phase() == caffe::TRAIN; + auto layer = caffe_layers_.at(layer_name); + if (layer.type() != "Split") { + return layer_name; + } + while (layer.type() == "Split") { + string input_name = layer.bottom(0); + if (caffe_layers_.find(input_name) == caffe_layers_.end()) { + return input_name; + } + layer = caffe_layers_.at(input_name); + } + return layer.name(); } MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index ca61a934a0..e6fb5221a4 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -55,10 +55,11 @@ class CaffeModelParser : public ModelParser { STATUS ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode); - bool IsSkipedLayer(const caffe::LayerParameter &layer); + std::string GetOriginLayerName(const std::string &layer_name); caffe::NetParameter caffe_model_; caffe::NetParameter caffe_weight_; + std::unordered_map caffe_layers_; std::unordered_map nodes_; FuncGraphPtr func_graph_ptr_; }; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc index 5e890fc357..0961e6fb40 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc @@ -50,11 +50,9 @@ PrimitiveC *CaffeReduceParser::ParseLitePrimitive(const caffe::LayerParameter &p std::vector axes; if (reduce_param.has_axis()) { - axes.push_back(1); - axes.push_back(reduce_param.axis()); + axes = std::vector(1, reduce_param.axis()); } else { - axes.push_back(1); - axes.push_back(0); + axes = std::vector(1, 0); } attr->axes = axes; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc index cb80242c35..14aba17789 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc @@ -78,18 +78,21 @@ lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &on MS_LOG(ERROR) << "input error: params[0] is null"; return nullptr; } - const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); - const int64_t slope_size = slope->raw_data().size() / sizeof(float); - if (slope_size == 1) { - attr->slope.push_back(*slope_raw_data); - attr->channelShared = true; + if (slope->float_data_size() > 0) { + const int64_t slope_size = slope->float_data_size(); + for (int64_t i = 0; i < slope_size; i++) { + attr->slope.emplace_back(slope->float_data(i)); + } + attr->channelShared = slope_size == 1; } else { + const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); + const int64_t slope_size = slope->raw_data().size() / sizeof(float); attr->slope.resize(slope_size); - attr->channelShared = false; if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; return nullptr; } + attr->channelShared = slope_size == 1; } } else { MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc index ee1195a6c7..6c366ef787 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -36,7 +36,10 @@ PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptrinputs.size() > 1) { - if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { + const auto &tflite_model_buffers = tflite_model->buffers; + const auto &data = tflite_model_buffers.at(tflite_op->inputs[1])->data; + if (!data.empty() && + GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { MS_LOG(ERROR) << "get fill -> dims failed"; return nullptr; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index f19cf8e0d6..08995152e9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -81,6 +81,20 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: return func_graph_; } +std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) { + std::string tensor_name = op_name + "/input-" + std::to_string(index); + if (op_type == tflite::BuiltinOperator_CONV_2D || op_type == tflite::BuiltinOperator_TRANSPOSE_CONV || + op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || op_type == tflite::BuiltinOperator_FULLY_CONNECTED) { + if (index == 1) { + tensor_name = op_name + "/weight"; + } + if (index == 2) { + tensor_name = op_name + "/bias"; + } + } + return tensor_name; +} + STATUS TfliteModelParser::ConvertOps() { const auto &tflite_subgraph = tflite_model_->subgraphs.front(); NoSupportOp::GetInstance()->SetFmkType("TFLITE"); @@ -136,18 +150,7 @@ STATUS TfliteModelParser::ConvertOps() { if (!input_tensor->name.empty()) { tensor_name = input_tensor->name; } else { - tensor_name = op_name + "/input-" + std::to_string(op_inputs.size()); - if (tflite_op_type == tflite::BuiltinOperator_CONV_2D || - tflite_op_type == tflite::BuiltinOperator_TRANSPOSE_CONV || - tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || - tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) { - if (i == 1) { - tensor_name = op_name + "/weight"; - } - if (i == 2) { - tensor_name = op_name + "/bias"; - } - } + tensor_name = GetTensorName(i, tflite_op_type, op_name); } auto parameter = func_graph_->add_parameter(); status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name); @@ -155,18 +158,7 @@ STATUS TfliteModelParser::ConvertOps() { MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; continue; } - - if (tflite_op_type == tflite::BuiltinOperator_CONV_2D || - tflite_op_type == tflite::BuiltinOperator_DEPTHWISE_CONV_2D || - tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED) { - if (op_inputs.size() == 2) { - parameter->set_name(op_name + "/weight"); - } else if (op_inputs.size() == 3) { - parameter->set_name(op_name + "/bias"); - } - } else { - parameter->set_name(op_name + "/input-" + std::to_string(op_inputs.size() - 1)); - } + parameter->set_name(tensor_name); op_inputs.emplace_back(parameter); nodes_.insert(std::pair(input_idx, parameter)); } @@ -364,7 +356,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para MS_LOG(ERROR) << "parameter is null, get const tensor failed."; return RET_NULL_PTR; } - const auto &tfliteModelBuffers = tflite_model_->buffers; + const auto &tflite_model_buffers = tflite_model_->buffers; auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); std::vector shape_vector; (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), @@ -378,7 +370,7 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para param_value->set_tensor_shape(tensor->shape); param_value->set_tensor_type(GetTfliteDataType(tensor->type)); param_value->set_format(schema::Format::Format_NHWC); - const auto &data = tfliteModelBuffers.at(tensor->buffer)->data; + const auto &data = tflite_model_buffers.at(tensor->buffer)->data; if (!data.empty()) { auto size = data.size(); char *tensor_data = new (std::nothrow) char[size];