|
|
|
@@ -31,8 +31,6 @@ using mindspore::schema::PrimitiveType_Gather; |
|
|
|
namespace mindspore::kernel { |
|
|
|
|
|
|
|
int GatherCPUKernel::Init() { |
|
|
|
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_; |
|
|
|
batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_; |
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -47,7 +45,7 @@ int GatherCPUKernel::DoGather(int task_id) { |
|
|
|
auto out_tensor = out_tensors_.at(0); |
|
|
|
|
|
|
|
auto input_ptr = reinterpret_cast<float *>(input_tensor->Data()); |
|
|
|
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data()); |
|
|
|
auto indices_ptr = reinterpret_cast<float *>(indices_tensor->Data()); |
|
|
|
auto output_ptr = reinterpret_cast<float *>(out_tensor->Data()); |
|
|
|
|
|
|
|
auto input_int32 = reinterpret_cast<int32_t *>(input_tensor->Data()); |
|
|
|
@@ -56,26 +54,25 @@ int GatherCPUKernel::DoGather(int task_id) { |
|
|
|
auto in_shape = input_tensor->shape(); |
|
|
|
int in_rank = in_shape.size(); |
|
|
|
int indices_element_size = indices_tensor->ElementsNum(); |
|
|
|
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_; |
|
|
|
|
|
|
|
const int limit = in_shape[axis_]; |
|
|
|
const int limit = in_shape[axis]; |
|
|
|
for (int i = 0; i < indices_element_size; ++i) { |
|
|
|
if (indices_ptr[i] >= limit) { |
|
|
|
MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]"; |
|
|
|
indices_data_[i] = static_cast<int>(indices_ptr[i]); |
|
|
|
if (indices_data_[i] >= limit) { |
|
|
|
MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int outer_size = 1; |
|
|
|
for (int i = 0; i < axis_; ++i) { |
|
|
|
int outer_size = 1, inner_size = 1; |
|
|
|
for (int i = 0; i < axis; ++i) { |
|
|
|
outer_size *= in_shape[i]; |
|
|
|
} |
|
|
|
|
|
|
|
int inner_size = 1; |
|
|
|
for (int i = axis_ + 1; i < in_rank; ++i) { |
|
|
|
for (int i = axis + 1; i < in_rank; ++i) { |
|
|
|
inner_size *= in_shape[i]; |
|
|
|
} |
|
|
|
|
|
|
|
int stride = UP_DIV(outer_size, thread_count_); |
|
|
|
int stride = UP_DIV(outer_size, op_parameter_->thread_num_); |
|
|
|
int count = MSMIN(stride, outer_size - stride * task_id); |
|
|
|
auto thread_stride = stride * task_id; |
|
|
|
|
|
|
|
@@ -83,17 +80,13 @@ int GatherCPUKernel::DoGather(int task_id) { |
|
|
|
if (input_tensor->data_type() == kNumberTypeInt32) { |
|
|
|
input_int32 += thread_stride * limit; |
|
|
|
output_int32 += thread_stride * indices_element_size; |
|
|
|
error_code = GatherInt32(input_int32, count, inner_size, limit, indices_ptr, indices_element_size, output_int32); |
|
|
|
error_code = GatherInt32(input_int32, count, inner_size, limit, indices_data_, indices_element_size, output_int32); |
|
|
|
} else { |
|
|
|
input_ptr += thread_stride * limit; |
|
|
|
output_ptr += thread_stride * indices_element_size; |
|
|
|
error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
if (error_code != RET_OK) { |
|
|
|
return RET_ERROR; |
|
|
|
error_code = Gather(input_ptr, count, inner_size, limit, indices_data_, indices_element_size, output_ptr); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
return error_code; |
|
|
|
} |
|
|
|
|
|
|
|
int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { |
|
|
|
@@ -101,9 +94,8 @@ int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { |
|
|
|
auto error_code = gather_kernel->DoGather(task_id); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
return error_code; |
|
|
|
} |
|
|
|
|
|
|
|
int GatherCPUKernel::Run() { |
|
|
|
@@ -112,12 +104,19 @@ int GatherCPUKernel::Run() { |
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; |
|
|
|
return prepare_ret; |
|
|
|
} |
|
|
|
int error_code = LiteBackendParallelLaunch(GatherRun, this, thread_count_); |
|
|
|
|
|
|
|
auto indices_tensor = in_tensors_.at(1); |
|
|
|
indices_data_ = reinterpret_cast<int *>(context_->allocator->Malloc(indices_tensor->ElementsNum() * sizeof(int))); |
|
|
|
if (indices_data_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Memory allocation failed"; |
|
|
|
context_->allocator->Free(indices_data_); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int error_code = LiteBackendParallelLaunch(GatherRun, this, op_parameter_->thread_num_); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
return error_code; |
|
|
|
} |
|
|
|
|
|
|
|
kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, |
|
|
|
|