|
|
|
@@ -96,202 +96,59 @@ int ArithmeticCPUKernel::InitBroadCastCase() { |
|
|
|
} |
|
|
|
|
|
|
|
void ArithmeticCPUKernel::InitRunFunction() { |
|
|
|
switch (op_parameter_->type_) { |
|
|
|
case PrimitiveType_Mul: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmetic_run_ = ElementMulRelu; |
|
|
|
arithmetic_run_int_ = ElementMulReluInt; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmetic_run_ = ElementMulRelu6; |
|
|
|
arithmetic_run_int_ = ElementMulRelu6Int; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmetic_run_ = ElementMul; |
|
|
|
arithmetic_run_int_ = ElementMulInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Add: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmetic_run_ = ElementAddRelu; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmetic_run_ = ElementAddRelu6; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmetic_run_ = ElementAdd; |
|
|
|
arithmetic_run_int_ = ElementAddInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Sub: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmetic_run_ = ElementSubRelu; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmetic_run_ = ElementSubRelu6; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmetic_run_ = ElementSub; |
|
|
|
arithmetic_run_int_ = ElementSubInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Div: |
|
|
|
case PrimitiveType_RealDiv: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmetic_run_ = ElementDivRelu; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmetic_run_ = ElementDivRelu6; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmetic_run_ = ElementDiv; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_LogicalAnd: |
|
|
|
arithmetic_run_ = ElementLogicalAnd; |
|
|
|
arithmetic_run_int_ = ElementLogicalAndInt; |
|
|
|
arithmetic_run_bool_ = ElementLogicalAndBool; |
|
|
|
break; |
|
|
|
case PrimitiveType_LogicalOr: |
|
|
|
arithmetic_run_ = ElementLogicalOr; |
|
|
|
break; |
|
|
|
case PrimitiveType_Maximum: |
|
|
|
arithmetic_run_ = ElementMaximum; |
|
|
|
arithmetic_run_int_ = ElementMaximumInt; |
|
|
|
break; |
|
|
|
case PrimitiveType_Minimum: |
|
|
|
arithmetic_run_ = ElementMinimum; |
|
|
|
arithmetic_run_int_ = ElementMinimumInt; |
|
|
|
break; |
|
|
|
case PrimitiveType_FloorDiv: |
|
|
|
arithmetic_run_ = ElementFloorDiv; |
|
|
|
arithmetic_run_int_ = ElementFloorDivInt; |
|
|
|
break; |
|
|
|
case PrimitiveType_FloorMod: |
|
|
|
arithmetic_run_ = ElementFloorMod; |
|
|
|
arithmetic_run_int_ = ElementFloorModInt; |
|
|
|
break; |
|
|
|
case PrimitiveType_Mod: |
|
|
|
arithmetic_run_ = ElementMod; |
|
|
|
arithmetic_run_int_ = ElementModInt; |
|
|
|
break; |
|
|
|
case PrimitiveType_SquaredDifference: |
|
|
|
arithmetic_run_ = ElementSquaredDifference; |
|
|
|
break; |
|
|
|
case PrimitiveType_Equal: |
|
|
|
case PrimitiveType_Less: |
|
|
|
case PrimitiveType_Greater: |
|
|
|
case PrimitiveType_NotEqual: |
|
|
|
case PrimitiveType_LessEqual: |
|
|
|
case PrimitiveType_GreaterEqual: |
|
|
|
arithmetic_run_ = nullptr; |
|
|
|
arithmetic_run_int_ = nullptr; |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; |
|
|
|
arithmetic_run_ = nullptr; |
|
|
|
break; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = { |
|
|
|
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu, |
|
|
|
ElementOptMulReluInt}, |
|
|
|
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr, ElementOptMulRelu6, |
|
|
|
ElementOptMulRelu6Int}, |
|
|
|
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul, |
|
|
|
ElementOptMulInt}, |
|
|
|
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr}, |
|
|
|
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6, nullptr}, |
|
|
|
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd, |
|
|
|
ElementOptAddInt}, |
|
|
|
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr}, |
|
|
|
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6, nullptr}, |
|
|
|
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub, |
|
|
|
ElementOptSubInt}, |
|
|
|
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr}, |
|
|
|
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, nullptr}, |
|
|
|
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv, |
|
|
|
ElementOptDivInt}, |
|
|
|
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr}, |
|
|
|
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, |
|
|
|
nullptr}, |
|
|
|
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv, |
|
|
|
ElementOptDivInt}, |
|
|
|
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt, |
|
|
|
ElementLogicalAndBool, nullptr, nullptr}, |
|
|
|
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, nullptr, nullptr, |
|
|
|
nullptr}, |
|
|
|
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr, |
|
|
|
nullptr}, |
|
|
|
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr, |
|
|
|
nullptr}, |
|
|
|
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr, |
|
|
|
nullptr, nullptr}, |
|
|
|
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr, |
|
|
|
nullptr, nullptr}, |
|
|
|
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod, |
|
|
|
ElementOptModInt}, |
|
|
|
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr, |
|
|
|
nullptr, nullptr}}; |
|
|
|
|
|
|
|
void ArithmeticCPUKernel::InitOptRunFunction() { |
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { |
|
|
|
switch (arithmeticParameter_->op_parameter_.type_) { |
|
|
|
case PrimitiveType_Mul: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptMulRelu; |
|
|
|
arithmetic_opt_run_int_ = ElementOptMulReluInt; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptMulRelu6; |
|
|
|
arithmetic_opt_run_int_ = ElementOptMulRelu6Int; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptMul; |
|
|
|
arithmetic_opt_run_int_ = ElementOptMulInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Add: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptAddRelu; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptAddRelu6; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptAdd; |
|
|
|
arithmetic_opt_run_int_ = ElementOptAddInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Sub: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptSubRelu; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptSubRelu6; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptSub; |
|
|
|
arithmetic_opt_run_int_ = ElementOptSubInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Div: |
|
|
|
case PrimitiveType_RealDiv: |
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
case schema::ActivationType_RELU: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptDivRelu; |
|
|
|
break; |
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptDivRelu6; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptDiv; |
|
|
|
arithmetic_opt_run_int_ = ElementOptDivInt; |
|
|
|
break; |
|
|
|
} |
|
|
|
break; |
|
|
|
case PrimitiveType_Mod: |
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
arithmetic_opt_run_ = ElementOptMod; |
|
|
|
arithmetic_opt_run_int_ = ElementOptModInt; |
|
|
|
break; |
|
|
|
default: |
|
|
|
arithmetic_opt_run_ = nullptr; |
|
|
|
arithmetic_opt_run_int_ = nullptr; |
|
|
|
break; |
|
|
|
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32); |
|
|
|
for (size_t i = 0; i < length; i++) { |
|
|
|
if (fun_table[i].primitive_type_ == op_parameter_->type_ && |
|
|
|
fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) { |
|
|
|
arithmetic_run_ = fun_table[i].func_; |
|
|
|
arithmetic_run_int_ = fun_table[i].int_func_; |
|
|
|
arithmetic_run_bool_ = fun_table[i].bool_func_; |
|
|
|
arithmetic_opt_run_ = fun_table[i].opt_func_; |
|
|
|
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_; |
|
|
|
return; |
|
|
|
} |
|
|
|
} else { |
|
|
|
arithmetic_opt_run_ = nullptr; |
|
|
|
arithmetic_opt_run_int_ = nullptr; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void ArithmeticCPUKernel::InitParam() { |
|
|
|
@@ -321,7 +178,6 @@ void ArithmeticCPUKernel::InitParam() { |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::ReSize() { |
|
|
|
InitParam(); |
|
|
|
InitOptRunFunction(); |
|
|
|
return InitBroadCastCase(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -359,6 +215,66 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 |
|
|
|
if (input0_broadcast_ == true || input1_broadcast_ == true) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ || |
|
|
|
arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t break_axis = 0; |
|
|
|
for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) { |
|
|
|
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { |
|
|
|
break_axis = i; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (break_axis < arithmeticParameter_->ndim_) { |
|
|
|
for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) { |
|
|
|
if (arithmeticParameter_->in_shape1_[i] != 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
break_pos_ = break_axis; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { |
|
|
|
int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1]; |
|
|
|
int batch_per_thread = UP_DIV(batch, thread_count_); |
|
|
|
|
|
|
|
int start_batch = batch_per_thread * task_id; |
|
|
|
int end_batch = MSMIN(start_batch + batch_per_thread, batch); |
|
|
|
int batch_size = end_batch - start_batch; |
|
|
|
|
|
|
|
int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1]; |
|
|
|
int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1]; |
|
|
|
int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1]; |
|
|
|
|
|
|
|
int offset0 = stride0 * start_batch; |
|
|
|
int offset1 = stride1 * start_batch; |
|
|
|
int out_offset = out_stride * start_batch; |
|
|
|
|
|
|
|
int ret = RET_OK; |
|
|
|
for (int i = 0; i < batch_size; i++) { |
|
|
|
if (data_type_ == kDataTypeFloat) { |
|
|
|
ret = arithmetic_opt_run_( |
|
|
|
reinterpret_cast<float *>(input0_ptr_) + offset0, reinterpret_cast<float *>(input1_ptr_) + offset1, |
|
|
|
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_); |
|
|
|
} else { |
|
|
|
ret = arithmetic_opt_run_int_( |
|
|
|
reinterpret_cast<int *>(input0_ptr_) + offset0, reinterpret_cast<int *>(input1_ptr_) + offset1, |
|
|
|
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_); |
|
|
|
} |
|
|
|
offset0 += stride0; |
|
|
|
offset1 += stride1; |
|
|
|
out_offset += out_stride; |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
auto element_num = out_tensors_[0]->ElementsNum(); |
|
|
|
|
|
|
|
@@ -370,27 +286,12 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
int error_code; |
|
|
|
if (arithmeticParameter_->broadcasting_) { |
|
|
|
/* need broadcast in runtime */ |
|
|
|
stride = UP_DIV(outside_, thread_count_); |
|
|
|
int out_count = MSMIN(stride, outside_ - stride * task_id); |
|
|
|
if (out_count <= 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
int out_thread_stride = stride * task_id; |
|
|
|
if (data_type_ == kDataTypeFloat) { |
|
|
|
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_), |
|
|
|
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); |
|
|
|
} else { |
|
|
|
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_), |
|
|
|
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); |
|
|
|
} |
|
|
|
return error_code; |
|
|
|
if (CanBatchScalar()) { |
|
|
|
return BatchScalarCalc(task_id); |
|
|
|
} |
|
|
|
|
|
|
|
if (arithmetic_opt_run_ != nullptr) { |
|
|
|
int error_code = RET_OK; |
|
|
|
if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) && |
|
|
|
(arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) { |
|
|
|
/* run opt function |
|
|
|
* one of input is scalar */ |
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1) { |
|
|
|
@@ -413,11 +314,24 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { |
|
|
|
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_), |
|
|
|
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); |
|
|
|
} |
|
|
|
} |
|
|
|
return error_code; |
|
|
|
} |
|
|
|
if (arithmeticParameter_->broadcasting_) { |
|
|
|
/* need broadcast in runtime */ |
|
|
|
stride = UP_DIV(outside_, thread_count_); |
|
|
|
int out_count = MSMIN(stride, outside_ - stride * task_id); |
|
|
|
if (out_count <= 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
int out_thread_stride = stride * task_id; |
|
|
|
if (data_type_ == kDataTypeFloat) { |
|
|
|
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_), |
|
|
|
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar"; |
|
|
|
return RET_ERROR; |
|
|
|
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_), |
|
|
|
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); |
|
|
|
} |
|
|
|
|
|
|
|
return error_code; |
|
|
|
} |
|
|
|
|
|
|
|
|