|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h" |
|
|
|
#include <string.h> |
|
|
|
#include <limits> |
|
|
|
#include <vector> |
|
|
|
#include "schema/model_generated.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
@@ -50,7 +51,6 @@ int GatherNdInt8CPUKernel::ReSize() { |
|
|
|
in_offset_ = nullptr; |
|
|
|
} |
|
|
|
auto in_quant_args = in_tensors_.at(0)->quant_params(); |
|
|
|
auto ind_quant_args = in_tensors_.at(1)->quant_params(); |
|
|
|
auto out_quant_args = out_tensors_.at(0)->quant_params(); |
|
|
|
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale; |
|
|
|
param_.zp_in_ = in_quant_args.front().zeroPoint; |
|
|
|
@@ -73,10 +73,16 @@ int GatherNdInt8CPUKernel::ReSize() { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
(void)memset(in_offset_, 0, count_ * sizeof(int)); |
|
|
|
|
|
|
|
thread_sz_count_ = MSMIN(thread_count_, count_); |
|
|
|
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void GatherNdInt8CPUKernel::InitOffset() { |
|
|
|
auto ind_quant_args = in_tensors_.at(1)->quant_params(); |
|
|
|
auto indices_tensor = in_tensors_.at(1); |
|
|
|
auto indices_shape = indices_tensor->shape(); |
|
|
|
int indices_rank = indices_shape.size(); |
|
|
|
auto in_shape = in_tensors_.front()->shape(); |
|
|
|
int in_rank = in_shape.size(); |
|
|
|
int idx_lastshape = indices_shape.at(indices_rank - 1); |
|
|
|
@@ -99,7 +105,6 @@ int GatherNdInt8CPUKernel::ReSize() { |
|
|
|
in_offset_[j] += tmp * in_stride[k]; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int GatherNdInt8CPUKernel::DoGatherNd(int task_id) { |
|
|
|
@@ -129,6 +134,7 @@ int GatherNdInt8Run(void *cdata, int task_id) { |
|
|
|
int GatherNdInt8CPUKernel::Run() { |
|
|
|
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->MutableData()); |
|
|
|
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->MutableData()); |
|
|
|
InitOffset(); |
|
|
|
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; |
|
|
|
|