Browse Source

!6887 [MS][LITE][CPU]optimize fp16 common conv

Merge pull request !6887 from fuzhiye/tmp
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
fab9ca38f5
9 changed files with 47 additions and 1226 deletions
  1. +13
    -185
      mindspore/lite/nnacl/fp16/conv_fp16.c
  2. +1
    -12
      mindspore/lite/nnacl/fp16/conv_fp16.h
  3. +20
    -71
      mindspore/lite/nnacl/fp16/pack_fp16.c
  4. +0
    -2
      mindspore/lite/nnacl/fp16/pack_fp16.h
  5. +0
    -261
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc
  6. +0
    -85
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h
  7. +9
    -15
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  8. +4
    -4
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h
  9. +0
    -591
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc

+ 13
- 185
mindspore/lite/nnacl/fp16/conv_fp16.c View File

@@ -17,6 +17,7 @@
#include <string.h>
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/fp16/winograd_transform_fp16.h"
#include "nnacl/fp16/matmul_fp16.h"

#ifdef __cplusplus
extern "C" {
@@ -122,7 +123,8 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we

// fp16 convolution common (im2col+gemm)
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) {
float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) {
const int tile_n = 16;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_;
@@ -132,203 +134,29 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
int thread_count = conv_param->thread_num_;
const int tile_n = 16;
int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, tile_n);

int channel_block = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * channel_block * C4NUM;

// we accumulate 4 channels per time for input blocks
int ic4 = UP_DIV(in_channel, C4NUM);
int conv_depth = kernel_h * kernel_w;
// bytes from one output's i-th channel to the next output's i-th channel
// we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward
int deep = kernel_plane * in_channel;

for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
int in_batch_offset = b * in_channel * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
float16_t *gemm_input = (float16_t *)(packed_input + task_id * unit_size * tile_n);
float16_t *gemm_input = packed_input + task_id * deep * tile_n;
float16_t *col_major_gemm_input = col_major_input + task_id * deep * tile_n;
size_t packed_input_size = deep * tile_n * sizeof(float16_t);
memset(gemm_input, 0, packed_input_size);
memset(col_major_gemm_input, 0, packed_input_size);
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);

int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
if (real_cal_num == tile_n) {
float16_t *gemm_output = output_data + out_offset;
IndirectGemmFp16_16x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
out_channel * sizeof(float16_t), 0, 0, relu, relu6);
} else {
// res part
float16_t *tmp_out_ptr = tmp_out_block + task_id * tile_n * out_channel;
IndirectGemmFp16_16x8(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
out_channel * sizeof(float16_t), 0, 0, relu, relu6);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float16_t));
}
}
}
}

// fp16 conv3x3
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
int task_id, ConvParameter *conv_param) {
int thread_count = conv_param->thread_num_;
const int tile_num = 16;
const int output_unit = 4;
const int k_plane = 36;
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int ic4 = ic8 * 2;
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);

int out_w_block = UP_DIV(conv_param->output_w_, C4NUM);
int out_h_block = UP_DIV(conv_param->output_h_, C4NUM);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, tile_num);
int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM;
int block_unit_buffer_offset = k_plane * C8NUM;
int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM;

int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_num;
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;

Conv3x3Fp16InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);

IndirectGemmFp16_16x8(tmp_dst_buffer + task_id * tmp_dst_buffer_offset,
tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM,
oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0);

Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
}
}

void UnPack3x3OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) {
int out_w_block = UP_DIV(width, C4NUM);
int out_h_block = UP_DIV(height, C4NUM);
int oc8 = UP_DIV(channel, C8NUM);

for (int b = 0; b < batch; b++) {
int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int ro_batch_size = b * channel * height * width;
const float16_t *batch_tmp_out = src + tmp_out_batch_offset;
float16_t *batch_out = dst + ro_batch_size;
for (int h = 0; h < height; h++) {
int src_h_offset = h * out_w_block * C4NUM * C8NUM;
const int dst_h_offset = h * width * channel;
for (int w = 0; w < width; w++) {
int src_w_offset = src_h_offset + w * C8NUM;
int dst_w_offset = dst_h_offset + w * channel;
for (int c = 0; c < oc8 - 1; ++c) {
int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset;
int dst_offset = dst_w_offset + c * C8NUM;
vst1q_f16(batch_out + dst_offset, vld1q_f16(batch_tmp_out + src_offset));
}

int c_res = channel - (oc8 - 1) * C8NUM;
int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM;
for (int c = 0; c < c_res; c++) {
int src_offset = src_c_res_offset + c;
int dst_offset = dst_c_res_offset + c;
(batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0];
}
}
}
}
}

void UnPack3x3ReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) {
int out_w_block = UP_DIV(width, C4NUM);
int out_h_block = UP_DIV(height, C4NUM);
int oc8 = UP_DIV(channel, C8NUM);

for (int b = 0; b < batch; b++) {
int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int ro_batch_size = b * channel * height * width;
const float16_t *batch_tmp_out = src + tmp_out_batch_offset;
float16_t *batch_out = dst + ro_batch_size;
for (int h = 0; h < height; h++) {
int src_h_offset = h * out_w_block * C4NUM * C8NUM;
const int dst_h_offset = h * width * channel;
for (int w = 0; w < width; w++) {
int src_w_offset = src_h_offset + w * C8NUM;
int dst_w_offset = dst_h_offset + w * channel;
for (int c = 0; c < oc8 - 1; ++c) {
int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset;
int dst_offset = dst_w_offset + c * C8NUM;
float16x8_t input_ptr = vld1q_f16(batch_tmp_out + src_offset);
float16x8_t zero = vdupq_n_f16(0);
input_ptr = vmaxq_f16(zero, input_ptr);
vst1q_f16(batch_out + dst_offset, input_ptr);
}

int c_res = channel - (oc8 - 1) * C8NUM;
int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM;
for (int c = 0; c < c_res; c++) {
int src_offset = src_c_res_offset + c;
int dst_offset = dst_c_res_offset + c;
float16_t input_data = (batch_tmp_out + src_offset)[0];
input_data = input_data < 0 ? 0 : input_data;
(batch_out + dst_offset)[0] = input_data;
}
}
}
}
}

void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel) {
int out_w_block = UP_DIV(width, C4NUM);
int out_h_block = UP_DIV(height, C4NUM);
int oc8 = UP_DIV(channel, C8NUM);

for (int b = 0; b < batch; b++) {
int tmp_out_batch_offset = b * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int ro_batch_size = b * channel * height * width;
const float16_t *batch_tmp_out = src + tmp_out_batch_offset;
float16_t *batch_out = dst + ro_batch_size;
for (int h = 0; h < height; h++) {
int src_h_offset = h * out_w_block * C4NUM * C8NUM;
const int dst_h_offset = h * width * channel;
for (int w = 0; w < width; w++) {
int src_w_offset = src_h_offset + w * C8NUM;
int dst_w_offset = dst_h_offset + w * channel;
for (int c = 0; c < oc8 - 1; ++c) {
int src_offset = c * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + src_w_offset;
int dst_offset = dst_w_offset + c * C8NUM;
float16x8_t input_ptr = vld1q_f16(batch_tmp_out + src_offset);
float16x8_t zero = vdupq_n_f16(0);
float16x8_t six = vdupq_n_f16(6);
input_ptr = vmaxq_f16(zero, input_ptr);
input_ptr = vminq_f16(six, input_ptr);
vst1q_f16(batch_out + dst_offset, input_ptr);
}

int c_res = channel - (oc8 - 1) * C8NUM;
int src_c_res_offset = src_w_offset + (oc8 - 1) * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int dst_c_res_offset = dst_w_offset + (oc8 - 1) * C8NUM;
for (int c = 0; c < c_res; c++) {
int src_offset = src_c_res_offset + c;
int dst_offset = dst_c_res_offset + c;
float16_t input_data = (batch_tmp_out + src_offset)[0];
input_data = input_data < 0 ? 0 : input_data;
input_data = input_data > 6 ? 6 : input_data;
(batch_out + dst_offset)[0] = input_data;
}
}
RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
real_cal_num, out_channel, out_channel, OutType_Nhwc);
}
}
}


+ 1
- 12
mindspore/lite/nnacl/fp16/conv_fp16.h View File

@@ -43,18 +43,7 @@ extern "C" {

// fp16 convolution common (im2col+gemm)
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param);

// fp16 conv3x3
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
int task_id, ConvParameter *conv_param);

void UnPack3x3OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel);

void UnPack3x3ReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel);

void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel);
float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param);

// fp16 convolution winograd
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,


+ 20
- 71
mindspore/lite/nnacl/fp16/pack_fp16.c View File

@@ -23,6 +23,7 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int kernel_plane = kernel_h * kernel_w;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_u_;
@@ -33,9 +34,6 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int ic4 = UP_DIV(in_channel, 4);
int ic4_minus = in_channel / 4;
memset(packed_input, 0, kernel_w * kernel_h * ic4 * C4NUM * 16 * sizeof(float16_t));

for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
@@ -46,74 +44,25 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4_minus; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * 16 * C4NUM;
#ifdef ENABLE_ARM64
vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride));
#else
for (int l = 0; l < C4NUM; ++l) {
(packed_input + channel_block_offset)[l] = (input_data + channel_block_stride)[l];
}
#endif
} // channel_block loop
int ic_res = in_channel - ic4_minus * C4NUM;
for (int l = 0; l < ic_res; ++l) {
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
int channel_block_offset = input_plane_offset + ic4_minus * 16 * C4NUM + l;
packed_input[channel_block_offset] = input_data[channel_block_stride];
}
} // kernel_w loop
} // kernel_h loop
} // tile num loop
}

void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) {
// original weight format : ohwi
const int tile_num = 8;
const int inchannel_block = 4;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int kernel_block = UP_DIV(out_channel, tile_num);
int channel_block = UP_DIV(in_channel, inchannel_block);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane;

int unit_size = tile_num * inchannel_block;
int block_size = pack_weight_size / kernel_block;

for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * channel_block;
for (int i = 0; i < channel_block; i++) {
int channel_block_stride = kernel_plane_stride + i * inchannel_block;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * inchannel_block;
int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h * tile_num;
for (int j = 0; j < kernel_block; j++) {
int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * tile_num;
int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num;
for (int k = 0; k < real_oc_num; k++) {
float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k;
*packed_data_ptr = *origin_data_ptr;
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
if (dilation_h == 1 && dilation_w == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
(kw_e - kw_s) * in_channel * sizeof(float16_t));
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + n) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float16_t));
} // kernel_w loop
} // kernel_h loop
}
} // tile num loop
}

void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {


+ 0
- 2
mindspore/lite/nnacl/fp16/pack_fp16.h View File

@@ -29,8 +29,6 @@ extern "C" {
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index);

void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight);

void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);

void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);


+ 0
- 261
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc View File

@@ -1,261 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/winograd_transform_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;

namespace mindspore::kernel {
void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param) {
auto input_channel = conv_param->input_channel_;
auto output_channel = conv_param->output_channel_;
auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
int iC8 = UP_DIV(input_channel, C8NUM);
int oC8 = UP_DIV(output_channel, C8NUM);

size_t tmp_size = oC8 * C8NUM * iC8 * C8NUM * kernel_plane * sizeof(float16_t);
auto tmp_addr = reinterpret_cast<float16_t *>(malloc(tmp_size));
if (tmp_addr == nullptr) {
MS_LOG(ERROR) << "malloc tmp_addr failed.";
return;
}
memset(tmp_addr, 0, tmp_size);

PackWeightToC4Fp16(origin_weight, tmp_addr, conv_param);
Conv3x3Fp16FilterTransform(tmp_addr, dst_weight, iC8 * 2, output_channel, kernel_plane);

free(tmp_addr);
}

int Convolution3x3FP16CPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel;
int iC8 = UP_DIV(input_channel, C8NUM);
int oC8 = UP_DIV(output_channel, C8NUM);

size_t transformed_size = iC8 * C8NUM * oC8 * C8NUM * 36 * sizeof(float16_t);
transformed_filter_addr_ = reinterpret_cast<float16_t *>(malloc(transformed_size));
if (transformed_filter_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed.";
return RET_ERROR;
}
memset(transformed_filter_addr_, 0, transformed_size);
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}
ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_);

size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t);
bias_data_ = malloc(new_bias_size);
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, new_bias_size);
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
if (in_tensors_.size() == kInputSize2) {
auto ori_bias_addr = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
for (int i = 0; i < output_channel; ++i) {
fp16_bias_data[i] = (float16_t)ori_bias_addr[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
}
return RET_OK;
}

int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
const int tile_num = 16;
const int k_plane = 36;
int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM);
int iC8 = UP_DIV(conv_param_->input_channel_, C8NUM);
MS_ASSERT(ctx_->allocator != nullptr);

size_t nhwc8_input_size =
iC8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = ctx_->allocator->Malloc(nhwc8_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
return RET_ERROR;
}

size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC8 * C8NUM * sizeof(float16_t);
tile_buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile_buffer_ failed.";
return RET_ERROR;
}

size_t block_unit_buffer_size = thread_count_ * k_plane * C8NUM * sizeof(float16_t);
block_unit_buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc block_unit_buffer_ failed.";
return RET_ERROR;
}

size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float16_t);
tmp_dst_buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";
return RET_ERROR;
}

int new_out_plane = UP_DIV(conv_param_->output_h_, C4NUM) * UP_DIV(conv_param_->output_w_, C4NUM) * C4NUM * C4NUM;
size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * new_out_plane * sizeof(float16_t);
tmp_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tmp_out_size));
if (tmp_out_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
return RET_ERROR;
}

return RET_OK;
}

void Convolution3x3FP16CPUKernel::ConfigInputOutput() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_format = input_tensor->GetFormat();
schema::Format execute_format = schema::Format::Format_NHWC4;
convert_func_ = LayoutTransformFp16(input_format, execute_format);
if (convert_func_ == nullptr) {
MS_LOG(ERROR) << "layout convert func is nullptr.";
return;
}
}

int Convolution3x3FP16CPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int Convolution3x3FP16CPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize is invalid.";
return ret;
}

ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
return RET_OK;
}

int Convolution3x3FP16CPUKernel::RunImpl(int task_id) {
Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_,
reinterpret_cast<float16_t *>(bias_data_), execute_output_, tile_buffer_, block_unit_buffer_,
tmp_dst_buffer_, tmp_out_, task_id, conv_param_);
return RET_OK;
}

static int Convolution3x3Fp16Impl(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution3x3FP16CPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution3x3 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int Convolution3x3FP16CPUKernel::PostProcess() {
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
UnPack3x3OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu:
UnPack3x3ReluOutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu6:
UnPack3x3Relu6OutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}

int Convolution3x3FP16CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get execute tensor failed.";
return ret;
}
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
PackNHWCToNHWC8Fp16(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);

int error_code = ParallelLaunch(this->context_->thread_pool_, Convolution3x3Fp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv3x3 fp16 error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}

ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
return ret;
}
ConvolutionBaseFP16CPUKernel::IfCastOutput();
ConvolutionBaseFP16CPUKernel::FreeTmpBuffer();
FreeTmpBuffer();
return RET_OK;
}
} // namespace mindspore::kernel

+ 0
- 85
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h View File

@@ -1,85 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_

#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "nnacl/optimized_kernel.h"

namespace mindspore::kernel {
class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~Convolution3x3FP16CPUKernel() override {
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
fp16_weight_ = nullptr;
}
if (transformed_filter_addr_ != nullptr) {
free(transformed_filter_addr_);
transformed_filter_addr_ = nullptr;
}
}

int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();
int PostProcess();

private:
void FreeTmpBuffer() {
if (nhwc4_input_ != nullptr) {
ctx_->allocator->Free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (tile_buffer_ != nullptr) {
ctx_->allocator->Free(tile_buffer_);
tile_buffer_ = nullptr;
}
if (block_unit_buffer_ != nullptr) {
ctx_->allocator->Free(block_unit_buffer_);
block_unit_buffer_ = nullptr;
}
if (tmp_dst_buffer_ != nullptr) {
ctx_->allocator->Free(tmp_dst_buffer_);
tmp_dst_buffer_ = nullptr;
}
if (tmp_out_ != nullptr) {
ctx_->allocator->Free(tmp_out_);
tmp_out_ = nullptr;
}
}
float16_t *transformed_filter_addr_ = nullptr;
float16_t *tile_buffer_ = nullptr;
float16_t *block_unit_buffer_ = nullptr;
float16_t *tmp_dst_buffer_ = nullptr;
float16_t *tmp_out_ = nullptr;
};
void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param);
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_

+ 9
- 15
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -17,7 +17,6 @@
#include "src/runtime/kernel/arm/fp16/convolution_fp16.h"
#include <vector>
#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
@@ -45,9 +44,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int oc8 = UP_DIV(out_channel, C8NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
int pack_weight_size = oc8 * C8NUM * in_channel * kernel_plane;

// init weight
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
@@ -61,7 +59,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
PackWeightFp16(execute_weight_, conv_param_, packed_weight_);
RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false);

// init bias
bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t));
@@ -83,24 +81,20 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
}

int ConvolutionFP16CPUKernel::InitTmpBuffer() {
const int cal_num = 16;
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
int channel_block = UP_DIV(in_channel, C4NUM);
int cal_num = 16;
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int unit_size = kernel_plane * channel_block * C4NUM;
int packed_input_size = thread_count_ * cal_num * unit_size;
int unit_size = kernel_plane * in_channel * cal_num * thread_count_;

packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(packed_input_size * sizeof(float16_t)));
packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed_input_ failed.";
return RET_ERROR;
}

tmp_output_block_ =
reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(thread_count_ * cal_num * out_channel * sizeof(float16_t)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_output_block_ failed.";
col_major_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t)));
if (col_major_input_ == nullptr) {
MS_LOG(ERROR) << "malloc col_major_input_ failed.";
return RET_ERROR;
}
return RET_OK;
@@ -134,7 +128,7 @@ int ConvolutionFP16CPUKernel::ReSize() {
}

int ConvolutionFP16CPUKernel::RunImpl(int task_id) {
ConvFp16(execute_input_, packed_input_, packed_weight_, reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_,
ConvFp16(execute_input_, packed_input_, packed_weight_, reinterpret_cast<float16_t *>(bias_data_), col_major_input_,
execute_output_, task_id, conv_param_);
return RET_OK;
}


+ 4
- 4
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h View File

@@ -53,14 +53,14 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (tmp_output_block_ != nullptr) {
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
if (col_major_input_ != nullptr) {
ctx_->allocator->Free(col_major_input_);
col_major_input_ = nullptr;
}
}
float16_t *packed_input_ = nullptr;
float16_t *packed_weight_ = nullptr;
float16_t *tmp_output_block_ = nullptr;
float16_t *col_major_input_ = nullptr;
};
} // namespace mindspore::kernel



+ 0
- 591
mindspore/lite/test/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc View File

@@ -1,591 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/utils.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "nnacl/fp16/conv_fp16.h"

namespace mindspore {
class TestConvolutionFp16 : public mindspore::CommonTest {
public:
TestConvolutionFp16() {}
};

void InitConvParamGroup1Fp16(ConvParameter *conv_param) {
conv_param->input_batch_ = 1;
conv_param->input_h_ = 28;
conv_param->input_w_ = 28;
conv_param->input_channel_ = 3;

conv_param->output_batch_ = 1;
conv_param->output_h_ = 28;
conv_param->output_w_ = 28;
conv_param->output_channel_ = 32;

conv_param->kernel_h_ = 3;
conv_param->kernel_w_ = 3;

conv_param->stride_h_ = 1;
conv_param->stride_w_ = 1;

conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;

conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
conv_param->thread_num_ = 1;
}

void InitConvParamGroup2Fp16(ConvParameter *conv_param) {
conv_param->input_batch_ = 1;
conv_param->input_h_ = 128;
conv_param->input_w_ = 128;
conv_param->input_channel_ = 32;

conv_param->output_batch_ = 1;
conv_param->output_h_ = 128;
conv_param->output_w_ = 128;
conv_param->output_channel_ = 32;

conv_param->kernel_h_ = 3;
conv_param->kernel_w_ = 3;

conv_param->stride_h_ = 1;
conv_param->stride_w_ = 1;

conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;

conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
conv_param->thread_num_ = 1;
}

TEST_F(TestConvolutionFp16, ConvTest1) {
// prepare stage
auto conv_param = new ConvParameter();
InitConvParamGroup1Fp16(conv_param);

int tile_num = 16;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int kernel_plane = k_h * k_w;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int i_h = conv_param->input_h_;
int i_w = conv_param->input_w_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);

size_t weight_size;
std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
std::cout << "==============fp32 weight data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << weight_data[i] << ", ";
}
std::cout << std::endl;

std::cout << "weight data size: " << weight_size / sizeof(float) << std::endl;

int weight_ele_size = weight_size / sizeof(float);
auto fp16_weight_data = new float16_t[weight_ele_size];
for (int i = 0; i < weight_ele_size; i++) {
fp16_weight_data[i] = static_cast<float16_t>(weight_data[i]);
}

std::cout << "==============fp16 weight data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << fp16_weight_data[i] << ", ";
}
std::cout << std::endl;

auto packed_weight = reinterpret_cast<float16_t *>(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float16_t)));
PackWeightFp16(fp16_weight_data, conv_param, packed_weight);

std::cout << "==============fp16 packed weight data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << packed_weight[i] << ", ";
}
std::cout << std::endl;

size_t input_size;
std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
std::cout << "==============fp32 input data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << input_data[i] << ", ";
}
std::cout << std::endl;

int input_ele_size = input_size / sizeof(float);
auto fp16_input_data = new float16_t[input_ele_size];
for (int i = 0; i < input_ele_size; i++) {
fp16_input_data[i] = static_cast<float16_t>(input_data[i]);
}

auto nhwc4_input_data = reinterpret_cast<float16_t *>(malloc(i_h * i_w * ic4 * C4NUM * sizeof(float16_t)));
PackNHWCToNHWC4Fp32(fp16_input_data, nhwc4_input_data, 1, i_h * i_w, in_channel);

std::cout << "==============fp16 input data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << fp16_input_data[i] << ", ";
}
std::cout << std::endl;

int output_count = conv_param->output_h_ * conv_param->output_w_;
int output_tile_count = UP_DIV(output_count, tile_num);
int unit_size = kernel_plane * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_num * unit_size;
auto packed_input = reinterpret_cast<float16_t *>(malloc(in_batch * packed_input_size * sizeof(float16_t)));
memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t));

auto bias_data = reinterpret_cast<float16_t *>(malloc(conv_param->output_channel_ * sizeof(float16_t)));
memset(bias_data, 0, conv_param->output_channel_ * sizeof(float16_t));

size_t output_data_size =
conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
auto output_data = new float16_t[output_data_size];
auto tmp_output_block = reinterpret_cast<float16_t *>(malloc(tile_num * out_channel * sizeof(float16_t)));

// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
// warmup
for (int i = 0; i < 3; i++) {
ConvFp16(nhwc4_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param);
}

int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
ConvFp16(nhwc4_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param);
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
time_avg = cost / loop_count;
printf("single thread running time : %f ms\n", time_avg / 1000.0f);

std::cout << "==============fp16 output data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << output_data[i] << ", ";
}
std::cout << std::endl;

auto fp32_output_data = new float[output_data_size];
for (int i = 0; i < output_data_size; i++) {
fp32_output_data[i] = static_cast<float>(output_data[i]);
}
printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
std::cout << fp32_output_data[i] << " ,";
}
std::cout << std::endl;

std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin";
lite::CompareOutput(fp32_output_data, output_path);

free(nhwc4_input_data);
free(packed_input);
free(bias_data);
free(packed_weight);
free(tmp_output_block);
delete conv_param;
delete input_data;
delete weight_data;
delete[] fp16_weight_data;
delete[] fp16_input_data;
delete[] fp32_output_data;
delete[] output_data;
MS_LOG(INFO) << "TestConvolutionFp16 passed";
}

TEST_F(TestConvolutionFp16, ConvTest2) {
// prepare stage
auto conv_param = new ConvParameter();
InitConvParamGroup2Fp16(conv_param);

// parameter
int tile_num = 16;
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int kernel_plane = k_h * k_w;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);

// weight
size_t weight_size;
std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_32.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
int weight_ele_size = weight_size / sizeof(float);
auto fp16_weight_data = new float16_t[weight_ele_size];
for (int i = 0; i < weight_ele_size; i++) {
fp16_weight_data[i] = static_cast<float16_t>(weight_data[i]);
}
auto packed_weight = reinterpret_cast<float16_t *>(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float16_t)));
PackWeightFp16(fp16_weight_data, conv_param, packed_weight);

// input
size_t input_size;
std::string input_path = "./test_data/conv/convfp32_input_1_128_128_32.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
int input_ele_size = input_size / sizeof(float);
auto fp16_input_data = new float16_t[input_ele_size];
for (int i = 0; i < input_ele_size; i++) {
fp16_input_data[i] = static_cast<float16_t>(input_data[i]);
}
int output_count = conv_param->output_h_ * conv_param->output_w_;
int output_tile_count = UP_DIV(output_count, tile_num);
int unit_size = kernel_plane * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_num * unit_size;
auto packed_input = reinterpret_cast<float16_t *>(malloc(in_batch * packed_input_size * sizeof(float16_t)));
memset(packed_input, 0, in_batch * packed_input_size * sizeof(float16_t));

// bias
auto bias_data = reinterpret_cast<float16_t *>(malloc(conv_param->output_channel_ * sizeof(float16_t)));
memset(bias_data, 0, conv_param->output_channel_ * sizeof(float16_t));

// output
auto tmp_output_block = reinterpret_cast<float16_t *>(malloc(tile_num * out_channel * sizeof(float16_t)));
size_t output_data_size =
conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
auto output_data = new float16_t[output_data_size];

// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
// warmup
for (int i = 0; i < 3; i++) {
ConvFp16(fp16_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param);
}

int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
ConvFp16(fp16_input_data, packed_input, packed_weight, bias_data, tmp_output_block, output_data, 0, conv_param);
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
time_avg = cost / loop_count;
printf("single thread running time : %f ms\n", time_avg / 1000.0f);

std::cout << "==============fp16 output data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << output_data[i] << ", ";
}
std::cout << std::endl;

auto fp32_output_data = new float[output_data_size];
for (int i = 0; i < output_data_size; i++) {
fp32_output_data[i] = static_cast<float>(output_data[i]);
}
printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
std::cout << fp32_output_data[i] << " ,";
}
std::cout << std::endl;

std::string output_path = "./test_data/conv/convfp32_out_1_128_128_32.bin";
lite::CompareOutput(fp32_output_data, output_path);

free(packed_input);
free(bias_data);
free(packed_weight);
free(tmp_output_block);
delete conv_param;
delete input_data;
delete weight_data;
delete[] fp16_weight_data;
delete[] fp16_input_data;
delete[] fp32_output_data;
delete[] output_data;
MS_LOG(INFO) << "TestConvolutionFp16 passed";
}

TEST_F(TestConvolutionFp16, Conv3x3Test1) {
auto conv_param = new ConvParameter();
InitConvParamGroup1Fp16(conv_param);
int thread_count = 1;
int tile_num = 16;
int output_batch = conv_param->output_batch_;
int output_h = conv_param->output_h_;
int output_w = conv_param->output_w_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);

// tmp buffer
int k_plane = 36;
size_t tile_buffer_size = thread_count * tile_num * k_plane * ic4 * C4NUM * sizeof(float16_t);
float16_t *tile_buffer = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
memset(tile_buffer, 0, tile_buffer_size);

size_t block_unit_buffer_size = thread_count * k_plane * C4NUM * sizeof(float16_t);
float16_t *block_unit_buffer = reinterpret_cast<float16_t *>(malloc(block_unit_buffer_size));
memset(block_unit_buffer, 0, block_unit_buffer_size);

size_t tmp_dst_buffer_size = thread_count * tile_num * k_plane * oc8 * C8NUM * sizeof(float16_t);
float16_t *tmp_dst_buffer = reinterpret_cast<float16_t *>(malloc(tmp_dst_buffer_size));
memset(tmp_dst_buffer, 0, tmp_dst_buffer_size);

size_t tmp_out_size = oc8 * C8NUM * output_batch * output_h * output_w * tile_num * sizeof(float16_t);
float16_t *tmp_out = reinterpret_cast<float16_t *>(malloc(tmp_out_size));
memset(tmp_out, 0, tmp_out_size);

// weight
size_t weight_size;
std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
std::cout << "==============fp32 weight data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << weight_data[i] << ", ";
}
std::cout << std::endl;

std::cout << "weight data size: " << weight_size / sizeof(float) << std::endl;

int weight_ele_size = weight_size / sizeof(float);
auto fp16_weight_data = new float16_t[weight_ele_size];
for (int i = 0; i < weight_ele_size; i++) {
fp16_weight_data[i] = (float16_t)weight_data[i];
}

std::cout << "==============fp16 weight data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << fp16_weight_data[i] << ", ";
}
std::cout << std::endl;

size_t transformed_size = ic4 * C4NUM * oc8 * C8NUM * 36;
auto transformed_weight_data = new float16_t[transformed_size];
memset(transformed_weight_data, 0, transformed_size * sizeof(float16_t));
kernel::ProcessFilterFp16(fp16_weight_data, transformed_weight_data, conv_param);

// bias
auto bias_data =
reinterpret_cast<float16_t *>(malloc(UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t)));
memset(bias_data, 0, UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t));

// input
size_t input_size;
std::string input_path = "./test_data/conv/convfp32_input_1_28_28_3.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
std::cout << "==============fp32 input data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << input_data[i] << ", ";
}
std::cout << std::endl;

int input_ele_size = input_size / sizeof(float);
auto fp16_input_data = new float16_t[input_ele_size];
for (int i = 0; i < input_ele_size; i++) {
fp16_input_data[i] = static_cast<float16_t>(input_data[i]);
}

std::cout << "==============fp16 input data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << fp16_input_data[i] << ", ";
}
std::cout << std::endl;

// output
size_t output_data_size =
conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
auto output_data = new float16_t[output_data_size];

// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
// warmup
for (int i = 0; i < 3; i++) {
Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer,
tmp_dst_buffer, tmp_out, 0, conv_param);
}

int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer,
tmp_dst_buffer, tmp_out, 0, conv_param);
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
time_avg = cost / loop_count;
printf("single thread running time : %f ms\n", time_avg / 1000.0f);

std::cout << "==============fp16 output data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << output_data[i] << ", ";
}
std::cout << std::endl;

auto fp32_output_data = new float[output_data_size];
for (int i = 0; i < output_data_size; i++) {
fp32_output_data[i] = static_cast<float>(output_data[i]);
}
printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
std::cout << fp32_output_data[i] << " ,";
}
std::cout << std::endl;

std::string output_path = "./test_data/conv/convfp32_out_1_28_28_32.bin";
lite::CompareOutput(fp32_output_data, output_path);

free(bias_data);
free(tile_buffer);
free(block_unit_buffer);
free(tmp_dst_buffer);
free(tmp_out);
delete input_data;
delete weight_data;
delete conv_param;
delete[] fp16_weight_data;
delete[] fp16_input_data;
delete[] fp32_output_data;
delete[] output_data;
delete[] transformed_weight_data;
MS_LOG(INFO) << "TestConvolutionFp16 Conv3x3 passed";
}

TEST_F(TestConvolutionFp16, Conv3x3Test2) {
auto conv_param = new ConvParameter();
InitConvParamGroup2Fp16(conv_param);
int thread_count = 1;
int tile_num = 16;
int output_batch = conv_param->output_batch_;
int output_h = conv_param->output_h_;
int output_w = conv_param->output_w_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);

// tmp buffer
int k_plane = 36;
size_t tile_buffer_size = thread_count * tile_num * k_plane * ic4 * C4NUM * sizeof(float16_t);
float16_t *tile_buffer = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
memset(tile_buffer, 0, tile_buffer_size);

size_t block_unit_buffer_size = thread_count * k_plane * C4NUM * sizeof(float16_t);
float16_t *block_unit_buffer = reinterpret_cast<float16_t *>(malloc(block_unit_buffer_size));
memset(block_unit_buffer, 0, block_unit_buffer_size);

size_t tmp_dst_buffer_size = thread_count * tile_num * k_plane * oc8 * C8NUM * sizeof(float16_t);
float16_t *tmp_dst_buffer = reinterpret_cast<float16_t *>(malloc(tmp_dst_buffer_size));
memset(tmp_dst_buffer, 0, tmp_dst_buffer_size);

size_t tmp_out_size = oc8 * C8NUM * output_batch * output_h * output_w * tile_num * sizeof(float16_t);
float16_t *tmp_out = reinterpret_cast<float16_t *>(malloc(tmp_out_size));
memset(tmp_out, 0, tmp_out_size);

// weight
size_t weight_size;
std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_32.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
int weight_ele_size = weight_size / sizeof(float);
auto fp16_weight_data = new float16_t[weight_ele_size];
for (int i = 0; i < weight_ele_size; i++) {
fp16_weight_data[i] = static_cast<float16_t>(weight_data[i]);
}
size_t transformed_size = ic4 * C4NUM * oc8 * C8NUM * 36;
auto transformed_weight_data = new float16_t[transformed_size];
memset(transformed_weight_data, 0, transformed_size * sizeof(float16_t));
kernel::ProcessFilterFp16(fp16_weight_data, transformed_weight_data, conv_param);

// bias
auto bias_data =
reinterpret_cast<float16_t *>(malloc(UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t)));
memset(bias_data, 0, UP_DIV(conv_param->output_channel_, 8) * 8 * sizeof(float16_t));

// input
size_t input_size;
std::string input_path = "./test_data/conv/convfp32_input_1_128_128_32.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
int input_ele_size = input_size / sizeof(float);
auto fp16_input_data = new float16_t[input_ele_size];
for (int i = 0; i < input_ele_size; i++) {
fp16_input_data[i] = static_cast<float16_t>(input_data[i]);
}

// output
size_t output_data_size =
conv_param->output_batch_ * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
auto output_data = new float16_t[output_data_size];

// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
// warmup
for (int i = 0; i < 3; i++) {
Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer,
tmp_dst_buffer, tmp_out, 0, conv_param);
}

int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
Conv3x3Fp16(fp16_input_data, transformed_weight_data, bias_data, output_data, tile_buffer, block_unit_buffer,
tmp_dst_buffer, tmp_out, 0, conv_param);
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
time_avg = cost / loop_count;
printf("single thread running time : %f ms\n", time_avg / 1000.0f);

std::cout << "==============fp16 output data===========" << std::endl;
for (int i = 0; i < 20; i++) {
std::cout << output_data[i] << ", ";
}
std::cout << std::endl;

auto fp32_output_data = new float[output_data_size];
for (int i = 0; i < output_data_size; i++) {
fp32_output_data[i] = static_cast<float>(output_data[i]);
}
printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
std::cout << fp32_output_data[i] << " ,";
}
std::cout << std::endl;

std::string output_path = "./test_data/conv/convfp32_out_1_128_128_32.bin";
lite::CompareOutput(fp32_output_data, output_path);

free(bias_data);
free(tile_buffer);
free(block_unit_buffer);
free(tmp_dst_buffer);
free(tmp_out);
delete input_data;
delete weight_data;
delete conv_param;
delete[] fp16_weight_data;
delete[] fp16_input_data;
delete[] fp32_output_data;
delete[] output_data;
delete[] transformed_weight_data;
MS_LOG(INFO) << "TestConvolutionFp16 Conv3x3 passed";
}

} // namespace mindspore

Loading…
Cancel
Save