Browse Source

!9282 [lite]fix strided slice multi inputs bug

From: @xu_anyue
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
44ea3902b8
3 changed files with 0 additions and 49 deletions
  1. +0
    -45
      mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc
  2. +0
    -3
      mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h
  3. +0
    -1
      mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc

+ 0
- 45
mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc View File

@@ -29,12 +29,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_StridedSlice; using mindspore::schema::PrimitiveType_StridedSlice;


namespace mindspore::kernel { namespace mindspore::kernel {
namespace {
constexpr size_t kMultiInputsSize = 4;
constexpr size_t kBeginsIndex = 1;
constexpr size_t kEndsIndex = 2;
constexpr size_t kStridesInex = 3;
} // namespace
int StridedSliceCPUKernel::Init() { int StridedSliceCPUKernel::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
@@ -57,38 +51,6 @@ int StridedSliceCPUKernel::ReSize() {
return RET_OK; return RET_OK;
} }


int StridedSliceCPUKernel::HandleMultiInputs() {
if (in_tensors_.size() != kMultiInputsSize) {
MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size();
return RET_ERROR;
}
if (param_ == nullptr) {
MS_LOG(ERROR) << "StridedSliceParamater cast nullptr";
return RET_ERROR;
}
auto begins = in_tensors_.at(kBeginsIndex);
MS_ASSERT(begins != nullptr);
int axis_num = begins->ElementsNum();
if (axis_num > DIMENSION_6D) {
MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num;
return RET_ERROR;
}
memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int));

auto ends = in_tensors_.at(kEndsIndex);
MS_ASSERT(ends != nullptr);
MS_ASSERT(axis_num == ends->ElementsNum());
memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int));

auto strides = in_tensors_.at(kStridesInex);
MS_ASSERT(strides != nullptr);
MS_ASSERT(axis_num == strides->ElementsNum());
memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int));

param_->num_axes_ = axis_num;
return RET_OK;
}

int StridedSliceCPUKernel::Run() { int StridedSliceCPUKernel::Run() {
auto input = in_tensors_.at(0); auto input = in_tensors_.at(0);
MS_ASSERT(input); MS_ASSERT(input);
@@ -108,13 +70,6 @@ int StridedSliceCPUKernel::Run() {
} }
auto output = out_tensors_.at(0); auto output = out_tensors_.at(0);
MS_ASSERT(output); MS_ASSERT(output);
// inputs order: input, begin, end, stride
if (in_tensors().size() == kMultiInputsSize) {
auto ret = HandleMultiInputs();
if (ret != RET_OK) {
return ret;
}
}
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_); auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";


+ 0
- 3
mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h View File

@@ -36,9 +36,6 @@ class StridedSliceCPUKernel : public LiteKernel {
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;


private:
int HandleMultiInputs();

private: private:
StridedSliceParameter *param_; StridedSliceParameter *param_;
}; };


+ 0
- 1
mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc View File

@@ -60,7 +60,6 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) {
} }


if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce || if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce ||
opt::GetCNodeType(node) == schema::PrimitiveType_StridedSlice ||
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin || opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin ||
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax || opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax ||
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch || opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch ||


Loading…
Cancel
Save