|
|
|
@@ -31,6 +31,7 @@ using mindspore::lite::RET_OK; |
|
|
|
using mindspore::schema::PrimitiveType_Mean; |
|
|
|
using mindspore::schema::PrimitiveType_Reduce; |
|
|
|
using mindspore::schema::ReduceMode; |
|
|
|
using mindspore::schema::ReduceMode_ReduceAll; |
|
|
|
using mindspore::schema::ReduceMode_ReduceASum; |
|
|
|
using mindspore::schema::ReduceMode_ReduceMax; |
|
|
|
using mindspore::schema::ReduceMode_ReduceMean; |
|
|
|
@@ -78,6 +79,10 @@ int ReduceCPUKernel::Init() { |
|
|
|
reducer_ = ReduceSum; |
|
|
|
break; |
|
|
|
} |
|
|
|
case static_cast<int>(ReduceMode_ReduceAll): { |
|
|
|
bool_reducer_ = ReduceAll; |
|
|
|
break; |
|
|
|
} |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -96,6 +101,9 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) { |
|
|
|
if (data_type_ == kDataTypeFloat) { |
|
|
|
ret = reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_), |
|
|
|
static_cast<float *>(dst_data_), task_id, context_->thread_num_); |
|
|
|
} else if (data_type_ == KDataTypeBool) { |
|
|
|
ret = bool_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const bool *>(src_data_), |
|
|
|
static_cast<bool *>(dst_data_), task_id, context_->thread_num_); |
|
|
|
} else { |
|
|
|
ret = int_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const int *>(src_data_), |
|
|
|
static_cast<int *>(dst_data_), task_id, context_->thread_num_); |
|
|
|
@@ -117,6 +125,8 @@ int ReduceImpl(void *cdata, int task_id) { |
|
|
|
int ReduceCPUKernel::Run() { |
|
|
|
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) { |
|
|
|
data_type_ = kDataTypeFloat; |
|
|
|
} else if (in_tensors().at(0)->data_type() == kNumberTypeBool) { |
|
|
|
data_type_ = KDataTypeBool; |
|
|
|
} else { |
|
|
|
data_type_ = kDataTypeInt; |
|
|
|
} |
|
|
|
@@ -202,6 +212,8 @@ int ReduceCPUKernel::MallocTmpBuffer() { |
|
|
|
void *buffer = nullptr; |
|
|
|
if (data_type_ == kDataTypeFloat) { |
|
|
|
buffer = context_->allocator->Malloc(size * sizeof(float)); |
|
|
|
} else if (data_type_ == KDataTypeBool) { |
|
|
|
buffer = context_->allocator->Malloc(size * sizeof(bool)); |
|
|
|
} else { |
|
|
|
buffer = context_->allocator->Malloc(size * sizeof(int)); |
|
|
|
} |
|
|
|
|