Browse Source

!13301 [MS][LITE]reverse op bug

From: @fuzhiye
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
edaffb66af
2 changed files with 15 additions and 1 deletions
  1. +13
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc
  2. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.h

+ 13
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc View File

@@ -37,6 +37,9 @@ int ReverseCPUKernel::Stride(int index) {
}

int ReverseCPUKernel::ReSize() {
// trans negative to positive axis
UpdateAxisInfo();

data_size_ = in_tensors_.at(0)->ElementsNum();
thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_);
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
@@ -134,6 +137,16 @@ int ReverseCPUKernel::Run() {
return RET_OK;
}

void ReverseCPUKernel::UpdateAxisInfo() {
auto reverse_param = reinterpret_cast<ReverseParameter *>(op_parameter_);
int in_shape_len = in_tensors_.front()->shape().size();
for (int i = 0; i < reverse_param->num_axis_; ++i) {
if (reverse_param->axis_[i] < 0) {
reverse_param->axis_[i] += in_shape_len;
}
}
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
} // namespace mindspore::kernel

+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.h View File

@@ -28,7 +28,7 @@ class ReverseCPUKernel : public LiteKernel {
ReverseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {}
~ReverseCPUKernel() {
~ReverseCPUKernel() override {
if (tmp_ != nullptr) {
free(tmp_);
tmp_ = nullptr;
@@ -42,6 +42,7 @@ class ReverseCPUKernel : public LiteKernel {
int DoReverse(int task_id);

private:
void UpdateAxisInfo();
int thread_sz_count_ = 0;
int thread_sz_stride_ = 0;
int data_size_ = 0;


Loading…
Cancel
Save