Browse Source

modify arm cpu fp16&fp32 op: Arithmetic

tags/v0.7.0-beta
tao_yunhao 5 years ago
parent
commit
c31ad9d5dd
3 changed files with 87 additions and 43 deletions
  1. +58
    -43
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  2. +27
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c
  3. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h

+ 58
- 43
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -162,34 +162,9 @@ int ArithmeticFP16CPUKernel::Init() {
} }


int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
FreeTmpBuffer();
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
}
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
if (output_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
}


if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) { switch (arithmeticParameter_->op_parameter_.type_) {
@@ -292,20 +267,6 @@ int ArithmeticFP16CPUKernel::ReSize() {
break; break;
} }
} }

if (arithmeticParameter_->broadcasting_) {
outside_ = 1;
for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= arithmeticParameter_->out_shape_[i];
}
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
}
return RET_OK; return RET_OK;
} }


@@ -344,10 +305,8 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {


int error_code = RET_OK; int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) { if (arithmeticParameter_->broadcasting_) {
stride = UP_DIV(outside_, context_->thread_num_);
out_count_ = MSMIN(stride, outside_ - stride * task_id);
out_thread_stride_ = stride * task_id;
error_code = broadcast_run_(input0_data, input1_data1, output_data, 0);
error_code =
arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count);
} else if (arithmetic_opt_run_ != nullptr) { } else if (arithmetic_opt_run_ != nullptr) {
if (arithmeticParameter_->in_elements_num0_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count,
@@ -364,6 +323,7 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count); arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count);
} }
if (error_code != RET_OK) { if (error_code != RET_OK) {
FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }
if (output_fp16_ != nullptr) { if (output_fp16_ != nullptr) {
@@ -390,6 +350,37 @@ int ArithmeticFP16CPUKernel::Run() {
return ret; return ret;
} }


arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
}
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>(
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
if (output_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
}

if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_, Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_); arithmeticParameter_->in_elements_num0_);
@@ -399,9 +390,33 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num1_); arithmeticParameter_->in_elements_num1_);
} }


if (arithmeticParameter_->broadcasting_) {
auto tile_size = arithmeticParameter_->out_elements_num_ * sizeof(float16_t);
tile_data0_ = reinterpret_cast<float16_t *>(malloc(tile_size));
if (tile_data0_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
tile_data1_ = reinterpret_cast<float16_t *>(malloc(tile_size));
if (tile_data1_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer();
return RET_ERROR;
}
auto input0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());

float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_;
float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_;

TileDimensionsFp16(input0_data, input1_data1, tile_data0_, tile_data1_, arithmeticParameter_);
}

ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_); ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret; MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret;
FreeTmpBuffer();
return ret; return ret;
} }
return RET_OK; return RET_OK;


+ 27
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c View File

@@ -18,6 +18,33 @@
#include <math.h> #include <math.h>
#include "nnacl/arithmetic_common.h" #include "nnacl/arithmetic_common.h"


void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides,
int *outStrides, int *multiple) {
int srcDimSize = inShape[dim];
if (dim == ndim - 1) {
for (int i = 0; i < multiple[dim]; i++) {
memcpy(outData, inData, srcDimSize * sizeof(float16_t));
outData += srcDimSize;
}
return;
}
for (size_t i = 0; i < srcDimSize; i++) {
for (size_t j = 0; j < multiple[dim]; j++) {
TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim,
inShape, inStrides, outStrides, multiple);
}
}
}

void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
ArithmeticParameter *param) {
CalcMultiplesAndStrides(param);
TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_,
param->multiples0_);
TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_,
param->multiples1_);
}

int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;


+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h View File

@@ -111,6 +111,8 @@ int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, in
int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);


void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
ArithmeticParameter *param);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


Loading…
Cancel
Save