| @@ -17,6 +17,7 @@ | |||||
| #include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h" | #include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h" | ||||
| #include <string.h> | #include <string.h> | ||||
| #include <limits> | #include <limits> | ||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| @@ -65,11 +66,17 @@ int GatherNdCPUKernel::ReSize() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| (void)memset(in_offset_, 0, count_ * sizeof(int)); | (void)memset(in_offset_, 0, count_ * sizeof(int)); | ||||
| thread_sz_count_ = MSMIN(thread_count_, count_); | thread_sz_count_ = MSMIN(thread_count_, count_); | ||||
| thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); | thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); | ||||
| return RET_OK; | |||||
| } | |||||
| void GatherNdCPUKernel::InitOffset() { | |||||
| MS_ASSERT(in_offset_ != nullptr); | |||||
| auto indices_tensor = in_tensors_.at(1); | |||||
| auto indices_shape = indices_tensor->shape(); | |||||
| auto in_shape = in_tensors_.front()->shape(); | auto in_shape = in_tensors_.front()->shape(); | ||||
| int indices_rank = indices_shape.size(); | |||||
| int in_rank = in_shape.size(); | int in_rank = in_shape.size(); | ||||
| int idx_lastshape = indices_shape[indices_rank - 1]; | int idx_lastshape = indices_shape[indices_rank - 1]; | ||||
| auto indices_ptr = reinterpret_cast<int *>(indices_tensor->MutableData()); | auto indices_ptr = reinterpret_cast<int *>(indices_tensor->MutableData()); | ||||
| @@ -89,8 +96,6 @@ int GatherNdCPUKernel::ReSize() { | |||||
| in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride.at(k); | in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride.at(k); | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| int GatherNdCPUKernel::DoGatherNd(int task_id) { | int GatherNdCPUKernel::DoGatherNd(int task_id) { | ||||
| @@ -120,6 +125,7 @@ int GatherNdRun(void *cdata, int task_id) { | |||||
| int GatherNdCPUKernel::Run() { | int GatherNdCPUKernel::Run() { | ||||
| in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData()); | in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData()); | ||||
| out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | ||||
| InitOffset(); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdRun, this, thread_sz_count_); | auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdRun, this, thread_sz_count_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; | MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; | ||||
| @@ -41,6 +41,7 @@ class GatherNdCPUKernel : public LiteKernel { | |||||
| int DoGatherNd(int task_id); | int DoGatherNd(int task_id); | ||||
| private: | private: | ||||
| void InitOffset(); | |||||
| int thread_sz_count_; | int thread_sz_count_; | ||||
| int thread_sz_stride_; | int thread_sz_stride_; | ||||
| int count_; | int count_; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "src/runtime/kernel/arm/int8/gatherNd_int8.h" | #include "src/runtime/kernel/arm/int8/gatherNd_int8.h" | ||||
| #include <string.h> | #include <string.h> | ||||
| #include <limits> | #include <limits> | ||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| @@ -50,7 +51,6 @@ int GatherNdInt8CPUKernel::ReSize() { | |||||
| in_offset_ = nullptr; | in_offset_ = nullptr; | ||||
| } | } | ||||
| auto in_quant_args = in_tensors_.at(0)->quant_params(); | auto in_quant_args = in_tensors_.at(0)->quant_params(); | ||||
| auto ind_quant_args = in_tensors_.at(1)->quant_params(); | |||||
| auto out_quant_args = out_tensors_.at(0)->quant_params(); | auto out_quant_args = out_tensors_.at(0)->quant_params(); | ||||
| param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale; | param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale; | ||||
| param_.zp_in_ = in_quant_args.front().zeroPoint; | param_.zp_in_ = in_quant_args.front().zeroPoint; | ||||
| @@ -73,10 +73,16 @@ int GatherNdInt8CPUKernel::ReSize() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| (void)memset(in_offset_, 0, count_ * sizeof(int)); | (void)memset(in_offset_, 0, count_ * sizeof(int)); | ||||
| thread_sz_count_ = MSMIN(thread_count_, count_); | thread_sz_count_ = MSMIN(thread_count_, count_); | ||||
| thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); | thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); | ||||
| return RET_OK; | |||||
| } | |||||
| void GatherNdInt8CPUKernel::InitOffset() { | |||||
| auto ind_quant_args = in_tensors_.at(1)->quant_params(); | |||||
| auto indices_tensor = in_tensors_.at(1); | |||||
| auto indices_shape = indices_tensor->shape(); | |||||
| int indices_rank = indices_shape.size(); | |||||
| auto in_shape = in_tensors_.front()->shape(); | auto in_shape = in_tensors_.front()->shape(); | ||||
| int in_rank = in_shape.size(); | int in_rank = in_shape.size(); | ||||
| int idx_lastshape = indices_shape.at(indices_rank - 1); | int idx_lastshape = indices_shape.at(indices_rank - 1); | ||||
| @@ -99,7 +105,6 @@ int GatherNdInt8CPUKernel::ReSize() { | |||||
| in_offset_[j] += tmp * in_stride[k]; | in_offset_[j] += tmp * in_stride[k]; | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| int GatherNdInt8CPUKernel::DoGatherNd(int task_id) { | int GatherNdInt8CPUKernel::DoGatherNd(int task_id) { | ||||
| @@ -129,6 +134,7 @@ int GatherNdInt8Run(void *cdata, int task_id) { | |||||
| int GatherNdInt8CPUKernel::Run() { | int GatherNdInt8CPUKernel::Run() { | ||||
| in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->MutableData()); | in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->MutableData()); | ||||
| out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->MutableData()); | out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->MutableData()); | ||||
| InitOffset(); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_); | auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; | MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; | ||||
| @@ -36,6 +36,7 @@ class GatherNdInt8CPUKernel : public LiteKernel { | |||||
| int DoGatherNd(int task_id); | int DoGatherNd(int task_id); | ||||
| private: | private: | ||||
| void InitOffset(); | |||||
| int thread_count_; | int thread_count_; | ||||
| int thread_sz_count_; | int thread_sz_count_; | ||||
| int thread_sz_stride_; | int thread_sz_stride_; | ||||