Browse Source

fix fp16 multi threads bug

tags/v1.2.0-rc1
wangzhe 5 years ago
parent
commit
f6a7c01d73
13 changed files with 34 additions and 4 deletions
  1. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/base/merge.cc
  2. +4
    -3
      mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc
  3. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc
  4. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/base/switch.cc
  5. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc
  6. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc
  7. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  8. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc
  9. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc
  10. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc
  11. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  12. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc
  13. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc

+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/base/merge.cc View File

@@ -139,6 +139,7 @@ int MergeCPUKernel::Run() {
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
} // namespace mindspore::kernel

+ 4
- 3
mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc View File

@@ -41,11 +41,12 @@ int ReshapeBaseCPUKernel::ReSize() {

int ReshapeBaseCPUKernel::RunImpl(int task_id) {
size_t start_index = task_id * cal_max_num_per_thread_;
auto cur_in_ptr = input_ptr_ + start_index;
auto cur_out_ptr = output_ptr_ + start_index;
if (start_index > in_tensors_.front()->Size()) {
if (start_index >= in_tensors_.front()->Size()) {
return RET_OK;
}
auto cur_in_ptr = input_ptr_ + start_index;
auto cur_out_ptr = output_ptr_ + start_index;

size_t data_size = in_tensors_.front()->Size() - start_index;
data_size = data_size > cal_max_num_per_thread_ ? cal_max_num_per_thread_ : data_size;
memcpy(cur_out_ptr, cur_in_ptr, data_size);


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/stack_base.cc View File

@@ -55,7 +55,7 @@ int StackBaseCPUKernel::ReSize() {
axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() + 1 : param->axis_;
auto input_nums = in_tensors_.size();
if (input_nums == 1) {
copy_size_ = in_tensors_.front()->Size();
copy_size_ = in_tensors_.front()->ElementsNum() * data_type_size_;
} else {
MS_ASSERT(input_nums > 1);
copy_size_ = GetCopyNum(input0_shape, axis_, input0_shape.size()) * data_type_size_;


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/base/switch.cc View File

@@ -93,6 +93,7 @@ int SwitchCPUKernel::Run() {
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
} // namespace mindspore::kernel

+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc View File

@@ -51,6 +51,9 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) {

int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;
}

int error_code;
if (type_ == schema::ActivationType_RELU) {


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.cc View File

@@ -126,6 +126,9 @@ int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) {
int cur_offset = stride_per_thread * task_id;
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
if (cur_count <= 0) {
return RET_OK;
}

int ret = RET_OK;
if (param_->broadcasting_) {


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -169,6 +169,9 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
int cur_offset = stride_per_thread * task_id;
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
if (cur_count <= 0) {
return RET_OK;
}

int ret = RET_OK;
if (param_->broadcasting_) {


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc View File

@@ -88,6 +88,9 @@ int GatherFp16CPUKernel::DoGather(int task_id) {
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto thread_stride = stride * task_id;
int8_t *int8_in = nullptr;
if (input_tensor->data_type() == kNumberTypeFloat32) {


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc View File

@@ -54,6 +54,9 @@ int ActivationCPUKernel::DoActivation(int task_id) {

int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;
}

auto ret = RET_OK;



+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc View File

@@ -68,6 +68,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
}

if (func_fp32_ == nullptr) {
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -288,6 +288,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
}

if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc View File

@@ -57,6 +57,9 @@ int GatherCPUKernel::DoGather(int task_id) {
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto thread_stride = stride * task_id;

int8_t *int8_in = reinterpret_cast<int8_t *>(input_tensor->data_c());


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/power_fp32.cc View File

@@ -57,6 +57,9 @@ int PowerCPUKernel::RunImpl(int task_id) {
auto size = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(size, thread_count_);
int len = MSMIN(stride, size - stride * task_id);
if (len <= 0) {
return RET_OK;
}
float *exp_addr = nullptr;
bool broadcast = true;
if (in_tensors_.size() == 2) {


Loading…
Cancel
Save