From f6a7c01d73bee495a62a1bc5130fb539dd8eaccd Mon Sep 17 00:00:00 2001 From: wangzhe Date: Mon, 8 Feb 2021 10:00:52 +0800 Subject: [PATCH] fix fp16 multi threads bug --- mindspore/lite/src/runtime/kernel/arm/base/merge.cc | 1 + mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc | 7 ++++--- mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc | 2 +- mindspore/lite/src/runtime/kernel/arm/base/switch.cc | 1 + .../lite/src/runtime/kernel/arm/fp16/activation_fp16.cc | 3 +++ .../src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc | 3 +++ .../lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc | 3 +++ mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc | 3 +++ .../lite/src/runtime/kernel/arm/fp32/activation_fp32.cc | 3 +++ .../src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc | 3 +++ .../lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc | 3 +++ mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc | 3 +++ mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc | 3 +++ 13 files changed, 34 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/merge.cc b/mindspore/lite/src/runtime/kernel/arm/base/merge.cc index 2b4737ba47..6034403319 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/merge.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/merge.cc @@ -139,6 +139,7 @@ int MergeCPUKernel::Run() { } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Merge, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Merge, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc index 3b2f09683c..a0a79c9769 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -41,11 +41,12 @@ int ReshapeBaseCPUKernel::ReSize() { int ReshapeBaseCPUKernel::RunImpl(int task_id) { size_t start_index = task_id * cal_max_num_per_thread_; - auto cur_in_ptr = input_ptr_ + start_index; - auto cur_out_ptr = output_ptr_ + start_index; - if (start_index > in_tensors_.front()->Size()) { + if (start_index >= in_tensors_.front()->Size()) { return RET_OK; } + auto cur_in_ptr = input_ptr_ + start_index; + auto cur_out_ptr = output_ptr_ + start_index; + size_t data_size = in_tensors_.front()->Size() - start_index; data_size = data_size > cal_max_num_per_thread_ ? cal_max_num_per_thread_ : data_size; memcpy(cur_out_ptr, cur_in_ptr, data_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc index 3b327a0e5a..f3a656b5cf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc @@ -55,7 +55,7 @@ int StackBaseCPUKernel::ReSize() { axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() + 1 : param->axis_; auto input_nums = in_tensors_.size(); if (input_nums == 1) { - copy_size_ = in_tensors_.front()->Size(); + copy_size_ = in_tensors_.front()->ElementsNum() * data_type_size_; } else { MS_ASSERT(input_nums > 1); copy_size_ = GetCopyNum(input0_shape, axis_, input0_shape.size()) * data_type_size_; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc index cbd681eb54..a5c89f26e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc @@ -93,6 +93,7 @@ int SwitchCPUKernel::Run() { } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Switch, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Switch, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc index e66af3ced7..09c49fc64a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc @@ -51,6 +51,9 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) { int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); + if (count <= 0) { + return RET_OK; + } int error_code; if (type_ == schema::ActivationType_RELU) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc index d40f7caac3..278f1b0fe6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc @@ -126,6 +126,9 @@ int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) { int cur_offset = stride_per_thread * task_id; int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset) : MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); + if (cur_count <= 0) { + return RET_OK; + } int ret = RET_OK; if (param_->broadcasting_) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 888bb8b91d..097941fe83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -169,6 +169,9 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { int cur_offset = stride_per_thread * task_id; int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset) : MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); + if (cur_count <= 0) { + return RET_OK; + } int ret = RET_OK; if (param_->broadcasting_) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc index 6be60ca84d..c0e379d367 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc @@ -88,6 +88,9 @@ int GatherFp16CPUKernel::DoGather(int task_id) { } int stride = UP_DIV(outer_size, op_parameter_->thread_num_); int count = MSMIN(stride, outer_size - stride * task_id); + if (count <= 0) { + return RET_OK; + } auto thread_stride = stride * task_id; int8_t *int8_in = nullptr; if (input_tensor->data_type() == kNumberTypeFloat32) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc index 1e444061c6..316d7f2e38 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc @@ -54,6 +54,9 @@ int ActivationCPUKernel::DoActivation(int task_id) { int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); + if (count <= 0) { + return RET_OK; + } auto ret = RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index 4386125476..f791787e4b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -68,6 +68,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { MS_ASSERT(thread_count_ != 0); int stride = UP_DIV(element_num, thread_count_); int count = MSMIN(stride, element_num - stride * task_id); + if (count <= 0) { + return RET_OK; + } if (func_fp32_ == nullptr) { MS_LOG(ERROR) << "func_fp32_ function is nullptr!"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 51db2c8152..de92dd7219 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -288,6 +288,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { MS_ASSERT(thread_count_ != 0); int stride = UP_DIV(element_num, thread_count_); int count = MSMIN(stride, element_num - stride * task_id); + if (count <= 0) { + return RET_OK; + } if (arithmetic_run_ == nullptr) { MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc index 7df3f72a83..616644722c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc @@ -57,6 +57,9 @@ int GatherCPUKernel::DoGather(int task_id) { } int stride = UP_DIV(outer_size, op_parameter_->thread_num_); int count = MSMIN(stride, outer_size - stride * task_id); + if (count <= 0) { + return RET_OK; + } auto thread_stride = stride * task_id; int8_t *int8_in = reinterpret_cast(input_tensor->data_c()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc index e8b041e26c..64160b15af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc @@ -57,6 +57,9 @@ int PowerCPUKernel::RunImpl(int task_id) { auto size = in_tensors_.at(0)->ElementsNum(); int stride = UP_DIV(size, thread_count_); int len = MSMIN(stride, size - stride * task_id); + if (len <= 0) { + return RET_OK; + } float *exp_addr = nullptr; bool broadcast = true; if (in_tensors_.size() == 2) {