|
|
|
@@ -60,7 +60,11 @@ int ArithmeticCPUKernel::ReSize() { |
|
|
|
outside_ *= param_->out_shape_[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
return ConstTensorBroadCast(); |
|
|
|
int ret = RET_OK; |
|
|
|
if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) { |
|
|
|
ret = ConstTensorBroadCast(); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::CheckDataType() { |
|
|
|
@@ -73,6 +77,47 @@ int ArithmeticCPUKernel::CheckDataType() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1 |
|
|
|
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) { |
|
|
|
return true; |
|
|
|
} else { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1 |
|
|
|
if (arithmetic_opt_run_ == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t break_axis = 0; |
|
|
|
for (size_t i = 0; i < param_->ndim_; i++) { |
|
|
|
if (param_->in_shape0_[i] != param_->in_shape1_[i]) { |
|
|
|
break_axis = i; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (break_axis < param_->ndim_) { |
|
|
|
for (size_t i = break_axis; i < param_->ndim_; i++) { |
|
|
|
if (param_->in_shape1_[i] != 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
break_pos_ = break_axis; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32 |
|
|
|
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1]; |
|
|
|
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1]; |
|
|
|
if (param_->in_elements_num0_ > param_->in_elements_num1_) { |
|
|
|
return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1; |
|
|
|
} else if (param_->in_elements_num0_ < param_->in_elements_num1_) { |
|
|
|
return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::ConstTensorBroadCast() { |
|
|
|
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */ |
|
|
|
if (!param_->broadcasting_) { |
|
|
|
@@ -86,11 +131,6 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() { |
|
|
|
param_->in_elements_num1_ != param_->out_elements_num_) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) { |
|
|
|
/* run opt function |
|
|
|
* one of input is scalar */ |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
FreeConstTileBuff(); |
|
|
|
if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) { |
|
|
|
@@ -252,32 +292,6 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 |
|
|
|
if (input0_broadcast_ || input1_broadcast_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (param_->in_elements_num0_ == param_->in_elements_num1_ || param_->in_elements_num0_ == 1 || |
|
|
|
param_->in_elements_num1_ == 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t break_axis = 0; |
|
|
|
for (size_t i = 0; i < param_->ndim_; i++) { |
|
|
|
if (param_->in_shape0_[i] != param_->in_shape1_[i]) { |
|
|
|
break_axis = i; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (break_axis < param_->ndim_) { |
|
|
|
for (size_t i = break_axis; i < param_->ndim_; i++) { |
|
|
|
if (param_->in_shape1_[i] != 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
break_pos_ = break_axis; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { |
|
|
|
if (break_pos_ < 1) { |
|
|
|
return RET_ERROR; |
|
|
|
@@ -308,6 +322,40 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::BiasCalc(int task_id) { |
|
|
|
int last_shape = param_->out_shape_[param_->ndim_ - 1]; |
|
|
|
int batch = param_->out_elements_num_ / last_shape; |
|
|
|
int batch_per_thread = UP_DIV(batch, context_->thread_num_); |
|
|
|
|
|
|
|
int start_batch = batch_per_thread * task_id; |
|
|
|
int end_batch = MSMIN(start_batch + batch_per_thread, batch); |
|
|
|
int batch_size = end_batch - start_batch; |
|
|
|
|
|
|
|
int stride = last_shape * data_type_len_; |
|
|
|
int offset = stride * start_batch; |
|
|
|
int ret = RET_OK; |
|
|
|
if (param_->in_elements_num0_ > param_->in_elements_num1_) { |
|
|
|
for (int i = 0; i < batch_size; i++) { |
|
|
|
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_), |
|
|
|
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false); |
|
|
|
if (ret != RET_OK) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
offset += stride; |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (int i = 0; i < batch_size; i++) { |
|
|
|
ret = Execute(static_cast<uint8_t *>(input0_ptr_), static_cast<uint8_t *>(input1_ptr_) + offset, |
|
|
|
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false); |
|
|
|
if (ret != RET_OK) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
offset += stride; |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
auto element_num = out_tensors_[0]->ElementsNum(); |
|
|
|
int stride = UP_DIV(element_num, context_->thread_num_); |
|
|
|
@@ -315,13 +363,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
if (count <= 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
/* run opt function, every batch one of input is scalar */ |
|
|
|
if (CanBatchScalar()) { |
|
|
|
return BatchScalarCalc(task_id); |
|
|
|
} |
|
|
|
int offset = stride * task_id * data_type_len_; |
|
|
|
/* run opt function, one of input is scalar */ |
|
|
|
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) { |
|
|
|
if (isScalarClac()) { // 2 32 240 240, 1 1 1 1 |
|
|
|
if (param_->in_elements_num0_ == 1) { |
|
|
|
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset, |
|
|
|
static_cast<uint8_t *>(output_ptr_) + offset, count, true); |
|
|
|
@@ -330,6 +374,14 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
static_cast<uint8_t *>(output_ptr_) + offset, count, true); |
|
|
|
} |
|
|
|
} |
|
|
|
/* run opt function, every batch one of input is scalar */ |
|
|
|
if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1 |
|
|
|
return BatchScalarCalc(task_id); |
|
|
|
} |
|
|
|
/* each batch is eltwise calculation */ |
|
|
|
if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32 |
|
|
|
return BiasCalc(task_id); |
|
|
|
} |
|
|
|
/* need broadcast in runtime */ |
|
|
|
if (param_->broadcasting_) { |
|
|
|
stride = UP_DIV(outside_, context_->thread_num_); |
|
|
|
@@ -339,7 +391,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
} |
|
|
|
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id); |
|
|
|
} |
|
|
|
/* no broadcast in runtime */ |
|
|
|
/* all elements eltwise calculation */ |
|
|
|
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset, |
|
|
|
static_cast<uint8_t *>(output_ptr_) + offset, count, false); |
|
|
|
} |
|
|
|
|