|
|
|
@@ -30,14 +30,20 @@ constexpr int kStridedSliceInputNum = 1; |
|
|
|
void StridedSlice::ApplyNewAxisMask() { |
|
|
|
for (int i = 0; i < new_axis_mask_.size(); i++) { |
|
|
|
if (new_axis_mask_.at(i)) { |
|
|
|
updated_ndim_ += 1; |
|
|
|
updated_in_shape_.insert(updated_in_shape_.begin() + i, 1); |
|
|
|
updated_begins_.at(i) = 0; |
|
|
|
updated_ends_.at(i) = 1; |
|
|
|
updated_strides_.at(i) = 1; |
|
|
|
updated_begins_.emplace_back(0); |
|
|
|
updated_ends_.emplace_back(updated_in_shape_.at(updated_ndim_ - 1)); |
|
|
|
updated_strides_.emplace_back(1); |
|
|
|
ndim_ += 1; |
|
|
|
in_shape_.insert(in_shape_.begin() + i, 1); |
|
|
|
begins_.at(i) = 0; |
|
|
|
ends_.at(i) = 1; |
|
|
|
strides_.at(i) = 1; |
|
|
|
|
|
|
|
begins_.emplace_back(0); |
|
|
|
ends_.emplace_back(in_shape_.at(ndim_ - 1)); |
|
|
|
strides_.emplace_back(1); |
|
|
|
|
|
|
|
begins_mask_.at(i) = false; |
|
|
|
ends_mask_.at(i) = false; |
|
|
|
ellipsis_mask_.at(i) = false; |
|
|
|
shrink_axis_mask_.at(i) = false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -47,8 +53,8 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) { |
|
|
|
out_shape.clear(); |
|
|
|
for (int i = 0; i < shrink_axis_mask_.size(); i++) { |
|
|
|
if (shrink_axis_mask_.at(i)) { |
|
|
|
updated_ends_.at(i) = updated_begins_.at(i) + 1; |
|
|
|
updated_strides_.at(i) = 1; |
|
|
|
ends_.at(i) = begins_.at(i) + 1; |
|
|
|
strides_.at(i) = 1; |
|
|
|
} else { |
|
|
|
out_shape.emplace_back(old_out_shape.at(i)); |
|
|
|
} |
|
|
|
@@ -63,22 +69,26 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) { |
|
|
|
void StridedSlice::ApplyEllipsisMask() { |
|
|
|
for (int i = 0; i < ellipsis_mask_.size(); i++) { |
|
|
|
if (ellipsis_mask_.at(i)) { |
|
|
|
updated_begins_.at(i) = 0; |
|
|
|
updated_ends_.at(i) = updated_in_shape_.at(i); |
|
|
|
begins_.at(i) = 0; |
|
|
|
ends_.at(i) = in_shape_.at(i); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void StridedSlice::ApplyBeginMask() { |
|
|
|
for (int i = 0; i < ori_ndim_; i++) { |
|
|
|
updated_begins_.at(i) = 0; |
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
if (begins_mask_.at(i)) { |
|
|
|
begins_.at(i) = 0; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void StridedSlice::ApplyEndMask() { |
|
|
|
for (int i = 0; i < ori_ndim_; i++) { |
|
|
|
updated_ends_.at(i) = 0; |
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
if (ends_.at(i)) { |
|
|
|
ends_.at(i) = in_shape_.at(i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -88,7 +98,7 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t |
|
|
|
MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
if (inputs.size() < kStridedSliceInputNum) { |
|
|
|
if (inputs.size() != kStridedSliceInputNum) { |
|
|
|
MS_LOG(ERROR) << "Invalid input size " << inputs.size(); |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
@@ -97,28 +107,28 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t |
|
|
|
auto input_shape = input->shape(); |
|
|
|
std::vector<int> output_shape; |
|
|
|
auto strided_slice_prim = this->primitive->value_as_StridedSlice(); |
|
|
|
updated_ndim_ = static_cast<int>(strided_slice_prim->begin()->size()); |
|
|
|
ori_ndim_ = updated_ndim_; |
|
|
|
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->end()->size())); |
|
|
|
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->stride()->size())); |
|
|
|
MS_ASSERT(updated_ndim_ == static_cast<int>(input_shape.size())); |
|
|
|
|
|
|
|
for (int i = 0; i < updated_ndim_; i++) { |
|
|
|
updated_in_shape_.emplace_back(input_shape.at(i)); |
|
|
|
updated_begins_.emplace_back((*(strided_slice_prim->begin()))[i]); |
|
|
|
updated_ends_.emplace_back((*(strided_slice_prim->end()))[i]); |
|
|
|
updated_strides_.emplace_back((*(strided_slice_prim->stride()))[i]); |
|
|
|
ndim_ = static_cast<int>(strided_slice_prim->begin()->size()); |
|
|
|
|
|
|
|
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size())); |
|
|
|
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size())); |
|
|
|
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size())); |
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
in_shape_.emplace_back(input_shape.at(i)); |
|
|
|
begins_.emplace_back((*(strided_slice_prim->begin()))[i]); |
|
|
|
ends_.emplace_back((*(strided_slice_prim->end()))[i]); |
|
|
|
strides_.emplace_back((*(strided_slice_prim->stride()))[i]); |
|
|
|
} |
|
|
|
|
|
|
|
// set all mask to original input shape |
|
|
|
begins_mask_.resize(updated_ndim_); |
|
|
|
ends_mask_.resize(updated_ndim_); |
|
|
|
ellipsis_mask_.resize(updated_ndim_); |
|
|
|
new_axis_mask_.resize(updated_ndim_); |
|
|
|
shrink_axis_mask_.resize(updated_ndim_); |
|
|
|
begins_mask_.resize(ndim_); |
|
|
|
ends_mask_.resize(ndim_); |
|
|
|
ellipsis_mask_.resize(ndim_); |
|
|
|
new_axis_mask_.resize(ndim_); |
|
|
|
shrink_axis_mask_.resize(ndim_); |
|
|
|
|
|
|
|
// convert bit to vector |
|
|
|
for (int i = 0; i < updated_ndim_; i++) { |
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
begins_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->beginMask()) & (1 << i); |
|
|
|
ends_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->endMask()) & (1 << i); |
|
|
|
ellipsis_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->ellipsisMask()) & (1 << i); |
|
|
|
@@ -127,29 +137,17 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t |
|
|
|
} |
|
|
|
|
|
|
|
ApplyNewAxisMask(); |
|
|
|
ApplyNewAxisMask(); |
|
|
|
ApplyBeginMask(); |
|
|
|
ApplyEndMask(); |
|
|
|
ApplyEllipsisMask(); |
|
|
|
|
|
|
|
output_shape.resize(updated_in_shape_.size()); |
|
|
|
for (int i = 0; i < updated_in_shape_.size(); i++) { |
|
|
|
if (i < ori_ndim_ && new_axis_mask_.at(i)) { |
|
|
|
output_shape.clear(); |
|
|
|
output_shape.resize(in_shape_.size()); |
|
|
|
for (int i = 0; i < in_shape_.size(); i++) { |
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) { |
|
|
|
output_shape.at(i) = 1; |
|
|
|
} else { |
|
|
|
// begins and ends out of range handling |
|
|
|
if (updated_begins_.at(i) >= updated_in_shape_.at(i) || updated_begins_.at(i) < -updated_in_shape_.at(i) || |
|
|
|
updated_ends_.at(i) < -updated_in_shape_.at(i) || updated_ends_.at(i) > updated_in_shape_.at(i)) { |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
updated_begins_.at(i) = updated_begins_.at(i) % updated_in_shape_.at(i); |
|
|
|
updated_ends_.at(i) = updated_ends_.at(i) % updated_in_shape_.at(i); |
|
|
|
|
|
|
|
if ((updated_ends_.at(i) <= updated_begins_.at(i) && updated_strides_.at(i) > 0) || |
|
|
|
(updated_ends_.at(i) >= updated_begins_.at(i) && updated_strides_.at(i) < 0)) { |
|
|
|
output_shape.at(i) = 0; |
|
|
|
} else { |
|
|
|
output_shape.at(i) = 1 + (updated_ends_.at(i) - updated_begins_.at(i) - 1) / updated_strides_.at(i); |
|
|
|
} |
|
|
|
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|