|
|
|
@@ -29,12 +29,6 @@ using mindspore::lite::RET_OK; |
|
|
|
using mindspore::schema::PrimitiveType_StridedSlice; |
|
|
|
|
|
|
|
namespace mindspore::kernel { |
|
|
|
namespace { |
|
|
|
constexpr size_t kMultiInputsSize = 4; |
|
|
|
constexpr size_t kBeginsIndex = 1; |
|
|
|
constexpr size_t kEndsIndex = 2; |
|
|
|
constexpr size_t kStridesInex = 3; |
|
|
|
} // namespace |
|
|
|
int StridedSliceCPUKernel::Init() { |
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
@@ -57,38 +51,6 @@ int StridedSliceCPUKernel::ReSize() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int StridedSliceCPUKernel::HandleMultiInputs() { |
|
|
|
if (in_tensors_.size() != kMultiInputsSize) { |
|
|
|
MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (param_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "StridedSliceParamater cast nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto begins = in_tensors_.at(kBeginsIndex); |
|
|
|
MS_ASSERT(begins != nullptr); |
|
|
|
int axis_num = begins->ElementsNum(); |
|
|
|
if (axis_num > DIMENSION_6D) { |
|
|
|
MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int)); |
|
|
|
|
|
|
|
auto ends = in_tensors_.at(kEndsIndex); |
|
|
|
MS_ASSERT(ends != nullptr); |
|
|
|
MS_ASSERT(axis_num == ends->ElementsNum()); |
|
|
|
memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int)); |
|
|
|
|
|
|
|
auto strides = in_tensors_.at(kStridesInex); |
|
|
|
MS_ASSERT(strides != nullptr); |
|
|
|
MS_ASSERT(axis_num == strides->ElementsNum()); |
|
|
|
memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int)); |
|
|
|
|
|
|
|
param_->num_axes_ = axis_num; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int StridedSliceCPUKernel::Run() { |
|
|
|
auto input = in_tensors_.at(0); |
|
|
|
MS_ASSERT(input); |
|
|
|
@@ -108,13 +70,6 @@ int StridedSliceCPUKernel::Run() { |
|
|
|
} |
|
|
|
auto output = out_tensors_.at(0); |
|
|
|
MS_ASSERT(output); |
|
|
|
// inputs order: input, begin, end, stride |
|
|
|
if (in_tensors().size() == kMultiInputsSize) { |
|
|
|
auto ret = HandleMultiInputs(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; |
|
|
|
|