Browse Source

!12392 [MSLITE] Fix the bug of fp16 gather working in multithreading

From: @zhanyuan1
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
1c8fd7eac0
1 changed files with 6 additions and 5 deletions
  1. +6
    -5
      mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc

+ 6
- 5
mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc View File

@@ -91,9 +91,6 @@ int GatherFp16CPUKernel::DoGather(int task_id) {
auto thread_stride = stride * task_id;
int8_t *int8_in = nullptr;
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
int8_in = reinterpret_cast<int8_t *>(input_data_);
} else if (input_tensor->data_type() == kNumberTypeFloat16) {
int8_in = reinterpret_cast<int8_t *>(input_tensor->data_c());
@@ -127,7 +124,12 @@ int GatherFp16CPUKernel::Run() {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}

auto input_tensor = in_tensors_.at(0);
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
}
ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
@@ -140,7 +142,6 @@ int GatherFp16CPUKernel::Run() {
context_->allocator->Free(input_data_);
input_data_ = nullptr;
}

return ret;
}



Loading…
Cancel
Save