|
|
|
@@ -18,6 +18,7 @@ |
|
|
|
#include "nnacl/common_func.h" |
|
|
|
#include "nnacl/fp32/common_func_fp32.h" |
|
|
|
#include "nnacl/fp32/winograd_transform.h" |
|
|
|
#include "nnacl/intrinsics/ms_simd_instructions.h" |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
#include <arm_neon.h> |
|
|
|
#endif |
|
|
|
@@ -337,260 +338,373 @@ bool CheckConvDwUse3X3(const ConvParameter *conv_param) { |
|
|
|
in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3BorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, |
|
|
|
int in_kh_step, int in_kw_step, int channel, bool relu, bool relu6) { |
|
|
|
for (int c = 0; c < channel; c += C4NUM) { |
|
|
|
for (int i = 0; i < C4NUM; i++) { |
|
|
|
dst[i] = 0; |
|
|
|
} |
|
|
|
const float *src_kh = src; |
|
|
|
const float *weight_kh = weight; |
|
|
|
for (int kh = 0; kh < height; kh++) { |
|
|
|
const float *src_kw = src_kh; |
|
|
|
const float *weight_kw = weight_kh; |
|
|
|
for (int kw = 0; kw < width; kw++) { |
|
|
|
for (int i = 0; i < C4NUM; i++) { |
|
|
|
dst[i] += src_kw[c + i] * weight_kw[c + i]; |
|
|
|
} |
|
|
|
src_kw += in_kw_step; |
|
|
|
weight_kw += channel; |
|
|
|
} // kernel_w loop |
|
|
|
src_kh += in_kh_step; |
|
|
|
weight_kh += 3 * channel; |
|
|
|
} // kernel_h loop |
|
|
|
for (int i = 0; i < C4NUM; i++) { |
|
|
|
dst[i] += bias[c + i]; |
|
|
|
dst[i] = (relu) ? (MSMAX(0, dst[i])) : (dst[i]); |
|
|
|
dst[i] = (relu6) ? (MSMIN(6, MSMAX(0, dst[i]))) : (dst[i]); |
|
|
|
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) |
|
|
|
bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) { |
|
|
|
return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 && |
|
|
|
conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && |
|
|
|
conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 && |
|
|
|
conv_param->input_channel_ == conv_param->output_channel_ && |
|
|
|
conv_param->output_h_ / thread_num >= 4; // better had more than 4 rows for each thread |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) { |
|
|
|
MS_FLOAT32X4 v0, v1, v2, v3; |
|
|
|
v0 = MS_MOVQ_F32(0.0f); |
|
|
|
int ic = 0; |
|
|
|
for (; ic < channel - 3; ic += 4) { |
|
|
|
v1 = MS_LDQ_F32(src + ic); |
|
|
|
v2 = MS_LDQ_F32(src + channel + ic); |
|
|
|
v3 = MS_LDQ_F32(src + 2 * channel + ic); |
|
|
|
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); |
|
|
|
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); |
|
|
|
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); |
|
|
|
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); |
|
|
|
MS_STQ_F32(line + lw * ic, b0); |
|
|
|
MS_STQ_F32(line + lw * ic + 4, b1); |
|
|
|
MS_STQ_F32(line + lw * ic + 8, b2); |
|
|
|
MS_STQ_F32(line + lw * ic + 12, b3); |
|
|
|
} |
|
|
|
if (ic < channel) { |
|
|
|
float *remain_line = line + ic * lw; |
|
|
|
memset(remain_line, 0, 16); |
|
|
|
memset(remain_line + 4, 0, 16); |
|
|
|
memset(remain_line + 8, 0, 16); |
|
|
|
memset(remain_line + 12, 0, 16); |
|
|
|
for (int i = 0; i < channel - ic; i++) { |
|
|
|
float d1 = src[i + ic]; |
|
|
|
float d2 = src[i + ic + channel]; |
|
|
|
float d3 = src[i + ic + 2 * channel]; |
|
|
|
remain_line[i] = 0.0f - d2; |
|
|
|
remain_line[i + 4] = d1 + d2; |
|
|
|
remain_line[i + 8] = d2 - d1; |
|
|
|
remain_line[i + 12] = d3 - d1; |
|
|
|
} |
|
|
|
dst += C4NUM; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, |
|
|
|
int in_kw_step, int channel, bool relu, bool relu6) { |
|
|
|
ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, relu, relu6); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, |
|
|
|
int in_kw_step, int channel, bool relu, bool relu6) { |
|
|
|
ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, relu, relu6); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, |
|
|
|
int in_kw_step, int channel, bool relu, bool relu6) { |
|
|
|
ConvDw3x3BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, relu, relu6); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, |
|
|
|
const ConvParameter *conv_param, const SlidingWindowParam *sliding) { |
|
|
|
int input_row_size = conv_param->input_w_ * conv_param->input_channel_; |
|
|
|
int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; |
|
|
|
int output_row_size = conv_param->output_w_ * conv_param->output_channel_; |
|
|
|
int in_kh_step = sliding->in_kh_step_; |
|
|
|
int in_kw_step = sliding->in_kw_step_; |
|
|
|
bool relu = conv_param->act_type_ == ActType_Relu; |
|
|
|
bool relu6 = conv_param->act_type_ == ActType_Relu6; |
|
|
|
|
|
|
|
for (int b = 0; b < conv_param->output_batch_; b++) { |
|
|
|
const float *input_batch = |
|
|
|
input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; |
|
|
|
float *output_batch = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; |
|
|
|
// top |
|
|
|
const float *input = input_batch; |
|
|
|
const float *weight = weight_data + weight_row_size + conv_param->input_channel_; |
|
|
|
float *output = output_batch; |
|
|
|
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); |
|
|
|
input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; |
|
|
|
weight = weight_data + weight_row_size; |
|
|
|
output += conv_param->output_channel_; |
|
|
|
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { |
|
|
|
ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, |
|
|
|
relu6); |
|
|
|
input += conv_param->stride_w_ * conv_param->input_channel_; |
|
|
|
output += conv_param->output_channel_; |
|
|
|
} |
|
|
|
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); |
|
|
|
|
|
|
|
// left |
|
|
|
input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; |
|
|
|
weight = weight_data + conv_param->input_channel_; |
|
|
|
output = output_batch + output_row_size; |
|
|
|
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { |
|
|
|
ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, |
|
|
|
relu6); |
|
|
|
input += conv_param->stride_h_ * input_row_size; |
|
|
|
output += output_row_size; |
|
|
|
void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) { |
|
|
|
MS_FLOAT32X4 v0, v1, v2, v3; |
|
|
|
int ic = 0; |
|
|
|
for (; ic < channel - 3; ic += 4) { |
|
|
|
v0 = MS_LDQ_F32(src + ic); |
|
|
|
v1 = MS_LDQ_F32(src + channel + ic); |
|
|
|
v2 = MS_LDQ_F32(src + 2 * channel + ic); |
|
|
|
v3 = MS_LDQ_F32(src + 3 * channel + ic); |
|
|
|
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); |
|
|
|
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); |
|
|
|
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); |
|
|
|
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); |
|
|
|
MS_STQ_F32(line + lw * ic, b0); |
|
|
|
MS_STQ_F32(line + lw * ic + 4, b1); |
|
|
|
MS_STQ_F32(line + lw * ic + 8, b2); |
|
|
|
MS_STQ_F32(line + lw * ic + 12, b3); |
|
|
|
} |
|
|
|
if (ic < channel) { |
|
|
|
float *remain_line = line + ic * lw; |
|
|
|
memset(remain_line, 0, 16); |
|
|
|
memset(remain_line + 4, 0, 16); |
|
|
|
memset(remain_line + 8, 0, 16); |
|
|
|
memset(remain_line + 12, 0, 16); |
|
|
|
for (int i = 0; i < channel - ic; i++) { |
|
|
|
float d0 = src[i + ic]; |
|
|
|
float d1 = src[i + ic + channel]; |
|
|
|
float d2 = src[i + ic + 2 * channel]; |
|
|
|
float d3 = src[i + ic + 3 * channel]; |
|
|
|
remain_line[i] = d0 - d2; |
|
|
|
remain_line[i + 4] = d1 + d2; |
|
|
|
remain_line[i + 8] = d2 - d1; |
|
|
|
remain_line[i + 12] = d3 - d1; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// right |
|
|
|
input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + |
|
|
|
(conv_param->stride_h_ - 1) * input_row_size; |
|
|
|
weight = weight_data; |
|
|
|
output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; |
|
|
|
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { |
|
|
|
ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, |
|
|
|
relu6); |
|
|
|
input += conv_param->stride_h_ * input_row_size; |
|
|
|
output += output_row_size; |
|
|
|
void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) { |
|
|
|
MS_FLOAT32X4 v0, v1, v2, v3; |
|
|
|
int ic = 0; |
|
|
|
v3 = MS_MOVQ_F32(0.0f); |
|
|
|
for (; ic < channel - 3; ic += 4) { |
|
|
|
v0 = MS_LDQ_F32(src + ic); |
|
|
|
v1 = MS_LDQ_F32(src + channel + ic); |
|
|
|
v2 = MS_LDQ_F32(src + 2 * channel + ic); |
|
|
|
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); |
|
|
|
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); |
|
|
|
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); |
|
|
|
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); |
|
|
|
MS_STQ_F32(line + lw * ic, b0); |
|
|
|
MS_STQ_F32(line + lw * ic + 4, b1); |
|
|
|
MS_STQ_F32(line + lw * ic + 8, b2); |
|
|
|
MS_STQ_F32(line + lw * ic + 12, b3); |
|
|
|
} |
|
|
|
if (ic < channel) { |
|
|
|
float *remain_line = line + ic * lw; |
|
|
|
memset(remain_line, 0, 16); |
|
|
|
memset(remain_line + 4, 0, 16); |
|
|
|
memset(remain_line + 8, 0, 16); |
|
|
|
memset(remain_line + 12, 0, 16); |
|
|
|
for (int i = 0; i < channel - ic; i++) { |
|
|
|
float d0 = src[i + ic]; |
|
|
|
float d1 = src[i + ic + channel]; |
|
|
|
float d2 = src[i + ic + 2 * channel]; |
|
|
|
remain_line[i] = d0 - d2; |
|
|
|
remain_line[i + 4] = d1 + d2; |
|
|
|
remain_line[i + 8] = d2 - d1; |
|
|
|
remain_line[i + 12] = 0.0f - d1; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// bottom |
|
|
|
input = input_batch + (conv_param->input_h_ - 2) * input_row_size; |
|
|
|
weight = weight_data + conv_param->input_channel_; |
|
|
|
output = output_batch + (conv_param->output_h_ - 1) * output_row_size; |
|
|
|
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); |
|
|
|
input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; |
|
|
|
weight = weight_data; |
|
|
|
output += conv_param->output_channel_; |
|
|
|
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { |
|
|
|
ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, |
|
|
|
relu6); |
|
|
|
input += conv_param->stride_w_ * conv_param->input_channel_; |
|
|
|
output += conv_param->output_channel_; |
|
|
|
void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) { |
|
|
|
MS_FLOAT32X4 v0, v1, v2; |
|
|
|
int ic = 0; |
|
|
|
v2 = MS_MOVQ_F32(0.0f); |
|
|
|
for (; ic < channel - 3; ic += 4) { |
|
|
|
v0 = MS_LDQ_F32(src + ic); |
|
|
|
v1 = MS_LDQ_F32(src + channel + ic); |
|
|
|
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); |
|
|
|
MS_STQ_F32(line + lw * ic, v0); |
|
|
|
MS_STQ_F32(line + lw * ic + 4, v1); |
|
|
|
MS_STQ_F32(line + lw * ic + 8, b2); |
|
|
|
memset(line + lw * ic + 12, 0, 16); |
|
|
|
} |
|
|
|
if (ic < channel) { |
|
|
|
float *remain_line = line + ic * lw; |
|
|
|
memset(remain_line, 0, 16); |
|
|
|
memset(remain_line + 4, 0, 16); |
|
|
|
memset(remain_line + 8, 0, 16); |
|
|
|
memset(remain_line + 12, 0, 16); |
|
|
|
for (int i = 0; i < channel - ic; i++) { |
|
|
|
float d0 = src[i + ic]; |
|
|
|
float d1 = src[i + ic + channel]; |
|
|
|
remain_line[i] = d0; |
|
|
|
remain_line[i + 4] = d1; |
|
|
|
remain_line[i + 8] = 0.0f - d1; |
|
|
|
} |
|
|
|
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3InitBuffer(float *buffer, const float *input, const ConvParameter *conv_param, int block_input_h, |
|
|
|
int block_input_w) { |
|
|
|
for (int h = 0; h < block_input_h; h++) { |
|
|
|
const float *src = input; |
|
|
|
for (int w = 0; w < block_input_w; w++) { |
|
|
|
memcpy(buffer, src, 64 * sizeof(float)); |
|
|
|
src += conv_param->input_channel_; |
|
|
|
buffer += 64; |
|
|
|
} |
|
|
|
input += conv_param->input_w_ * conv_param->input_channel_; |
|
|
|
void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) { |
|
|
|
float *line0 = lines[0]; |
|
|
|
float *line1 = lines[1]; |
|
|
|
float *line2 = lines[2]; |
|
|
|
int c4 = UP_ROUND(channel, C4NUM); |
|
|
|
int lw = UP_DIV(width, C2NUM) * C4NUM; |
|
|
|
memset(line0, 0, c4 * lw * sizeof(float)); |
|
|
|
ConvDw3x3RowLeft(src, line1, lw, channel); |
|
|
|
ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); |
|
|
|
int ow = 2; |
|
|
|
for (; ow < width - 2; ow += 2) { |
|
|
|
ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); |
|
|
|
} |
|
|
|
int remain = width - ow; |
|
|
|
if (remain == 2) { |
|
|
|
ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); |
|
|
|
} else if (remain == 1) { |
|
|
|
ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3Window(float *output, const float *buffer, const float *weight, const float *bias, int col_size, |
|
|
|
int row_size, int channel, int output_h, int output_w, int stride, bool relu, bool relu6) { |
|
|
|
for (int w = 0; w < output_w; w++) { |
|
|
|
for (int i = 0; i < C4NUM; i++) { |
|
|
|
output[i] = bias[i]; |
|
|
|
} |
|
|
|
const float *src_kh = buffer; |
|
|
|
const float *weight_kh = weight; |
|
|
|
for (int kh = 0; kh < 3; kh++) { |
|
|
|
const float *src_kw = src_kh; |
|
|
|
const float *weight_kw = weight_kh; |
|
|
|
for (int kw = 0; kw < 3; kw++) { |
|
|
|
for (int c = 0; c < C4NUM; c++) { |
|
|
|
output[c] += src_kw[c] * weight_kw[c]; |
|
|
|
} |
|
|
|
src_kw += col_size; |
|
|
|
weight_kw += channel; |
|
|
|
} |
|
|
|
src_kh += row_size; |
|
|
|
weight_kh += 3 * channel; |
|
|
|
} |
|
|
|
for (int i = 0; i < C4NUM; i++) { |
|
|
|
output[i] = (relu) ? (MSMAX(0, output[i])) : (output[i]); |
|
|
|
output[i] = (relu6) ? (MSMIN(6, MSMAX(0, output[i]))) : (output[i]); |
|
|
|
} |
|
|
|
output += channel; |
|
|
|
buffer += col_size * stride; |
|
|
|
void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) { |
|
|
|
float *line0 = lines[0]; |
|
|
|
float *line1 = lines[1]; |
|
|
|
float *line2 = lines[2]; |
|
|
|
int lw = UP_DIV(width, C2NUM) * C4NUM; |
|
|
|
ConvDw3x3RowLeft(src - width * channel, line0, lw, channel); |
|
|
|
ConvDw3x3RowLeft(src, line1, lw, channel); |
|
|
|
ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); |
|
|
|
int ow = 2; |
|
|
|
for (; ow < width - 2; ow += 2) { |
|
|
|
ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); |
|
|
|
} |
|
|
|
int remain = width - ow; |
|
|
|
if (remain == 2) { |
|
|
|
ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); |
|
|
|
} else if (remain == 1) { |
|
|
|
ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); |
|
|
|
ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3Block(float *output, const float *buffer, const float *weight, const float *bias, int start_c, int end_c, |
|
|
|
int col_size, int row_size, int channel, int output_h, int output_w, int stride, bool relu, |
|
|
|
bool relu6) { |
|
|
|
for (; start_c <= end_c - C4NUM; start_c += C4NUM) { |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
if (stride == 1) { |
|
|
|
ConvDw3x3Stride1(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6); |
|
|
|
} else { |
|
|
|
ConvDw3x3Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6); |
|
|
|
} |
|
|
|
#else |
|
|
|
ConvDw3x3Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, stride, relu, relu6); |
|
|
|
#endif |
|
|
|
output += C4NUM; |
|
|
|
buffer += C4NUM; |
|
|
|
weight += C4NUM; |
|
|
|
bias += C4NUM; |
|
|
|
void ConvDw3x3Row(const float *src, float **lines, int width, int channel) { |
|
|
|
float *tmp = lines[0]; |
|
|
|
lines[0] = lines[1]; |
|
|
|
lines[1] = lines[2]; |
|
|
|
lines[2] = tmp; |
|
|
|
int c4 = UP_ROUND(channel, C4NUM); |
|
|
|
int lw = UP_DIV(width, C2NUM) * C4NUM; |
|
|
|
memset(tmp, 0, c4 * lw * sizeof(float)); |
|
|
|
ConvDw3x3RowLeft(src, tmp, lw, channel); |
|
|
|
int ow = 2; |
|
|
|
for (; ow < width - 2; ow += 2) { |
|
|
|
ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); |
|
|
|
} |
|
|
|
int remain = width - ow; |
|
|
|
if (remain == 2) { |
|
|
|
ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); |
|
|
|
} else if (remain == 1) { |
|
|
|
ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3Row(float *output, float *buffer, const float *input, const float *weight, const float *bias, |
|
|
|
const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w, |
|
|
|
int block_input_h, int block_input_w) { |
|
|
|
bool relu = conv_param->act_type_ == ActType_Relu; |
|
|
|
bool relu6 = conv_param->act_type_ == ActType_Relu6; |
|
|
|
const int ih_offset = 64 * block_input_w; |
|
|
|
int w = start_w; |
|
|
|
if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { |
|
|
|
for (; w <= end_w - block_output_w; w += block_output_w) { |
|
|
|
float *output_ptr = output; |
|
|
|
const float *input_ptr = input; |
|
|
|
const float *weight_ptr = weight; |
|
|
|
const float *bias_ptr = bias; |
|
|
|
int c = 0; |
|
|
|
for (; c <= conv_param->output_channel_ - 64; c += 64) { |
|
|
|
ConvDw3x3InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); |
|
|
|
ConvDw3x3Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, |
|
|
|
block_output_h, block_output_w, conv_param->stride_h_, relu, relu6); |
|
|
|
output_ptr += 64; |
|
|
|
input_ptr += 64; |
|
|
|
weight_ptr += 64; |
|
|
|
bias_ptr += 64; |
|
|
|
void ConvDw3x3Bottom(float **lines, int width, int channel) { |
|
|
|
float *tmp = lines[0]; |
|
|
|
lines[0] = lines[1]; |
|
|
|
lines[1] = lines[2]; |
|
|
|
lines[2] = tmp; |
|
|
|
int c4 = UP_ROUND(channel, C4NUM); |
|
|
|
memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float)); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, |
|
|
|
bool relu, bool relu6) { |
|
|
|
int channel = ori_channel; |
|
|
|
float *line0 = lines[0]; |
|
|
|
float *line1 = lines[1]; |
|
|
|
float *line2 = lines[2]; |
|
|
|
for (; channel > 0; channel -= 4) { |
|
|
|
MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data); |
|
|
|
bias_data += 4; |
|
|
|
MS_FLOAT32X4 g00 = MS_LDQ_F32(weight); |
|
|
|
MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4); |
|
|
|
MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8); |
|
|
|
MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12); |
|
|
|
MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16); |
|
|
|
MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20); |
|
|
|
MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24); |
|
|
|
MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28); |
|
|
|
MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32); |
|
|
|
MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36); |
|
|
|
MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40); |
|
|
|
MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44); |
|
|
|
weight += 48; |
|
|
|
float *cur_dst = dst; |
|
|
|
int ow = 0; |
|
|
|
for (; ow < width - 1; ow += 2) { |
|
|
|
MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); |
|
|
|
MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); |
|
|
|
MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); |
|
|
|
MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03); |
|
|
|
line0 += 16; |
|
|
|
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); |
|
|
|
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); |
|
|
|
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); |
|
|
|
acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13); |
|
|
|
line1 += 16; |
|
|
|
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); |
|
|
|
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); |
|
|
|
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); |
|
|
|
acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23); |
|
|
|
line2 += 16; |
|
|
|
MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); |
|
|
|
MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2)); |
|
|
|
res0 = MS_ADDQ_F32(res0, bias); |
|
|
|
res1 = MS_ADDQ_F32(res1, bias); |
|
|
|
if (relu || relu6) { |
|
|
|
res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); |
|
|
|
res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f)); |
|
|
|
} |
|
|
|
if (relu6) { |
|
|
|
res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); |
|
|
|
res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f)); |
|
|
|
} |
|
|
|
// left channel |
|
|
|
ConvDw3x3Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, |
|
|
|
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, |
|
|
|
conv_param->input_channel_, block_output_h, block_output_w, conv_param->stride_h_, relu, relu6); |
|
|
|
output += block_output_w * conv_param->input_channel_; |
|
|
|
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; |
|
|
|
if (channel >= 4) { |
|
|
|
MS_STQ_F32(cur_dst, res0); |
|
|
|
MS_STQ_F32(cur_dst + ori_channel, res1); |
|
|
|
} else { |
|
|
|
for (int i = 0; i < channel; i++) { |
|
|
|
cur_dst[i] = res0[i]; |
|
|
|
cur_dst[ori_channel + i] = res1[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
cur_dst += 2 * ori_channel; |
|
|
|
} |
|
|
|
} |
|
|
|
// left width |
|
|
|
int left_width = end_w - w; |
|
|
|
if (left_width > 0) { |
|
|
|
ConvDw3x3Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_, |
|
|
|
conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h, |
|
|
|
left_width, conv_param->stride_h_, relu, relu6); |
|
|
|
if (ow < width) { |
|
|
|
MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); |
|
|
|
MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); |
|
|
|
MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); |
|
|
|
line0 += 16; |
|
|
|
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); |
|
|
|
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); |
|
|
|
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); |
|
|
|
line1 += 16; |
|
|
|
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); |
|
|
|
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); |
|
|
|
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); |
|
|
|
line2 += 16; |
|
|
|
MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); |
|
|
|
res0 = MS_ADDQ_F32(res0, bias); |
|
|
|
if (relu || relu6) { |
|
|
|
res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); |
|
|
|
} |
|
|
|
if (relu6) { |
|
|
|
res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); |
|
|
|
} |
|
|
|
if (channel >= 4) { |
|
|
|
MS_STQ_F32(cur_dst, res0); |
|
|
|
} else { |
|
|
|
for (int i = 0; i < channel; i++) { |
|
|
|
cur_dst[i] = res0[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
dst += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, |
|
|
|
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, |
|
|
|
int task_id) { |
|
|
|
int output_h = sliding->bottom_ - sliding->top_; |
|
|
|
int step_oh = UP_DIV(output_h, conv_param->thread_num_); |
|
|
|
int start_oh = step_oh * task_id + sliding->top_; |
|
|
|
int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_); |
|
|
|
int start_ow = sliding->left_; |
|
|
|
int end_ow = sliding->right_; |
|
|
|
|
|
|
|
const int block_output_h = 1; |
|
|
|
int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14; |
|
|
|
const int block_input_h = 3; |
|
|
|
int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3; |
|
|
|
const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { |
|
|
|
int units = UP_DIV(conv_param->output_w_, C2NUM); |
|
|
|
int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); |
|
|
|
int line = conv_param->input_channel_ * conv_param->input_w_; |
|
|
|
|
|
|
|
bool relu = conv_param->act_type_ == ActType_Relu; |
|
|
|
bool relu6 = conv_param->act_type_ == ActType_Relu6; |
|
|
|
|
|
|
|
for (int b = 0; b < conv_param->output_batch_; b++) { |
|
|
|
int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_; |
|
|
|
int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_; |
|
|
|
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ + |
|
|
|
start_ih * conv_param->input_w_ * conv_param->input_channel_ + |
|
|
|
start_iw * conv_param->input_channel_; |
|
|
|
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ + |
|
|
|
start_oh * conv_param->output_w_ * conv_param->output_channel_ + |
|
|
|
start_ow * conv_param->output_channel_; |
|
|
|
|
|
|
|
for (int oh = start_oh; oh < end_oh; oh++) { |
|
|
|
ConvDw3x3Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h, |
|
|
|
block_output_w, block_input_h, block_input_w); |
|
|
|
src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_; |
|
|
|
dst += conv_param->output_w_ * conv_param->output_channel_; |
|
|
|
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; |
|
|
|
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; |
|
|
|
float *line0 = buffer; |
|
|
|
float *line1 = buffer + units * c4 * C4NUM; |
|
|
|
float *line2 = buffer + units * c4 * C8NUM; |
|
|
|
float *lines[3] = {line0, line1, line2}; |
|
|
|
int oh = start_oh; |
|
|
|
if (oh == 0) { |
|
|
|
// input trans |
|
|
|
ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_); |
|
|
|
} else { |
|
|
|
// input trans |
|
|
|
ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); |
|
|
|
} |
|
|
|
// dst calc and trans |
|
|
|
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, |
|
|
|
relu, relu6); |
|
|
|
for (oh = start_oh + 1; oh < end_oh - 1; oh++) { |
|
|
|
// input trans |
|
|
|
ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); |
|
|
|
// dst calc and trans |
|
|
|
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, |
|
|
|
relu, relu6); |
|
|
|
} |
|
|
|
if (oh == conv_param->output_h_ - 1) { |
|
|
|
// input trans |
|
|
|
ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_); |
|
|
|
} else { |
|
|
|
// input trans |
|
|
|
ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); |
|
|
|
} |
|
|
|
// dst calc and trans |
|
|
|
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, |
|
|
|
relu, relu6); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
/*conv depthwise indirect buffer fp32 begin*/ |
|
|
|
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) { |
|
|
|
|