Browse Source

!5441 [MS][LITE][Develop]Refactor arithmetic fp16 kernel

Merge pull request !5441 from sunsuodong/fix_arithmetic
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
85b1dae578
2 changed files with 143 additions and 343 deletions
  1. +124
    -331
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  2. +19
    -12
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h

+ 124
- 331
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -15,6 +15,7 @@
*/ */


#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h" #include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "nnacl/fp16/arithmetic_fp16.h" #include "nnacl/fp16/arithmetic_fp16.h"
#include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/cast_fp16.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
@@ -29,7 +30,6 @@ using mindspore::lite::RET_OK;


using mindspore::schema::PrimitiveType_Add; using mindspore::schema::PrimitiveType_Add;
using mindspore::schema::PrimitiveType_Div; using mindspore::schema::PrimitiveType_Div;
using mindspore::schema::PrimitiveType_Eltwise;
using mindspore::schema::PrimitiveType_Equal; using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_FloorDiv; using mindspore::schema::PrimitiveType_FloorDiv;
using mindspore::schema::PrimitiveType_FloorMod; using mindspore::schema::PrimitiveType_FloorMod;
@@ -47,121 +47,57 @@ using mindspore::schema::PrimitiveType_SquaredDifference;
using mindspore::schema::PrimitiveType_Sub; using mindspore::schema::PrimitiveType_Sub;


namespace mindspore::kernel { namespace mindspore::kernel {
void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
if (input0_fp16_ != nullptr) {
context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr;
}
if (input1_fp16_ != nullptr) {
context_->allocator->Free(input1_fp16_);
input1_fp16_ = nullptr;
}
if (output_fp16_ != nullptr) {
context_->allocator->Free(output_fp16_);
output_fp16_ = nullptr;
ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = {
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16},
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16},
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16},
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16},
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16},
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16},
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16},
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16},
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16},
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16},
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16},
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16},
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16},
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16},
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16,
ElementOptSquaredDifferenceFp16},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16},
{PrimitiveType_NotEqual, schema::ActivationType_NO_ACTIVATION, ElementNotEqualFp16, ElementOptNotEqualFp16},
{PrimitiveType_Equal, schema::ActivationType_NO_ACTIVATION, ElementEqualFp16, ElementOptEqualFp16},
{PrimitiveType_Less, schema::ActivationType_NO_ACTIVATION, ElementLessFp16, ElementOptLessFp16},
{PrimitiveType_LessEqual, schema::ActivationType_NO_ACTIVATION, ElementLessEqual, ElementOptLessEqualFp16},
{PrimitiveType_Greater, schema::ActivationType_NO_ACTIVATION, ElementGreaterFp16, ElementOptGreaterFp16},
{PrimitiveType_GreaterEqual, schema::ActivationType_NO_ACTIVATION, ElementGreaterEqualFp16,
ElementOptGreaterEqualFp16},
};

ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) {
for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) {
if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type &&
arithmetic_fun_table_fp16[i].activation_type_ == activation_type) {
return arithmetic_fun_table_fp16[i].func_;
}
} }
return nullptr;
} }


ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() {}
ArithmeticOptFuncFp16 GetOptimizedArithmeticFun(int primitive_type, int activation_type) {
for (size_t i = 0; i < sizeof(arithmetic_fun_table_fp16); i++) {
if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type &&
arithmetic_fun_table_fp16[i].activation_type_ == activation_type) {
return arithmetic_fun_table_fp16[i].opt_func_;
}
}
return nullptr;
}


int ArithmeticFP16CPUKernel::Init() { int ArithmeticFP16CPUKernel::Init() {
switch (op_parameter_->type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6Fp16;
break;
default:
arithmetic_run_ = ElementMulFp16;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6Fp16;
break;
default:
arithmetic_run_ = ElementAddFp16;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6Fp16;
break;
default:
arithmetic_run_ = ElementSubFp16;
break;
}
break;
case PrimitiveType_Div:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6Fp16;
break;
default:
arithmetic_run_ = ElementDivFp16;
break;
}
break;
case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorModFp16;
break;
case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDivFp16;
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAndFp16;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOrFp16;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifferenceFp16;
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximumFp16;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimumFp16;
break;
case PrimitiveType_NotEqual:
arithmetic_run_ = ElementNotEqualFp16;
break;
case PrimitiveType_Equal:
arithmetic_run_ = ElementEqualFp16;
break;
case PrimitiveType_Less:
arithmetic_run_ = ElementLessFp16;
break;
case PrimitiveType_LessEqual:
arithmetic_run_ = ElementLessEqual;
break;
case PrimitiveType_Greater:
arithmetic_run_ = ElementGreaterFp16;
break;
case PrimitiveType_GreaterEqual:
arithmetic_run_ = ElementGreaterEqualFp16;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
arithmetic_run_ = nullptr;
break;
}
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
} }
@@ -169,162 +105,47 @@ int ArithmeticFP16CPUKernel::Init() {
} }


int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
param_->out_elements_num_ = out_tensors_[0]->ElementsNum();


if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) {
case PrimitiveType_Mul:
arithmeticParameter_->broadcasting_ = false;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptMulReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptMulRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptMulFp16;
break;
}
break;
case PrimitiveType_Add:
arithmeticParameter_->broadcasting_ = false;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptAddReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptAddRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptAddFp16;
break;
}
break;
case PrimitiveType_Sub:
arithmeticParameter_->broadcasting_ = false;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptSubReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptSubRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptSubFp16;
break;
}
break;
case PrimitiveType_Div:
arithmeticParameter_->broadcasting_ = false;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptDivReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptDivRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptDivFp16;
break;
}
break;
case PrimitiveType_FloorMod:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptFloorModFp16;
break;
case PrimitiveType_FloorDiv:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptFloorDivFp16;
break;
case PrimitiveType_LogicalAnd:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLogicalAndFp16;
break;
case PrimitiveType_LogicalOr:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLogicalOrFp16;
break;
case PrimitiveType_SquaredDifference:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16;
break;
case PrimitiveType_Maximum:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMaximumFp16;
break;
case PrimitiveType_Minimum:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMinimumFp16;
break;
case PrimitiveType_NotEqual:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptNotEqualFp16;
break;
case PrimitiveType_Equal:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptEqualFp16;
break;
case PrimitiveType_Less:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLessFp16;
break;
case PrimitiveType_LessEqual:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLessEqualFp16;
break;
case PrimitiveType_Greater:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptGreaterFp16;
break;
case PrimitiveType_GreaterEqual:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptGreaterEqualFp16;
break;
default:
break;
}
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
param_->broadcasting_ = false;
arithmetic_opt_func_ = GetOptimizedArithmeticFun(param_->op_parameter_.type_, param_->activation_type_);
} else {
arithmetic_func_ = GetArithmeticFun(param_->op_parameter_.type_, param_->activation_type_);
} }

if (arithmeticParameter_->broadcasting_) {
if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
return RET_ERROR;
}
if (param_->broadcasting_) {
outside_ = 1; outside_ = 1;
for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
for (int i = param_->ndim_ - 1; i >= 0; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i; break_pos_ = i;
break; break;
} }
outside_ *= arithmeticParameter_->out_shape_[i];
outside_ *= param_->out_shape_[i];
} }
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_);
ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_);
ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_);
} }
return RET_OK; return RET_OK;
} }


int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
int out_count, int out_thread_stride) {
int out_count, int cur_offset) {
if (dim > break_pos_) { if (dim > break_pos_) {
int error_code =
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count);
if (output_fp16_ != nullptr) {
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data());
int bias = output - output_fp16_;
output_fp32 += bias;
Float16ToFloat32(output + out_thread_stride, output_fp32 + out_thread_stride, out_count);
}
return error_code;
return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count);
} }
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code =
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
if (error_code != RET_OK) {
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i;
int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim],
output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset);
if (ret != RET_OK) {
return RET_ERROR; return RET_ERROR;
} }
} }
@@ -332,62 +153,33 @@ int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1,
} }


int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
auto input0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
auto output = reinterpret_cast<float16_t *>(out_tensors_[0]->Data());
auto element_num = out_tensors_[0]->ElementsNum();

float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_;
float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_;
auto output_data = output_fp16_ == nullptr ? output : output_fp16_;
int stride = UP_DIV(element_num, context_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id);
auto thread_stride = stride * task_id;
int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_);
int cur_offset = stride_per_thread * task_id;
int cur_count = MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);


if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR;
}

int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) {
stride = UP_DIV(outside_, context_->thread_num_);
out_count_ = MSMIN(stride, outside_ - stride * task_id);
out_thread_stride_ = stride * task_id;
error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_);
} else if (arithmetic_opt_run_ != nullptr) {
if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count,
arithmeticParameter_);
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride, count,
arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride,
output_data + thread_stride, count, arithmeticParameter_);
}
int ret = RET_OK;
if (param_->broadcasting_) {
ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset);
} else if (param_->in_elements_num0_ == 1) {
ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_);
} else if (param_->in_elements_num1_ == 1) {
ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_);
} else { } else {
error_code =
arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count);
ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count);
} }
if (error_code != RET_OK) {
return RET_ERROR;
}
if (output_fp16_ != nullptr && !arithmeticParameter_->broadcasting_) {
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data());
Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret;
} }
return RET_OK;
return ret;
} }


static int ArithmeticsRun_Fp16(void *cdata, int task_id) {
static int ArithmeticsRunFp16(void *cdata, int task_id) {
auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata); auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata);
auto error_code = arithmetic_kernel->DoArithmetic(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
auto ret = arithmetic_kernel->DoArithmetic(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]";
} }
return RET_OK;
return ret;
} }


int ArithmeticFP16CPUKernel::Run() { int ArithmeticFP16CPUKernel::Run() {
@@ -396,43 +188,45 @@ int ArithmeticFP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret; MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret; return ret;
} }
auto output_tensor = out_tensors_.at(0);
is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32;
is_output_fp32_ = output_tensor->data_type() == kNumberTypeFloat32;


arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
if (output_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_);
output_fp16_ = MallocOutputFp16(output_tensor, context_);
if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
return RET_ERROR;
} }
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_);
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRunFp16, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]";
} }
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input1_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
arithmeticParameter_->in_elements_num1_);
if (is_output_fp32_) {
Float16ToFloat32(output_fp16_, reinterpret_cast<float *>(output_tensor->Data()), output_tensor->ElementsNum());
} }
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_);
FreeTmpBuffer(); FreeTmpBuffer();
return ret; return ret;
} }


void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
if (is_input0_fp32_) {
context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr;
}
if (is_input1_fp32_) {
context_->allocator->Free(input1_fp16_);
input1_fp16_ = nullptr;
}
if (is_output_fp32_) {
context_->allocator->Free(output_fp16_);
output_fp16_ = nullptr;
}
}

kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *parameter, const lite::Context *ctx, OpParameter *parameter, const lite::Context *ctx,
@@ -473,5 +267,4 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16Kernel
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 19
- 12
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h View File

@@ -23,39 +23,46 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"


namespace mindspore::kernel { namespace mindspore::kernel {
class ArithmeticFP16CPUKernel : public LiteKernel {
typedef int (*ArithmeticRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
typedef int (*ArithmeticOptRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
typedef struct {
int primitive_type_;
int activation_type_;
ArithmeticFuncFp16 func_;
ArithmeticOptFuncFp16 opt_func_;
} ARITHMETIC_FUNC_INFO_FP16;


class ArithmeticFP16CPUKernel : public LiteKernel {
public: public:
ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) { : LiteKernel(parameter, inputs, outputs, ctx, primitive) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
} }
~ArithmeticFP16CPUKernel() override;
~ArithmeticFP16CPUKernel() = default;


int Init() override; int Init() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoArithmetic(int task_id); int DoArithmetic(int task_id);
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
int out_thread_stride);
int out_thread_stride);


private: private:
void FreeTmpBuffer(); void FreeTmpBuffer();
int outside_; int outside_;
int break_pos_; int break_pos_;
int out_thread_stride_;
int out_count_;
bool is_input0_fp32_ = false;
bool is_input1_fp32_ = false;
bool is_output_fp32_ = false;
float16_t *input0_fp16_ = nullptr; float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr; float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr; float16_t *output_fp16_ = nullptr;
ArithmeticParameter *arithmeticParameter_ = nullptr;
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticParameter *param_ = nullptr;
ArithmeticFuncFp16 arithmetic_func_ = nullptr;
ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_

Loading…
Cancel
Save