Browse Source

fix onnx pool parser, Fixed striped_sclice memory out of bounds

tags/v1.1.0
gongdaguo 5 years ago
parent
commit
cd858eae61
3 changed files with 17 additions and 7 deletions
  1. +11
    -6
      mindspore/lite/src/ops/strided_slice.cc
  2. +1
    -0
      mindspore/lite/test/models_onnx.cfg
  3. +5
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc

+ 11
- 6
mindspore/lite/src/ops/strided_slice.cc View File

@@ -254,18 +254,17 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
auto input = inputs.at(0); auto input = inputs.at(0);
outputs.front()->set_data_type(input->data_type()); outputs.front()->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat()); outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto input_shape = input->shape(); auto input_shape = input->shape();
std::vector<int> output_shape;
auto inferflag = GetInferFlag();


if (inputs.size() == kStridedSliceInputNum) { if (inputs.size() == kStridedSliceInputNum) {
ndim_ = static_cast<int>(GetBegin().size()); ndim_ = static_cast<int>(GetBegin().size());


for (int i = 0; i < ndim_; i++) { for (int i = 0; i < ndim_; i++) {
in_shape_.emplace_back(input_shape.at(i));
if (inferflag) {
in_shape_.emplace_back(input_shape.at(i));
}
begins_.emplace_back((GetBegin())[i]); begins_.emplace_back((GetBegin())[i]);
ends_.emplace_back((GetEnd())[i]); ends_.emplace_back((GetEnd())[i]);
strides_.emplace_back((GetStride())[i]); strides_.emplace_back((GetStride())[i]);
@@ -282,7 +281,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
} }
ndim_ = begin_tensor->ElementsNum(); ndim_ = begin_tensor->ElementsNum();
for (int i = 0; i < ndim_; ++i) { for (int i = 0; i < ndim_; ++i) {
in_shape_.emplace_back(input_shape.at(i));
if (inferflag) {
in_shape_.emplace_back(input_shape.at(i));
}
begins_.emplace_back(begin_data[i]); begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]); ends_.emplace_back(end_data[i]);
strides_.emplace_back(stride_data[i]); strides_.emplace_back(stride_data[i]);
@@ -310,6 +311,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
ApplyEndMask(); ApplyEndMask();
ApplyEllipsisMask(); ApplyEllipsisMask();


if (!inferflag) {
return RET_OK;
}
std::vector<int> output_shape;
output_shape.clear(); output_shape.clear();
output_shape.resize(in_shape_.size()); output_shape.resize(in_shape_.size());




+ 1
- 0
mindspore/lite/test/models_onnx.cfg View File

@@ -1,6 +1,7 @@
mtk_detect-mbv2-shortcut-400-400-simplified.onnx mtk_detect-mbv2-shortcut-400-400-simplified.onnx
mtk_emotions-d2012-75.8%.onnx mtk_emotions-d2012-75.8%.onnx
mtk_face_features_v3.onnx mtk_face_features_v3.onnx
emotion-ferplus-8.onnx
ml_face_3d.onnx ml_face_3d.onnx
gts_version-RFB-320_simplified.onnx gts_version-RFB-320_simplified.onnx
mnist-8.onnx mnist-8.onnx


+ 5
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc View File

@@ -89,7 +89,11 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} }
} }
if (attribute_name == "ceil_mode") { if (attribute_name == "ceil_mode") {
attr->roundMode = schema::RoundMode_CEIL;
if (onnx_node_attr.f() == 0) {
attr->roundMode = schema::RoundMode_FLOOR;
} else {
attr->roundMode = schema::RoundMode_CEIL;
}
} }
if (attribute_name == "dilations") { if (attribute_name == "dilations") {
MS_LOG(ERROR) << "pooling op not support dilations now"; MS_LOG(ERROR) << "pooling op not support dilations now";


Loading…
Cancel
Save