Browse Source

fix gathernd bug

tags/v1.2.0-rc1
gongdaguo 4 years ago
parent
commit
6a4cd208d0
4 changed files with 20 additions and 6 deletions
  1. +9
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc
  2. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h
  3. +9
    -3
      mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.cc
  4. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h

+ 9
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc View File

@@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h"
#include <string.h>
#include <limits>
#include <vector>
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
@@ -65,11 +66,17 @@ int GatherNdCPUKernel::ReSize() {
return RET_ERROR;
}
(void)memset(in_offset_, 0, count_ * sizeof(int));

thread_sz_count_ = MSMIN(thread_count_, count_);
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
return RET_OK;
}

void GatherNdCPUKernel::InitOffset() {
MS_ASSERT(in_offset_ != nullptr);
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
auto in_shape = in_tensors_.front()->shape();
int indices_rank = indices_shape.size();
int in_rank = in_shape.size();
int idx_lastshape = indices_shape[indices_rank - 1];
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->MutableData());
@@ -89,8 +96,6 @@ int GatherNdCPUKernel::ReSize() {
in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride.at(k);
}
}

return RET_OK;
}

int GatherNdCPUKernel::DoGatherNd(int task_id) {
@@ -120,6 +125,7 @@ int GatherNdRun(void *cdata, int task_id) {
int GatherNdCPUKernel::Run() {
in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
InitOffset();
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdRun, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h View File

@@ -41,6 +41,7 @@ class GatherNdCPUKernel : public LiteKernel {
int DoGatherNd(int task_id);

private:
void InitOffset();
int thread_sz_count_;
int thread_sz_stride_;
int count_;


+ 9
- 3
mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.cc View File

@@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h"
#include <string.h>
#include <limits>
#include <vector>
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
@@ -50,7 +51,6 @@ int GatherNdInt8CPUKernel::ReSize() {
in_offset_ = nullptr;
}
auto in_quant_args = in_tensors_.at(0)->quant_params();
auto ind_quant_args = in_tensors_.at(1)->quant_params();
auto out_quant_args = out_tensors_.at(0)->quant_params();
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
param_.zp_in_ = in_quant_args.front().zeroPoint;
@@ -73,10 +73,16 @@ int GatherNdInt8CPUKernel::ReSize() {
return RET_ERROR;
}
(void)memset(in_offset_, 0, count_ * sizeof(int));

thread_sz_count_ = MSMIN(thread_count_, count_);
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
return RET_OK;
}

void GatherNdInt8CPUKernel::InitOffset() {
auto ind_quant_args = in_tensors_.at(1)->quant_params();
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
int indices_rank = indices_shape.size();
auto in_shape = in_tensors_.front()->shape();
int in_rank = in_shape.size();
int idx_lastshape = indices_shape.at(indices_rank - 1);
@@ -99,7 +105,6 @@ int GatherNdInt8CPUKernel::ReSize() {
in_offset_[j] += tmp * in_stride[k];
}
}
return RET_OK;
}

int GatherNdInt8CPUKernel::DoGatherNd(int task_id) {
@@ -129,6 +134,7 @@ int GatherNdInt8Run(void *cdata, int task_id) {
int GatherNdInt8CPUKernel::Run() {
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->MutableData());
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->MutableData());
InitOffset();
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h View File

@@ -36,6 +36,7 @@ class GatherNdInt8CPUKernel : public LiteKernel {
int DoGatherNd(int task_id);

private:
void InitOffset();
int thread_count_;
int thread_sz_count_;
int thread_sz_stride_;


Loading…
Cancel
Save