|
|
|
@@ -91,9 +91,6 @@ int GatherFp16CPUKernel::DoGather(int task_id) { |
|
|
|
auto thread_stride = stride * task_id; |
|
|
|
int8_t *int8_in = nullptr; |
|
|
|
if (input_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
input_data_ = |
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); |
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); |
|
|
|
int8_in = reinterpret_cast<int8_t *>(input_data_); |
|
|
|
} else if (input_tensor->data_type() == kNumberTypeFloat16) { |
|
|
|
int8_in = reinterpret_cast<int8_t *>(input_tensor->data_c()); |
|
|
|
@@ -127,7 +124,12 @@ int GatherFp16CPUKernel::Run() { |
|
|
|
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
auto input_tensor = in_tensors_.at(0); |
|
|
|
if (input_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
input_data_ = |
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); |
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); |
|
|
|
} |
|
|
|
ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]"; |
|
|
|
@@ -140,7 +142,6 @@ int GatherFp16CPUKernel::Run() { |
|
|
|
context_->allocator->Free(input_data_); |
|
|
|
input_data_ = nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
|