From 10553f3ab50a4d93307439ac5a199176357e6fcb Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Tue, 15 Dec 2020 16:42:43 +0800 Subject: [PATCH] fix begin length less than input rank --- mindspore/lite/src/ops/strided_slice.cc | 30 +++++++++---------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 8df66abc1e..577229fea0 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -340,6 +340,7 @@ int StridedSlice::HandleAxesInputExist(const std::vector &inputs return RET_OK; } +// note: begin, end, stride length are equal, but may less than rank of input int StridedSlice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kStridedSliceOutputNum) { @@ -359,6 +360,9 @@ int StridedSlice::InferShape(std::vector inputs, std::vector inputs, std::vector(GetBegin().size()); for (int i = 0; i < ndim_; i++) { - if (inferflag) { - in_shape_.emplace_back(input_shape.at(i)); - } begins_.emplace_back((GetBegin()).at(i)); ends_.emplace_back((GetEnd()).at(i)); strides_.emplace_back((GetStride()).at(i)); @@ -391,9 +392,6 @@ int StridedSlice::InferShape(std::vector inputs, std::vectorElementsNum(); for (int i = 0; i < ndim_; ++i) { - if (inferflag) { - in_shape_.emplace_back(input_shape.at(i)); - } begins_.emplace_back(begin_data[i]); ends_.emplace_back(end_data[i]); strides_.emplace_back(stride_data[i]); @@ -431,22 +429,16 @@ int StridedSlice::InferShape(std::vector inputs, std::vector output_shape; - output_shape.clear(); - output_shape.resize(in_shape_.size()); + std::vector output_shape(in_shape_); TransIndexToPositive(); - for (int i = 0; i < static_cast(in_shape_.size()); i++) { - if (i < ndim_ && new_axis_mask_.at(i)) { - output_shape.at(i) = 1; - } else { - if (strides_.at(i) == 0) { - MS_LOG(ERROR) << "strides should not be 0."; - return RET_INFER_ERR; - } - output_shape.at(i) = - (ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i); + for (int i = 0; i < ndim_; i++) { + if (strides_.at(i) == 0) { + MS_LOG(ERROR) << "strides should not be 0."; + return RET_INFER_ERR; } + output_shape.at(i) = + (ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i); } output_shape = ApplyShrinkMask(output_shape);