|
|
@@ -15,6 +15,7 @@ |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h" |
|
|
#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h" |
|
|
|
|
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h" |
|
|
#include "nnacl/fp16/arithmetic_fp16.h" |
|
|
#include "nnacl/fp16/arithmetic_fp16.h" |
|
|
#include "nnacl/fp16/cast_fp16.h" |
|
|
#include "nnacl/fp16/cast_fp16.h" |
|
|
#include "schema/model_generated.h" |
|
|
#include "schema/model_generated.h" |
|
|
@@ -29,7 +30,6 @@ using mindspore::lite::RET_OK; |
|
|
|
|
|
|
|
|
using mindspore::schema::PrimitiveType_Add; |
|
|
using mindspore::schema::PrimitiveType_Add; |
|
|
using mindspore::schema::PrimitiveType_Div; |
|
|
using mindspore::schema::PrimitiveType_Div; |
|
|
using mindspore::schema::PrimitiveType_Eltwise; |
|
|
|
|
|
using mindspore::schema::PrimitiveType_Equal; |
|
|
using mindspore::schema::PrimitiveType_Equal; |
|
|
using mindspore::schema::PrimitiveType_FloorDiv; |
|
|
using mindspore::schema::PrimitiveType_FloorDiv; |
|
|
using mindspore::schema::PrimitiveType_FloorMod; |
|
|
using mindspore::schema::PrimitiveType_FloorMod; |
|
|
@@ -47,121 +47,57 @@ using mindspore::schema::PrimitiveType_SquaredDifference; |
|
|
using mindspore::schema::PrimitiveType_Sub; |
|
|
using mindspore::schema::PrimitiveType_Sub; |
|
|
|
|
|
|
|
|
namespace mindspore::kernel { |
|
|
namespace mindspore::kernel { |
|
|
void ArithmeticFP16CPUKernel::FreeTmpBuffer() { |
|
|
|
|
|
if (input0_fp16_ != nullptr) { |
|
|
|
|
|
context_->allocator->Free(input0_fp16_); |
|
|
|
|
|
input0_fp16_ = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
if (input1_fp16_ != nullptr) { |
|
|
|
|
|
context_->allocator->Free(input1_fp16_); |
|
|
|
|
|
input1_fp16_ = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
if (output_fp16_ != nullptr) { |
|
|
|
|
|
context_->allocator->Free(output_fp16_); |
|
|
|
|
|
output_fp16_ = nullptr; |
|
|
|
|
|
|
|
|
ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = { |
|
|
|
|
|
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, |
|
|
|
|
|
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, |
|
|
|
|
|
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16}, |
|
|
|
|
|
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16}, |
|
|
|
|
|
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, |
|
|
|
|
|
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16}, |
|
|
|
|
|
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16}, |
|
|
|
|
|
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, |
|
|
|
|
|
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16}, |
|
|
|
|
|
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16}, |
|
|
|
|
|
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, |
|
|
|
|
|
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16}, |
|
|
|
|
|
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16}, |
|
|
|
|
|
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16}, |
|
|
|
|
|
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, |
|
|
|
|
|
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, |
|
|
|
|
|
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16, |
|
|
|
|
|
ElementOptSquaredDifferenceFp16}, |
|
|
|
|
|
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16}, |
|
|
|
|
|
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}, |
|
|
|
|
|
{PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16}, |
|
|
|
|
|
{PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16}, |
|
|
|
|
|
{PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16}, |
|
|
|
|
|
{PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16}, |
|
|
|
|
|
{PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16}, |
|
|
|
|
|
{PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16, |
|
|
|
|
|
ElementOptGreaterEqualFp16}, |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) { |
|
|
|
|
|
for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) { |
|
|
|
|
|
if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type && |
|
|
|
|
|
arithmetic_fun_table_fp16[i].activation_type_ == activation_type) { |
|
|
|
|
|
return arithmetic_fun_table_fp16[i].func_; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() {} |
|
|
|
|
|
|
|
|
ArithmeticOptFuncFp16 GetOptimizedArithmeticFun(int primitive_type, int activation_type) { |
|
|
|
|
|
for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) { |
|
|
|
|
|
if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type && |
|
|
|
|
|
arithmetic_fun_table_fp16[i].activation_type_ == activation_type) { |
|
|
|
|
|
return arithmetic_fun_table_fp16[i].opt_func_; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::Init() { |
|
|
int ArithmeticFP16CPUKernel::Init() { |
|
|
switch (op_parameter_->type_) { |
|
|
|
|
|
case PrimitiveType_Mul: |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_run_ = ElementMulReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_run_ = ElementMulRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_run_ = ElementMulFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Add: |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_run_ = ElementAddReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_run_ = ElementAddRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_run_ = ElementAddFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Sub: |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_run_ = ElementSubReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_run_ = ElementSubRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_run_ = ElementSubFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Div: |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_run_ = ElementDivReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_run_ = ElementDivRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_run_ = ElementDivFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_FloorMod: |
|
|
|
|
|
arithmetic_run_ = ElementFloorModFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_FloorDiv: |
|
|
|
|
|
arithmetic_run_ = ElementFloorDivFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_LogicalAnd: |
|
|
|
|
|
arithmetic_run_ = ElementLogicalAndFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_LogicalOr: |
|
|
|
|
|
arithmetic_run_ = ElementLogicalOrFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_SquaredDifference: |
|
|
|
|
|
arithmetic_run_ = ElementSquaredDifferenceFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Maximum: |
|
|
|
|
|
arithmetic_run_ = ElementMaximumFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Minimum: |
|
|
|
|
|
arithmetic_run_ = ElementMinimumFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_NotEqual: |
|
|
|
|
|
arithmetic_run_ = ElementNotEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Equal: |
|
|
|
|
|
arithmetic_run_ = ElementEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Less: |
|
|
|
|
|
arithmetic_run_ = ElementLessFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_LessEqual: |
|
|
|
|
|
arithmetic_run_ = ElementLessEqual; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Greater: |
|
|
|
|
|
arithmetic_run_ = ElementGreaterFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_GreaterEqual: |
|
|
|
|
|
arithmetic_run_ = ElementGreaterEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; |
|
|
|
|
|
arithmetic_run_ = nullptr; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
if (!InferShapeDone()) { |
|
|
if (!InferShapeDone()) { |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
@@ -169,162 +105,47 @@ int ArithmeticFP16CPUKernel::Init() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::ReSize() { |
|
|
int ArithmeticFP16CPUKernel::ReSize() { |
|
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); |
|
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); |
|
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); |
|
|
|
|
|
|
|
|
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); |
|
|
|
|
|
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); |
|
|
|
|
|
param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); |
|
|
|
|
|
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { |
|
|
|
|
|
switch (arithmeticParameter_->op_parameter_.type_) { |
|
|
|
|
|
case PrimitiveType_Mul: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptMulReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptMulRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptMulFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Add: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptAddReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptAddRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptAddFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Sub: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptSubReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptSubRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptSubFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Div: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) { |
|
|
|
|
|
case schema::ActivationType_RELU: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptDivReluFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case schema::ActivationType_RELU6: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptDivRelu6Fp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptDivFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_FloorMod: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptFloorModFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_FloorDiv: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptFloorDivFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_LogicalAnd: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptLogicalAndFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_LogicalOr: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptLogicalOrFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_SquaredDifference: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Maximum: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptMaximumFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Minimum: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptMinimumFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_NotEqual: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptNotEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Equal: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Less: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptLessFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_LessEqual: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptLessEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_Greater: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptGreaterFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
case PrimitiveType_GreaterEqual: |
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_run_ = ElementOptGreaterEqualFp16; |
|
|
|
|
|
break; |
|
|
|
|
|
default: |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { |
|
|
|
|
|
param_->broadcasting_ = false; |
|
|
|
|
|
arithmetic_opt_func_ = GetOptimizedArithmeticFun(param_->op_parameter_.type_, param_->activation_type_); |
|
|
|
|
|
} else { |
|
|
|
|
|
arithmetic_func_ = GetArithmeticFun(param_->op_parameter_.type_, param_->activation_type_); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (arithmeticParameter_->broadcasting_) { |
|
|
|
|
|
|
|
|
if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (param_->broadcasting_) { |
|
|
outside_ = 1; |
|
|
outside_ = 1; |
|
|
for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { |
|
|
|
|
|
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { |
|
|
|
|
|
|
|
|
for (int i = param_->ndim_ - 1; i >= 0; --i) { |
|
|
|
|
|
if (param_->in_shape0_[i] != param_->in_shape1_[i]) { |
|
|
break_pos_ = i; |
|
|
break_pos_ = i; |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
outside_ *= arithmeticParameter_->out_shape_[i]; |
|
|
|
|
|
|
|
|
outside_ *= param_->out_shape_[i]; |
|
|
} |
|
|
} |
|
|
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); |
|
|
|
|
|
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); |
|
|
|
|
|
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); |
|
|
|
|
|
|
|
|
ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_); |
|
|
|
|
|
ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_); |
|
|
|
|
|
ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_); |
|
|
} |
|
|
} |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, |
|
|
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, |
|
|
int out_count, int out_thread_stride) { |
|
|
|
|
|
|
|
|
int out_count, int cur_offset) { |
|
|
if (dim > break_pos_) { |
|
|
if (dim > break_pos_) { |
|
|
int error_code = |
|
|
|
|
|
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count); |
|
|
|
|
|
if (output_fp16_ != nullptr) { |
|
|
|
|
|
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data()); |
|
|
|
|
|
int bias = output - output_fp16_; |
|
|
|
|
|
output_fp32 += bias; |
|
|
|
|
|
Float16ToFloat32(output + out_thread_stride, output_fp32 + out_thread_stride, out_count); |
|
|
|
|
|
} |
|
|
|
|
|
return error_code; |
|
|
|
|
|
|
|
|
return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count); |
|
|
} |
|
|
} |
|
|
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { |
|
|
|
|
|
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; |
|
|
|
|
|
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; |
|
|
|
|
|
int error_code = |
|
|
|
|
|
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], |
|
|
|
|
|
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], |
|
|
|
|
|
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride); |
|
|
|
|
|
if (error_code != RET_OK) { |
|
|
|
|
|
|
|
|
for (int i = 0; i < param_->out_shape_[dim]; ++i) { |
|
|
|
|
|
int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i; |
|
|
|
|
|
int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i; |
|
|
|
|
|
int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim], |
|
|
|
|
|
output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset); |
|
|
|
|
|
if (ret != RET_OK) { |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -332,62 +153,33 @@ int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { |
|
|
int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { |
|
|
auto input0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data()); |
|
|
|
|
|
auto input1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data()); |
|
|
|
|
|
auto output = reinterpret_cast<float16_t *>(out_tensors_[0]->Data()); |
|
|
|
|
|
auto element_num = out_tensors_[0]->ElementsNum(); |
|
|
|
|
|
|
|
|
|
|
|
float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_; |
|
|
|
|
|
float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_; |
|
|
|
|
|
auto output_data = output_fp16_ == nullptr ? output : output_fp16_; |
|
|
|
|
|
int stride = UP_DIV(element_num, context_->thread_num_); |
|
|
|
|
|
int count = MSMIN(stride, element_num - stride * task_id); |
|
|
|
|
|
auto thread_stride = stride * task_id; |
|
|
|
|
|
|
|
|
int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_); |
|
|
|
|
|
int cur_offset = stride_per_thread * task_id; |
|
|
|
|
|
int cur_count = MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); |
|
|
|
|
|
|
|
|
if (arithmetic_run_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int error_code = RET_OK; |
|
|
|
|
|
if (arithmeticParameter_->broadcasting_) { |
|
|
|
|
|
stride = UP_DIV(outside_, context_->thread_num_); |
|
|
|
|
|
out_count_ = MSMIN(stride, outside_ - stride * task_id); |
|
|
|
|
|
out_thread_stride_ = stride * task_id; |
|
|
|
|
|
error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_); |
|
|
|
|
|
} else if (arithmetic_opt_run_ != nullptr) { |
|
|
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1) { |
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, |
|
|
|
|
|
arithmeticParameter_); |
|
|
|
|
|
} else if (arithmeticParameter_->in_elements_num1_ == 1) { |
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride, count, |
|
|
|
|
|
arithmeticParameter_); |
|
|
|
|
|
} else { |
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride, |
|
|
|
|
|
output_data + thread_stride, count, arithmeticParameter_); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
int ret = RET_OK; |
|
|
|
|
|
if (param_->broadcasting_) { |
|
|
|
|
|
ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset); |
|
|
|
|
|
} else if (param_->in_elements_num0_ == 1) { |
|
|
|
|
|
ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_); |
|
|
|
|
|
} else if (param_->in_elements_num1_ == 1) { |
|
|
|
|
|
ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_); |
|
|
} else { |
|
|
} else { |
|
|
error_code = |
|
|
|
|
|
arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count); |
|
|
|
|
|
|
|
|
ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count); |
|
|
} |
|
|
} |
|
|
if (error_code != RET_OK) { |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (output_fp16_ != nullptr && !arithmeticParameter_->broadcasting_) { |
|
|
|
|
|
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data()); |
|
|
|
|
|
Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count); |
|
|
|
|
|
|
|
|
if (ret != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret; |
|
|
} |
|
|
} |
|
|
return RET_OK; |
|
|
|
|
|
|
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static int ArithmeticsRun_Fp16(void *cdata, int task_id) { |
|
|
|
|
|
|
|
|
static int ArithmeticsRunFp16(void *cdata, int task_id) { |
|
|
auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata); |
|
|
auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata); |
|
|
auto error_code = arithmetic_kernel->DoArithmetic(task_id); |
|
|
|
|
|
if (error_code != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
|
|
|
auto ret = arithmetic_kernel->DoArithmetic(task_id); |
|
|
|
|
|
if (ret != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]"; |
|
|
} |
|
|
} |
|
|
return RET_OK; |
|
|
|
|
|
|
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::Run() { |
|
|
int ArithmeticFP16CPUKernel::Run() { |
|
|
@@ -396,43 +188,45 @@ int ArithmeticFP16CPUKernel::Run() { |
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << ret; |
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << ret; |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
auto output_tensor = out_tensors_.at(0); |
|
|
|
|
|
is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; |
|
|
|
|
|
is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32; |
|
|
|
|
|
is_output_fp32_ = output_tensor->data_type() == kNumberTypeFloat32; |
|
|
|
|
|
|
|
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); |
|
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); |
|
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); |
|
|
|
|
|
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { |
|
|
|
|
|
output_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); |
|
|
|
|
|
if (output_fp16_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!"; |
|
|
|
|
|
FreeTmpBuffer(); |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); |
|
|
|
|
|
input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_); |
|
|
|
|
|
output_fp16_ = MallocOutputFp16(output_tensor, context_); |
|
|
|
|
|
if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Memory allocation failed"; |
|
|
|
|
|
FreeTmpBuffer(); |
|
|
|
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { |
|
|
|
|
|
input0_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); |
|
|
|
|
|
if (input0_fp16_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!"; |
|
|
|
|
|
FreeTmpBuffer(); |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_, |
|
|
|
|
|
arithmeticParameter_->in_elements_num0_); |
|
|
|
|
|
|
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRunFp16, this, context_->thread_num_); |
|
|
|
|
|
if (ret != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]"; |
|
|
} |
|
|
} |
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { |
|
|
|
|
|
input1_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); |
|
|
|
|
|
if (input1_fp16_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!"; |
|
|
|
|
|
FreeTmpBuffer(); |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_, |
|
|
|
|
|
arithmeticParameter_->in_elements_num1_); |
|
|
|
|
|
|
|
|
if (is_output_fp32_) { |
|
|
|
|
|
Float16ToFloat32(output_fp16_, reinterpret_cast<float *>(output_tensor->Data()), output_tensor->ElementsNum()); |
|
|
} |
|
|
} |
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_); |
|
|
|
|
|
FreeTmpBuffer(); |
|
|
FreeTmpBuffer(); |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void ArithmeticFP16CPUKernel::FreeTmpBuffer() { |
|
|
|
|
|
if (is_input0_fp32_) { |
|
|
|
|
|
context_->allocator->Free(input0_fp16_); |
|
|
|
|
|
input0_fp16_ = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
if (is_input1_fp32_) { |
|
|
|
|
|
context_->allocator->Free(input1_fp16_); |
|
|
|
|
|
input1_fp16_ = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
if (is_output_fp32_) { |
|
|
|
|
|
context_->allocator->Free(output_fp16_); |
|
|
|
|
|
output_fp16_ = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, |
|
|
kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, |
|
|
const std::vector<lite::tensor::Tensor *> &outputs, |
|
|
const std::vector<lite::tensor::Tensor *> &outputs, |
|
|
OpParameter *parameter, const lite::Context *ctx, |
|
|
OpParameter *parameter, const lite::Context *ctx, |
|
|
@@ -473,5 +267,4 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16Kernel |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator) |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator) |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator) |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator) |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator) |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator) |
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator) |
|
|
|
|
|
} // namespace mindspore::kernel |
|
|
} // namespace mindspore::kernel |