|
|
|
@@ -65,7 +65,7 @@ int ReduceFp16CPUKernel::CallReduceUnit(int task_id) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
static int ReduceImpl(void *cdata, int task_id) { |
|
|
|
static int ReduceFp16Impl(void *cdata, int task_id) { |
|
|
|
auto reduce = reinterpret_cast<ReduceFp16CPUKernel *>(cdata); |
|
|
|
auto error_code = reduce->CallReduceUnit(task_id); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
@@ -102,7 +102,7 @@ int ReduceFp16CPUKernel::Run() { |
|
|
|
outer_size_ = outer_sizes_[i]; |
|
|
|
inner_size_ = inner_sizes_[i]; |
|
|
|
axis_size_ = axis_sizes_[i]; |
|
|
|
auto error_code = ParallelLaunch(this->context_->thread_pool_, ReduceImpl, this, context_->thread_num_); |
|
|
|
auto error_code = ParallelLaunch(this->context_->thread_pool_, ReduceFp16Impl, this, context_->thread_num_); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
FreeTmpBuffer(); |
|
|
|
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; |
|
|
|
|