Browse Source

adjust fp16 1x1 conv

tags/v0.7.0-beta
fuzhiye 5 years ago
parent
commit
cbab9e74d9
13 changed files with 280 additions and 23 deletions
  1. +2
    -1
      mindspore/lite/src/populate_parameter.cc
  2. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc
  3. +171
    -10
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc
  4. +25
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h
  5. +10
    -3
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc
  6. +14
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  7. +10
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc
  8. +10
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc
  9. +6
    -1
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c
  10. +4
    -1
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h
  11. +21
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c
  12. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h
  13. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h

+ 2
- 1
mindspore/lite/src/populate_parameter.cc View File

@@ -171,7 +171,6 @@ OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) {
pooling_param->global_ = pooling_primitive->global();
pooling_param->window_w_ = pooling_primitive->windowW();
pooling_param->window_h_ = pooling_primitive->windowH();
// todo format
auto pooling_lite_primitive = (lite::Pooling *)primitive;
MS_ASSERT(nullptr != pooling_lite_primitive);
pooling_param->pad_u_ = pooling_lite_primitive->PadUp();
@@ -181,6 +180,8 @@ OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) {
pooling_param->stride_w_ = pooling_primitive->strideW();
pooling_param->stride_h_ = pooling_primitive->strideH();

auto is_global = pooling_primitive->global();
pooling_param->global_ = is_global;
auto pool_mode = pooling_primitive->poolingMode();
switch (pool_mode) {
case schema::PoolMode_MAX_POOLING:


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc View File

@@ -76,6 +76,10 @@ int PoolingBaseCPUKernel::Init() {
pooling_param_->output_channel_ = out_tensor->Channel();
pooling_param_->output_h_ = out_tensor->Height();
pooling_param_->output_w_ = out_tensor->Width();
if (pooling_param_->global_) {
pooling_param_->window_h_ = pooling_param_->input_h_;
pooling_param_->window_w_ = pooling_param_->input_w_;
}
return RET_OK;
}



+ 171
- 10
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc View File

@@ -27,16 +27,125 @@
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;

namespace mindspore::kernel {
int Convolution1x1FP16CPUKernel::InitMatmulParam() {
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No;
matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_;
return RET_OK;
}

int Convolution1x1FP16CPUKernel::InitConv1x1Param() {
pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 ||
conv_param_->stride_w_ != 1);
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<float16_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float16_t)));
if (input_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!";
return RET_MEMORY_FAILED;
}
memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float16_t));
}

thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;

pack_input_ =
reinterpret_cast<float16_t *>(malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED;
}
memset(pack_input_, 0, matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t));
return RET_OK;
}

int Convolution1x1FP16CPUKernel::InitWeightBias() {
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}
if (in_tensors_.size() == 3) {
bias_data_ = malloc(matmul_param_->col_8_ * sizeof(float16_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!";
return RET_ERROR;
}
memset(bias_data_, 0, matmul_param_->col_8_ * sizeof(float16_t));
memcpy(bias_data_, in_tensors_[2]->Data(), conv_param_->output_channel_ * sizeof(float16_t));
} else {
bias_data_ = nullptr;
}

weight_ptr_ = reinterpret_cast<float16_t *>(malloc(matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t)));
if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!";
return RET_ERROR;
}
memset(weight_ptr_, 0, matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t));
RowMajor2Col8MajorFp16(reinterpret_cast<float16_t *>(execute_weight_), weight_ptr_, matmul_param_->col_,
matmul_param_->deep_);

return RET_OK;
}

int Convolution1x1FP16CPUKernel::InitBuffer() {
/*=============================fp16_input_============================*/
size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ *
conv_param_->input_w_ * sizeof(float16_t);
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}
memset(fp16_input_, 0, fp16_input_size);

/*=============================fp16_out_============================*/
size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ *
conv_param_->output_w_ * sizeof(float16_t);
fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size));
if (fp16_out_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_out_ failed.";
return RET_ERROR;
}
return RET_OK;
}

int Convolution1x1FP16CPUKernel::Init() {
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
ret = InitMatmulParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init matmul param failed.";
return ret;
}
ret = InitConv1x1Param();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init conv1x1 param failed.";
return ret;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init buffer failed.";
return ret;
}
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return ret;
}
return RET_OK;
}

@@ -47,8 +156,14 @@ int Convolution1x1FP16CPUKernel::ReSize() {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
}
if (input_ptr_ != nullptr) {
free(input_ptr_);
}
if (weight_ptr_ != nullptr) {
free(weight_ptr_);
}

auto ret = ConvolutionBaseCPUKernel::Init();
@@ -56,13 +171,49 @@ int Convolution1x1FP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
ret = InitMatmulParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init matmul param failed.";
return ret;
}
ret = InitConv1x1Param();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init conv1x1 param failed.";
return ret;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init buffer failed.";
return ret;
}

return RET_OK;
}

void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *src_output) {
output_ptr_ = src_output;
if (pre_trans_input_) {
Conv1x1InputPackFp16(src_input, input_ptr_, conv_param_);
} else {
input_ptr_ = src_input;
}

RowMajor2Col8MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
return;
}

int Convolution1x1FP16CPUKernel::RunImpl(int task_id) {
// Conv1x1Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_,
// reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_,
// tmp_dst_buffer_, tmp_out_, task_id, conv_param_);
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}

auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float16_t *>(bias_data_) + thread_stride_ * task_id;

MatMulFp16(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, true);

return RET_OK;
}

@@ -83,12 +234,22 @@ int Convolution1x1FP16CPUKernel::Run() {
return RET_ERROR;
}

ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get executor tensor failed.";
return ret;
}

for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(
execute_input_ + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_,
execute_output_ + batch_index * matmul_param_->row_ * matmul_param_->col_);

int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]";
return RET_ERROR;
int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]";
return RET_ERROR;
}
}

ConvolutionBaseFP16CPUKernel::IfCastOutput();


+ 25
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h View File

@@ -22,6 +22,8 @@
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
#include "src/runtime/kernel/arm/nnacl/matmul_parameter.h"
#include "src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h"

namespace mindspore::kernel {
class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
@@ -29,7 +31,9 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {
matmul_param_ = new MatMulParameter();
}
~Convolution1x1FP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
@@ -40,14 +44,34 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
if (input_ptr_ != nullptr) {
free(input_ptr_);
}
if (weight_ptr_ != nullptr) {
free(weight_ptr_);
}
delete matmul_param_;
}

int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitBuffer();
int InitConv1x1Param();
int InitMatmulParam();
int InitWeightBias();
void Pre1x1Trans(float16_t *src_input, float16_t *src_output);

private:
bool pre_trans_input_ = false;
int thread_count_ = 0;
int thread_stride_ = 0;
float16_t *weight_ptr_ = nullptr;
float16_t *input_ptr_ = nullptr;
float16_t *pack_input_ = nullptr;
float16_t *output_ptr_ = nullptr;
MatMulParameter *matmul_param_ = nullptr;
};
} // namespace mindspore::kernel



+ 10
- 3
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc View File

@@ -62,7 +62,11 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(transformed_filter_addr_, 0, transformed_size);
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}
ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_);

// init bias
@@ -249,8 +253,11 @@ int Convolution3x3FP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();

ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get execute tensor failed.";
return ret;
}
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;


+ 14
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -18,6 +18,7 @@
#include <vector>
#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
@@ -46,7 +47,11 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;

// init weight
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
@@ -218,7 +223,12 @@ int ConvolutionFP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();

ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute tensor failed.";
return ret;
}

int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
@@ -256,6 +266,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
kernel::LiteKernel *kernel = nullptr;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
bool use_winograd = false;
int out_unit;


+ 10
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc View File

@@ -39,7 +39,11 @@ int ConvolutionSWFP16CPUKernel::ProcessFilter() {
int out_channel = conv_param_->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);

ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}

for (int oc = 0; oc < out_channel; ++oc) {
int src_oc_offset = oc * kernel_h * kernel_w * in_channel;
@@ -228,7 +232,11 @@ int ConvolutionSWFP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute tensor failed.";
return ret;
}

int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;


+ 10
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc View File

@@ -115,7 +115,11 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
return RET_ERROR;
}

ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}
WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block);

// init bias
@@ -377,7 +381,11 @@ int ConvolutionWinogradFP16CPUKernel::Run() {
return prepare_ret;
}

ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute tensor failed.";
return ret;
}

int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;


+ 6
- 1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c View File

@@ -16,6 +16,11 @@

#include "nnacl/fp16/matmul_fp16.h"

void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, bool write_nhwc) {
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
}

void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row16 = row / C16NUM * C16NUM;
size_t col8 = col / C8NUM * C8NUM;
@@ -134,7 +139,7 @@ void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row,
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C16NUM; tr++) {


+ 4
- 1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h View File

@@ -29,10 +29,13 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, bool write_nhwc);

void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
#ifdef __aarch64__
void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col);
int col, int stride, bool write_nhwc);
#endif
#ifdef __cplusplus
}


+ 21
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c View File

@@ -18,6 +18,27 @@
#include <string.h>
#include <stdlib.h>

void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param) {
/* support nhwc */
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const float16_t *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
float16_t *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_,
conv_param->input_channel_ * sizeof(float16_t));
}
}
return;
}

void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc


+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h View File

@@ -26,6 +26,8 @@
#ifdef __cplusplus
extern "C" {
#endif
void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param);

void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index);



+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h View File

@@ -26,6 +26,7 @@ typedef struct MatMulParameter {
int row_;
int col_;
int row_8_;
int row_16_;
int col_8_;
int deep_;
bool has_bias_;


Loading…
Cancel
Save