|
|
|
@@ -226,6 +226,17 @@ void StridedSlice::ApplyEndMask() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void StridedSlice::TransIndexToPositive() { |
|
|
|
for (int i = 0; i < static_cast<int>(begins_.size()); ++i) { |
|
|
|
if (begins_.at(i) < 0) { |
|
|
|
begins_.at(i) += in_shape_.at(i); |
|
|
|
} |
|
|
|
if (ends_.at(i) < 0) { |
|
|
|
ends_.at(i) += in_shape_.at(i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { |
|
|
|
MS_ASSERT(this->primitive_ != nullptr); |
|
|
|
if (outputs.size() != kStridedSliceOutputNum) { |
|
|
|
@@ -266,7 +277,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
ndim_ = begin_tensor->ElementsNum(); |
|
|
|
for (int i=0; i< ndim_; ++i) { |
|
|
|
for (int i = 0; i < ndim_; ++i) { |
|
|
|
in_shape_.emplace_back(input_shape.at(i)); |
|
|
|
begins_.emplace_back(begin_data[i]); |
|
|
|
ends_.emplace_back(end_data[i]); |
|
|
|
@@ -297,13 +308,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit |
|
|
|
|
|
|
|
output_shape.clear(); |
|
|
|
output_shape.resize(in_shape_.size()); |
|
|
|
|
|
|
|
TransIndexToPositive(); |
|
|
|
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) { |
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) { |
|
|
|
output_shape.at(i) = 1; |
|
|
|
} else if (ends_.at(i) > 0) { |
|
|
|
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); |
|
|
|
} else { |
|
|
|
output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i); |
|
|
|
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|