diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc index 588bae5cbe..e5fc4d2dca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc @@ -17,6 +17,7 @@ #include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h" #include #include +#include #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/kernel_registry.h" @@ -65,11 +66,17 @@ int GatherNdCPUKernel::ReSize() { return RET_ERROR; } (void)memset(in_offset_, 0, count_ * sizeof(int)); - thread_sz_count_ = MSMIN(thread_count_, count_); thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); + return RET_OK; +} +void GatherNdCPUKernel::InitOffset() { + MS_ASSERT(in_offset_ != nullptr); + auto indices_tensor = in_tensors_.at(1); + auto indices_shape = indices_tensor->shape(); auto in_shape = in_tensors_.front()->shape(); + int indices_rank = indices_shape.size(); int in_rank = in_shape.size(); int idx_lastshape = indices_shape[indices_rank - 1]; auto indices_ptr = reinterpret_cast(indices_tensor->MutableData()); @@ -89,8 +96,6 @@ int GatherNdCPUKernel::ReSize() { in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride.at(k); } } - - return RET_OK; } int GatherNdCPUKernel::DoGatherNd(int task_id) { @@ -120,6 +125,7 @@ int GatherNdRun(void *cdata, int task_id) { int GatherNdCPUKernel::Run() { in_ptr_ = reinterpret_cast(in_tensors_.front()->MutableData()); out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + InitOffset(); auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h index 7f719ae29d..b24978abcf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h @@ -41,6 +41,7 @@ class GatherNdCPUKernel : public LiteKernel { int DoGatherNd(int task_id); private: + void InitOffset(); int thread_sz_count_; int thread_sz_stride_; int count_; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.cc index 4f15af5593..968daa223e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.cc @@ -17,6 +17,7 @@ #include "src/runtime/kernel/arm/int8/gatherNd_int8.h" #include #include +#include #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/kernel_registry.h" @@ -50,7 +51,6 @@ int GatherNdInt8CPUKernel::ReSize() { in_offset_ = nullptr; } auto in_quant_args = in_tensors_.at(0)->quant_params(); - auto ind_quant_args = in_tensors_.at(1)->quant_params(); auto out_quant_args = out_tensors_.at(0)->quant_params(); param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale; param_.zp_in_ = in_quant_args.front().zeroPoint; @@ -73,10 +73,16 @@ int GatherNdInt8CPUKernel::ReSize() { return RET_ERROR; } (void)memset(in_offset_, 0, count_ * sizeof(int)); - thread_sz_count_ = MSMIN(thread_count_, count_); thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); + return RET_OK; +} +void GatherNdInt8CPUKernel::InitOffset() { + auto ind_quant_args = in_tensors_.at(1)->quant_params(); + auto indices_tensor = in_tensors_.at(1); + auto indices_shape = indices_tensor->shape(); + int indices_rank = indices_shape.size(); auto in_shape = in_tensors_.front()->shape(); int in_rank = in_shape.size(); int idx_lastshape = indices_shape.at(indices_rank - 1); @@ -99,7 +105,6 @@ int GatherNdInt8CPUKernel::ReSize() { in_offset_[j] += tmp * in_stride[k]; } } - return RET_OK; } int GatherNdInt8CPUKernel::DoGatherNd(int task_id) { @@ -129,6 +134,7 @@ int GatherNdInt8Run(void *cdata, int task_id) { int GatherNdInt8CPUKernel::Run() { in_ptr_ = reinterpret_cast(in_tensors_.front()->MutableData()); out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + InitOffset(); auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h index b6529f00e4..92294824b3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h @@ -36,6 +36,7 @@ class GatherNdInt8CPUKernel : public LiteKernel { int DoGatherNd(int task_id); private: + void InitOffset(); int thread_count_; int thread_sz_count_; int thread_sz_stride_;