|
|
|
@@ -340,6 +340,7 @@ int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
// note: begin, end, stride length are equal, but may less than rank of input |
|
|
|
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { |
|
|
|
MS_ASSERT(this->primitive_ != nullptr); |
|
|
|
if (outputs.size() != kStridedSliceOutputNum) { |
|
|
|
@@ -359,6 +360,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit |
|
|
|
auto inferflag = infer_flag(); |
|
|
|
|
|
|
|
in_shape_.clear(); |
|
|
|
if (inferflag) { |
|
|
|
in_shape_.assign(input_shape.begin(), input_shape.end()); |
|
|
|
} |
|
|
|
begins_.clear(); |
|
|
|
ends_.clear(); |
|
|
|
strides_.clear(); |
|
|
|
@@ -366,9 +370,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit |
|
|
|
ndim_ = static_cast<int>(GetBegin().size()); |
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
if (inferflag) { |
|
|
|
in_shape_.emplace_back(input_shape.at(i)); |
|
|
|
} |
|
|
|
begins_.emplace_back((GetBegin()).at(i)); |
|
|
|
ends_.emplace_back((GetEnd()).at(i)); |
|
|
|
strides_.emplace_back((GetStride()).at(i)); |
|
|
|
@@ -391,9 +392,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit |
|
|
|
} |
|
|
|
ndim_ = begin_tensor->ElementsNum(); |
|
|
|
for (int i = 0; i < ndim_; ++i) { |
|
|
|
if (inferflag) { |
|
|
|
in_shape_.emplace_back(input_shape.at(i)); |
|
|
|
} |
|
|
|
begins_.emplace_back(begin_data[i]); |
|
|
|
ends_.emplace_back(end_data[i]); |
|
|
|
strides_.emplace_back(stride_data[i]); |
|
|
|
@@ -431,22 +429,16 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit |
|
|
|
if (!inferflag) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
std::vector<int> output_shape; |
|
|
|
output_shape.clear(); |
|
|
|
output_shape.resize(in_shape_.size()); |
|
|
|
std::vector<int> output_shape(in_shape_); |
|
|
|
|
|
|
|
TransIndexToPositive(); |
|
|
|
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) { |
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) { |
|
|
|
output_shape.at(i) = 1; |
|
|
|
} else { |
|
|
|
if (strides_.at(i) == 0) { |
|
|
|
MS_LOG(ERROR) << "strides should not be 0."; |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
output_shape.at(i) = |
|
|
|
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i); |
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
if (strides_.at(i) == 0) { |
|
|
|
MS_LOG(ERROR) << "strides should not be 0."; |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
output_shape.at(i) = |
|
|
|
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i); |
|
|
|
} |
|
|
|
|
|
|
|
output_shape = ApplyShrinkMask(output_shape); |
|
|
|
|