Browse Source

!9992 fix strided slice shrink axis mask

From: @zhaozhenlong
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
99fca84f94
1 changed files with 11 additions and 19 deletions
  1. +11
    -19
      mindspore/lite/src/ops/strided_slice.cc

+ 11
- 19
mindspore/lite/src/ops/strided_slice.cc View File

@@ -340,6 +340,7 @@ int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs
return RET_OK;
}

// note: begin, end, stride length are equal, but may less than rank of input
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kStridedSliceOutputNum) {
@@ -359,6 +360,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
auto inferflag = infer_flag();

in_shape_.clear();
if (inferflag) {
in_shape_.assign(input_shape.begin(), input_shape.end());
}
begins_.clear();
ends_.clear();
strides_.clear();
@@ -366,9 +370,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
ndim_ = static_cast<int>(GetBegin().size());

for (int i = 0; i < ndim_; i++) {
if (inferflag) {
in_shape_.emplace_back(input_shape.at(i));
}
begins_.emplace_back((GetBegin()).at(i));
ends_.emplace_back((GetEnd()).at(i));
strides_.emplace_back((GetStride()).at(i));
@@ -391,9 +392,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
}
ndim_ = begin_tensor->ElementsNum();
for (int i = 0; i < ndim_; ++i) {
if (inferflag) {
in_shape_.emplace_back(input_shape.at(i));
}
begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]);
strides_.emplace_back(stride_data[i]);
@@ -431,22 +429,16 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
if (!inferflag) {
return RET_OK;
}
std::vector<int> output_shape;
output_shape.clear();
output_shape.resize(in_shape_.size());
std::vector<int> output_shape(in_shape_);

TransIndexToPositive();
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1;
} else {
if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;
}
output_shape.at(i) =
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
for (int i = 0; i < ndim_; i++) {
if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;
}
output_shape.at(i) =
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
}

output_shape = ApplyShrinkMask(output_shape);


Loading…
Cancel
Save