diff --git a/mindspore/lite/nnacl/fp32/conv.c b/mindspore/lite/nnacl/fp32/conv.c index ca6c547df8..04e3e9115f 100644 --- a/mindspore/lite/nnacl/fp32/conv.c +++ b/mindspore/lite/nnacl/fp32/conv.c @@ -258,9 +258,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons } // fp32 conv winograd -void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list, - int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func, - GEMM_FUNC_FP32 gemm_func) { +void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func, + OutputTransFunc out_func) { int thread_num = conv_param->thread_num_; int input_unit = conv_param->input_unit_; int in_batch = conv_param->input_batch_; @@ -277,13 +277,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ #endif int output_tile_count = UP_DIV(output_count, tile_num); int out_channel = conv_param->output_channel_; - int oc4 = UP_DIV(out_channel, C4NUM); int oc8 = UP_DIV(out_channel, C8NUM); int input_unit_square = input_unit * input_unit; float *trans_input = buffer_list[0]; float *gemm_out = buffer_list[1]; - float *tmp_out_data = buffer_list[2]; float *tmp_data = buffer_list[3]; float *col_buffer = buffer_list[4]; int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM; @@ -294,7 +292,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ // step 2 : input transform (online) for (int b = 0; b < in_batch; b++) { int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; - int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM; + int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { int out_tile_index = thread_id * tile_num; int cal_num = output_count - thread_id * tile_num; @@ -317,8 +315,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ } // step 4 : output transform - WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index, - out_w_block, conv_param, out_func); + float *output_ptr = output_data + out_batch_offset; + WinogradOutputTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + out_func); } } } diff --git a/mindspore/lite/nnacl/fp32/conv.h b/mindspore/lite/nnacl/fp32/conv.h index 7baa37a26a..9b280ed820 100644 --- a/mindspore/lite/nnacl/fp32/conv.h +++ b/mindspore/lite/nnacl/fp32/conv.h @@ -53,9 +53,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons GEMM_FUNC_FP32 gemm_func); // fp32 convolution winograd -void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list, - int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func, - GEMM_FUNC_FP32 gemm_func); +void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func, + OutputTransFunc out_func); void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit); diff --git a/mindspore/lite/nnacl/winograd_transform.c b/mindspore/lite/nnacl/winograd_transform.c index 933e223f64..a298b8d657 100644 --- a/mindspore/lite/nnacl/winograd_transform.c +++ b/mindspore/lite/nnacl/winograd_transform.c @@ -82,13 +82,11 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * } // cal_tile_num loop } -void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num, +void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func) { int output_unit = conv_param->output_unit_; int output_w = conv_param->output_w_; int output_h = conv_param->output_h_; - int output_w_unit_block = UP_DIV(output_w, output_unit); - int output_h_unit_block = UP_DIV(output_h, output_unit); int output_channel = conv_param->output_channel_; int oc4 = UP_DIV(output_channel, C4NUM); int oc8 = UP_DIV(output_channel, C8NUM); @@ -99,19 +97,29 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f for (int i = 0; i < cal_num; i++) { int dst_x_s = out_tile_index % output_unit_num; int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; - int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit); + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); for (int j = 0; j < oc4; j++) { int c8_block = j / 2; int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; - int dst_oc4_offset = - dst_tile_offset + j * C4NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit; + int dst_oc4_offset = dst_tile_offset + j * C4NUM; const float *src_ptr = gemm_out + src_oc4_offset; const float *bias_ptr = bias_data + j * C4NUM; - float *dst_ptr = tmp_out_data + dst_oc4_offset; - func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit); + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); // GeneralOutputTransformUnit(src_ptr, dst_ptr, bias_ptr, matrix_a, matrix_at, C8NUM, // output_w_unit_block * output_unit, input_unit, output_unit); } diff --git a/mindspore/lite/nnacl/winograd_transform.h b/mindspore/lite/nnacl/winograd_transform.h index d8b8914b00..21cfb38ed0 100644 --- a/mindspore/lite/nnacl/winograd_transform.h +++ b/mindspore/lite/nnacl/winograd_transform.h @@ -35,7 +35,7 @@ extern "C" { void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, int out_tile_index, int out_w_block_num, ConvParameter *conv_param, InputTransFunc func); -void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num, +void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func); // for fp32 convolution 3x3 filter/input/output transform diff --git a/mindspore/lite/nnacl/winograd_utils.c b/mindspore/lite/nnacl/winograd_utils.c index 58be4adb92..564d3e4aae 100644 --- a/mindspore/lite/nnacl/winograd_utils.c +++ b/mindspore/lite/nnacl/winograd_utils.c @@ -25,9 +25,25 @@ static InputTransFunc InputTransFuncList[] = { NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit}; static OutputTransFunc OutputTransFuncList4[] = {NULL, NULL, OutputTransform4x2Unit, OutputTransform4x3Unit}; +static OutputTransFunc OutputTransFuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnit, + OutputTransform4x3ReluUnit}; +static OutputTransFunc OutputTransFuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6Unit, + OutputTransform4x3Relu6Unit}; static OutputTransFunc OutputTransFuncList6[] = { NULL, NULL, OutputTransform6x2Unit, OutputTransform6x3Unit, OutputTransform6x4Unit, OutputTransform6x5Unit}; +static OutputTransFunc OutputTransFuncReluList6[] = {NULL, + NULL, + OutputTransform6x2ReluUnit, + OutputTransform6x3ReluUnit, + OutputTransform6x4ReluUnit, + OutputTransform6x5ReluUnit}; +static OutputTransFunc OutputTransFuncRelu6List6[] = {NULL, + NULL, + OutputTransform6x2Relu6Unit, + OutputTransform6x3Relu6Unit, + OutputTransform6x4Relu6Unit, + OutputTransform6x5Relu6Unit}; static OutputTransFunc OutputTransFuncList8[] = {NULL, NULL, @@ -37,8 +53,22 @@ static OutputTransFunc OutputTransFuncList8[] = {NULL, OutputTransform8x5Unit, OutputTransform8x6Unit, OutputTransform8x7Unit}; -// -// static bool InputUnitList[] = {false, false, false, false, true, false, true, false, true}; +static OutputTransFunc OutputTransFuncReluList8[] = {NULL, + NULL, + OutputTransform8x2ReluUnit, + OutputTransform8x3ReluUnit, + OutputTransform8x4ReluUnit, + OutputTransform8x5ReluUnit, + OutputTransform8x6ReluUnit, + OutputTransform8x7ReluUnit}; +static OutputTransFunc OutputTransFuncRelu6List8[] = {NULL, + NULL, + OutputTransform8x2Relu6Unit, + OutputTransform8x3Relu6Unit, + OutputTransform8x4Relu6Unit, + OutputTransform8x5Relu6Unit, + OutputTransform8x6Relu6Unit, + OutputTransform8x7Relu6Unit}; void GeneralInputTransformUnit(const float *src_data, float *dst_data, float *matrix_b, float *matrix_bt, int src_step, int dst_step, int in_unit) { @@ -268,44 +298,45 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, Load64Data; for (int l = 0; l < 8; ++l) { int offset = l * 8; - t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 36), vmulq_n_f32(src[2 + offset], 49)), - vmulq_n_f32(src[4 + offset], 14)), + t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)), + vmulq_n_f32(src[4 + offset], 3.5)), src[6 + offset]); - float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 36), src[5 + offset]); - float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 36), vmulq_n_f32(src[4 + offset], 13)); - t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 13)), src[6 + offset]); - t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 13)), src[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 18), vmulq_n_f32(src[5 + offset], 2)); - tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 9), vmulq_n_f32(src[4 + offset], 10)); - t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 20)), src[6 + offset]); - t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 20)), src[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 12), vmulq_n_f32(src[5 + offset], 3)); - tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 4), vmulq_n_f32(src[4 + offset], 5)); - t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 15)), src[6 + offset]); - t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 15)), src[6 + offset]); - t[56 + l] = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -36), vmulq_n_f32(src[3 + offset], 49)), - vmulq_n_f32(src[5 + offset], 14)), - src[7 + offset]); + float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5)); + float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5)); + tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)), + vmulq_n_f32(src[5 + offset], 3.5)), + src[7 + offset]); } for (int l = 0; l < 8; ++l) { int offset = l * 8; - m[l] = vsubq_f32( - vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 36), vmulq_n_f32(t[2 + offset], 49)), vmulq_n_f32(t[4 + offset], 14)), - t[6 + offset]); - float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 36), t[5 + offset]); - float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 36), vmulq_n_f32(t[4 + offset], 13)); - m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 13)), t[6 + offset]); - m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 13)), t[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 18), vmulq_n_f32(t[5 + offset], 2)); - tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 9), vmulq_n_f32(t[4 + offset], 10)); - m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 20)), t[6 + offset]); - m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 20)), t[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 12), vmulq_n_f32(t[5 + offset], 3)); - tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 4), vmulq_n_f32(t[4 + offset], 5)); - m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 15)), t[6 + offset]); - m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 15)), t[6 + offset]); - m[56 + l] = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -36), vmulq_n_f32(t[3 + offset], 49)), - vmulq_n_f32(t[5 + offset], 14)), + m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)), + vmulq_n_f32(t[4 + offset], 3.5)), + t[6 + offset]); + float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5)); + float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5)); + tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)), + vmulq_n_f32(t[5 + offset], 3.5)), t[7 + offset]); } for (int i = 0; i < 64; i++) { @@ -321,37 +352,37 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, } for (int l = 0; l < 8; ++l) { int offset = l * 8; - t[l] = 36 * src[offset] - 49 * src[2 + offset] + 14 * src[4 + offset] - src[6 + offset]; - float tmp1 = 36 * src[1 + offset] + src[5 + offset]; - float tmp2 = 36 * src[2 + offset] - 13 * src[4 + offset]; - t[8 + l] = tmp1 + tmp2 - 13 * src[3 + offset] + src[6 + offset]; - t[16 + l] = tmp2 - tmp1 + 13 * src[3 + offset] + src[6 + offset]; - tmp1 = 18 * src[1 + offset] + 2 * src[5 + offset]; - tmp2 = 9 * src[2 + offset] - 10 * src[4 + offset]; - t[24 + l] = tmp1 + tmp2 - 20 * src[3 + offset] + src[6 + offset]; - t[32 + l] = tmp2 - tmp1 + 20 * src[3 + offset] + src[6 + offset]; - tmp1 = 12 * src[1 + offset] + 3 * src[5 + offset]; - tmp2 = 4 * src[2 + offset] - 5 * src[4 + offset]; - t[40 + l] = tmp1 + tmp2 - 15 * src[3 + offset] + src[6 + offset]; - t[48 + l] = tmp2 - tmp1 + 15 * src[3 + offset] + src[6 + offset]; - t[56 + l] = -36 * src[1 + offset] + 49 * src[3 + offset] - 14 * src[5 + offset] + src[7 + offset]; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; } for (int l = 0; l < 8; ++l) { int offset = l * 8; - m[l] = 36 * t[offset] - 49 * t[2 + offset] + 14 * t[4 + offset] - t[6 + offset]; - float tmp1 = 36 * t[1 + offset] + t[5 + offset]; - float tmp2 = 36 * t[2 + offset] - 13 * t[4 + offset]; - m[8 + l] = tmp1 + tmp2 - 13 * t[3 + offset] + t[6 + offset]; - m[16 + l] = tmp2 - tmp1 + 13 * t[3 + offset] + t[6 + offset]; - tmp1 = 18 * t[1 + offset] + 2 * t[5 + offset]; - tmp2 = 9 * t[2 + offset] - 10 * t[4 + offset]; - m[24 + l] = tmp1 + tmp2 - 20 * t[3 + offset] + t[6 + offset]; - m[32 + l] = tmp2 - tmp1 + 20 * t[3 + offset] + t[6 + offset]; - tmp1 = 12 * t[1 + offset] + 3 * t[5 + offset]; - tmp2 = 4 * t[2 + offset] - 5 * t[4 + offset]; - m[40 + l] = tmp1 + tmp2 - 15 * t[3 + offset] + t[6 + offset]; - m[48 + l] = tmp2 - tmp1 + 15 * t[3 + offset] + t[6 + offset]; - m[56 + l] = -36 * t[1 + offset] + 49 * t[3 + offset] - 14 * t[5 + offset] + t[7 + offset]; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; } for (int k = 0; k < 64; ++k) { dst_data[i + k * dst_step] = m[k]; @@ -360,20 +391,38 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, #endif } -OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit) { +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) { if (input_unit == 4 && output_unit < 4) { + if (act_type == ActType_Relu) { + return OutputTransFuncReluList4[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFuncRelu6List4[output_unit]; + } else { return OutputTransFuncList4[output_unit]; + } } else if (input_unit == 6 && output_unit < 6) { + if (act_type == ActType_Relu) { + return OutputTransFuncReluList6[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFuncRelu6List6[output_unit]; + } else { return OutputTransFuncList6[output_unit]; + } } else if (input_unit == 8 && output_unit < 8) { + if (act_type == ActType_Relu) { + return OutputTransFuncReluList8[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFuncRelu6List8[output_unit]; + } else { return OutputTransFuncList8[output_unit]; + } } else { return NULL; } } -void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[16]; float32x4_t t[8]; @@ -390,12 +439,24 @@ void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); m[l + 2] = vaddq_f32(vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); } - Store4Data; + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } #else float src[16]; float t[8]; float m[4]; - for (int i = 0; i < C4NUM; ++i) { + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 16; ++j) { src[j] = src_data[i + j * src_step]; @@ -411,19 +472,157 @@ void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; } // store output - for (int k = 0; k < 2; ++k) { - int dst_k_offset = k * dst_step * C4NUM; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; int m_k_offset = k * 2; - for (int j = 0; j < 2; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[16]; + float32x4_t t[8]; + float32x4_t m[4]; + float32x4_t zero = vdupq_n_f32(0); + Load16Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f32(vaddq_f32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f32(vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 2] = vmaxq_f32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[16]; + float32x4_t t[8]; + float32x4_t m[4]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load16Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f32(vaddq_f32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f32(vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 2] = vmaxq_f32(zero, m[l + 2]); + m[l + 2] = vminq_f32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[16]; float32x4_t t[12]; @@ -444,12 +643,24 @@ void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float m[l + 3] = vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), bias_ptr); m[l + 6] = vaddq_f32(vaddq_f32(tmp, t[3 + offset]), bias_ptr); } - Store9Data; + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } #else float src[16]; float t[12]; float m[9]; - for (int i = 0; i < C4NUM; ++i) { + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 16; ++j) { src[j] = src_data[i + j * src_step]; @@ -467,19 +678,172 @@ void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; } // store output - for (int k = 0; k < 3; ++k) { - int dst_k_offset = k * dst_step * C4NUM; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; int m_k_offset = k * 3; - for (int j = 0; j < 3; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[16]; + float32x4_t t[12]; + float32x4_t m[9]; + float32x4_t zero = vdupq_n_f32(0); + Load16Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float32x4_t tmp = vaddq_f32(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f32(src[offset], tmp); + t[l + 4] = vsubq_f32(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float32x4_t tmp = vaddq_f32(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f32(vaddq_f32(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 3] = vmaxq_f32(zero, m[l + 3]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[16]; + float32x4_t t[12]; + float32x4_t m[9]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load16Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float32x4_t tmp = vaddq_f32(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f32(src[offset], tmp); + t[l + 4] = vsubq_f32(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float32x4_t tmp = vaddq_f32(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f32(vaddq_f32(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 3] = vmaxq_f32(zero, m[l + 3]); + m[l + 3] = vminq_f32(six, m[l + 3]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + m[l + 6] = vminq_f32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[36]; float32x4_t t[12]; @@ -504,12 +868,24 @@ void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float t[5 + offset]), bias_ptr); } - Store4Data; + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } #else float src[36]; float t[12]; float m[4]; - for (int i = 0; i < C4NUM; ++i) { + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 36; ++j) { src[j] = src_data[i + j * src_step]; @@ -525,49 +901,216 @@ void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; } // store output - for (int k = 0; k < 2; ++k) { - int dst_k_offset = k * dst_step * C4NUM; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; int m_k_offset = k * 2; - for (int j = 0; j < 2; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; } } } #endif } -void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[36]; - float32x4_t t[18]; - float32x4_t m[9]; + float32x4_t t[12]; + float32x4_t m[4]; + float32x4_t zero = vdupq_n_f32(0); Load36Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 6; ++l) { int offset = l * 6; - float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); - float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); - t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); - t[l + 6] = vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), - vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)); - t[l + 12] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), src[5 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f32(vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), + vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); } - for (int l = 0; l < 3; ++l) { + for (int l = 0; l < 2; ++l) { int offset = l * 6; - float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); - float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); - m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 3] = vaddq_f32( - vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), + m[l] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), bias_ptr); - m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l + 2] = vaddq_f32(vaddq_f32(vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), + vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 2] = vmaxq_f32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } - Store9Data; #else float src[36]; - float t[18]; - float m[9]; - for (int i = 0; i < C4NUM; ++i) { + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[12]; + float32x4_t m[4]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f32(vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), + vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f32(vaddq_f32(vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), + vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 2] = vmaxq_f32(zero, m[l + 2]); + m[l + 2] = vminq_f32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[18]; + float32x4_t m[9]; + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), + vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f32( + vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 36; ++j) { src[j] = src_data[i + j * src_step]; @@ -585,18 +1128,182 @@ void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; } // store output - for (int k = 0; k < 3; ++k) { - int dst_k_offset = k * dst_step * C4NUM; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; int m_k_offset = k * 3; - for (int j = 0; j < 3; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[18]; + float32x4_t m[9]; + float32x4_t zero = vdupq_n_f32(0); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), + vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f32( + vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 3] = vmaxq_f32(zero, m[l + 3]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[18]; + float32x4_t m[9]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), + vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f32( + vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 3] = vmaxq_f32(zero, m[l + 3]); + m[l + 3] = vminq_f32(six, m[l + 3]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + m[l + 6] = vminq_f32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[36]; float32x4_t t[24]; @@ -625,152 +1332,1700 @@ void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float m[l + 8] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), t[5 + offset]), bias_ptr); } - Store16Data; + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } #else float src[36]; float t[24]; float m[16]; - for (int i = 0; i < C4NUM; ++i) { + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[24]; + float32x4_t m[16]; + float32x4_t zero = vdupq_n_f32(0); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp4 = vsubq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)); + t[l + 12] = vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)); + t[l + 18] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp4 = vsubq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 4] = vmaxq_f32(zero, m[l + 4]); + m[l + 8] = vmaxq_f32(zero, m[l + 8]); + m[l + 12] = vmaxq_f32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[24]; + float32x4_t m[16]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp4 = vsubq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)); + t[l + 12] = vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)); + t[l + 18] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp4 = vsubq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 4] = vmaxq_f32(zero, m[l + 4]); + m[l + 4] = vminq_f32(six, m[l + 4]); + m[l + 8] = vmaxq_f32(zero, m[l + 8]); + m[l + 8] = vminq_f32(six, m[l + 8]); + m[l + 12] = vmaxq_f32(zero, m[l + 12]); + m[l + 12] = vminq_f32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[30]; + float32x4_t m[25]; + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp4 = vsubq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)); + t[l + 12] = vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)); + t[l + 18] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)); + t[l + 24] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp4 = vsubq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[30]; + float32x4_t m[25]; + float32x4_t zero = vdupq_n_f32(0); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp4 = vsubq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)); + t[l + 12] = vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)); + t[l + 18] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)); + t[l + 24] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp4 = vsubq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 5] = vmaxq_f32(zero, m[l + 5]); + m[l + 10] = vmaxq_f32(zero, m[l + 10]); + m[l + 15] = vmaxq_f32(zero, m[l + 15]); + m[l + 20] = vmaxq_f32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[36]; + float32x4_t t[30]; + float32x4_t m[25]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load36Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp4 = vsubq_f32(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)); + t[l + 12] = vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)); + t[l + 18] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)); + t[l + 24] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp4 = vsubq_f32(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 5] = vmaxq_f32(zero, m[l + 5]); + m[l + 5] = vminq_f32(six, m[l + 5]); + m[l + 10] = vmaxq_f32(zero, m[l + 10]); + m[l + 10] = vminq_f32(six, m[l + 10]); + m[l + 15] = vmaxq_f32(zero, m[l + 15]); + m[l + 15] = vminq_f32(six, m[l + 15]); + m[l + 20] = vmaxq_f32(zero, m[l + 20]); + m[l + 20] = vminq_f32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[16]; + float32x4_t m[4]; + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), t[7 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[16]; + float32x4_t m[4]; + float32x4_t zero = vdupq_n_f32(0); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 2] = vmaxq_f32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[16]; + float32x4_t m[4]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 2] = vmaxq_f32(zero, m[l + 2]); + m[l + 2] = vminq_f32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[24]; + float32x4_t m[9]; + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), t[7 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[24]; + float32x4_t m[9]; + float32x4_t zero = vdupq_n_f32(0); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 3] = vmaxq_f32(zero, m[l + 3]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[24]; + float32x4_t m[9]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 3] = vmaxq_f32(zero, m[l + 3]); + m[l + 3] = vminq_f32(six, m[l + 3]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + m[l + 6] = vminq_f32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[32]; + float32x4_t m[16]; + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[32]; + float32x4_t m[16]; + float32x4_t zero = vdupq_n_f32(0); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 4] = vmaxq_f32(zero, m[l + 4]); + m[l + 8] = vmaxq_f32(zero, m[l + 8]); + m[l + 12] = vmaxq_f32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[32]; + float32x4_t m[16]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 4] = vmaxq_f32(zero, m[l + 4]); + m[l + 4] = vminq_f32(six, m[l + 4]); + m[l + 8] = vmaxq_f32(zero, m[l + 8]); + m[l + 8] = vminq_f32(six, m[l + 8]); + m[l + 12] = vmaxq_f32(zero, m[l + 12]); + m[l + 12] = vminq_f32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[40]; + float32x4_t m[25]; + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#ifdef ENABLE_ARM + float32x4_t src[64]; + float32x4_t t[40]; + float32x4_t m[25]; + float32x4_t zero = vdupq_n_f32(0); + Load64Data; + float32x4_t bias_ptr = vld1q_f32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 5] = vmaxq_f32(zero, m[l + 5]); + m[l + 10] = vmaxq_f32(zero, m[l + 10]); + m[l + 15] = vmaxq_f32(zero, m[l + 15]); + m[l + 20] = vmaxq_f32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +#else + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { // load source data - for (int j = 0; j < 36; ++j) { + for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; } - for (int l = 0; l < 6; ++l) { - int offset = l * 6; - t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; - t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); - t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); - t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; } - for (int l = 0; l < 4; ++l) { - int offset = l * 6; - m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; - m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); - m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); - m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 4; ++k) { - int dst_k_offset = k * dst_step * C4NUM; - int m_k_offset = k * 4; - for (int j = 0; j < 4; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM - float32x4_t src[36]; - float32x4_t t[30]; + float32x4_t src[64]; + float32x4_t t[40]; float32x4_t m[25]; - Load36Data; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); + Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); - for (int l = 0; l < 6; ++l) { - int offset = l * 6; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); - float32x4_t tmp3 = vsubq_f32(src[1 + offset], src[2 + offset]); - float32x4_t tmp4 = vsubq_f32(src[3 + offset], src[4 + offset]); - t[l] = vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2); - t[l + 6] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)); - t[l + 12] = vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)); - t[l + 18] = vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)); - t[l + 24] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), src[5 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), src[7 + offset]); } for (int l = 0; l < 5; ++l) { - int offset = l * 6; + int offset = l * 8; float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); - float32x4_t tmp3 = vsubq_f32(t[1 + offset], t[2 + offset]); - float32x4_t tmp4 = vsubq_f32(t[3 + offset], t[4 + offset]); - m[l] = vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 5] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 2)), bias_ptr); - m[l + 10] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), bias_ptr); - m[l + 15] = vaddq_f32(vaddq_f32(tmp3, vmulq_n_f32(tmp4, 8)), bias_ptr); - m[l + 20] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), t[5 + offset]), bias_ptr); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 5] = vmaxq_f32(zero, m[l + 5]); + m[l + 5] = vminq_f32(six, m[l + 5]); + m[l + 10] = vmaxq_f32(zero, m[l + 10]); + m[l + 10] = vminq_f32(six, m[l + 10]); + m[l + 15] = vmaxq_f32(zero, m[l + 15]); + m[l + 15] = vminq_f32(six, m[l + 15]); + m[l + 20] = vmaxq_f32(zero, m[l + 20]); + m[l + 20] = vminq_f32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } - Store25Data; #else - float src[36]; - float t[30]; + float src[64]; + float t[40]; float m[25]; - for (int i = 0; i < C4NUM; ++i) { + for (int i = 0; i < r_c; ++i) { // load source data - for (int j = 0; j < 36; ++j) { + for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; } - for (int l = 0; l < 6; ++l) { - int offset = l * 6; - t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; - t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); - t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); - t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); - t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; } for (int l = 0; l < 5; ++l) { - int offset = l * 6; - m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; - m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); - m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); - m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); - m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 5; ++k) { - int dst_k_offset = k * dst_step * C4NUM; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; int m_k_offset = k * 5; - for (int j = 0; j < 5; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[64]; - float32x4_t t[16]; - float32x4_t m[4]; + float32x4_t t[48]; + float32x4_t m[36]; Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 8; ++l) { int offset = l * 8; - t[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src[offset], src[1 + offset]), src[2 + offset]), - src[3 + offset]), - src[4 + offset]), - src[5 + offset]), - src[6 + offset]); - t[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(vsubq_f32(src[1 + offset], src[2 + offset]), - vmulq_n_f32(vsubq_f32(src[3 + offset], src[4 + offset]), 2)), - vmulq_n_f32(vsubq_f32(src[5 + offset], src[6 + offset]), 3)), - src[7 + offset]); + float32x4_t tmp1 = vaddq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp2 = vaddq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp3 = vaddq_f32(src[5 + offset], src[6 + offset]); + float32x4_t tmp4 = vsubq_f32(src[1 + offset], src[2 + offset]); + float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); + float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)); + t[l + 40] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), src[7 + offset]); } - for (int l = 0; l < 2; ++l) { + for (int l = 0; l < 6; ++l) { int offset = l * 8; - m[l] = vaddq_f32( - vaddq_f32( - vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), - t[4 + offset]), - t[5 + offset]), - t[6 + offset]), + float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp3 = vaddq_f32(t[5 + offset], t[6 + offset]); + float32x4_t tmp4 = vsubq_f32(t[1 + offset], t[2 + offset]); + float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); + float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), t[7 + offset]), bias_ptr); - m[l + 2] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vsubq_f32(t[1 + offset], t[2 + offset]), - vmulq_n_f32(vsubq_f32(t[3 + offset], t[4 + offset]), 2)), - vmulq_n_f32(vsubq_f32(t[5 + offset], t[6 + offset]), 3)), - t[7 + offset]), - bias_ptr); } - Store4Data; + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } #else float src[64]; - float t[16]; - float m[4]; - for (int i = 0; i < C4NUM; ++i) { + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; @@ -779,32 +3034,50 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float int offset = l * 8; t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + src[6 + offset]; - t[l + 8] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + - 3 * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; } - for (int l = 0; l < 2; ++l) { + for (int l = 0; l < 6; ++l) { int offset = l * 8; m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; - m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + - 3 * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 2; ++k) { - int dst_k_offset = k * dst_step * C4NUM; - int m_k_offset = k * 2; - for (int j = 0; j < 2; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; } } } #endif } -void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[64]; - float32x4_t t[24]; - float32x4_t m[9]; + float32x4_t t[48]; + float32x4_t m[36]; + float32x4_t zero = vdupq_n_f32(0); Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 8; ++l) { @@ -816,10 +3089,14 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)); - t[l + 16] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)), src[7 + offset]); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)); + t[l + 40] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), src[7 + offset]); } - for (int l = 0; l < 3; ++l) { + for (int l = 0; l < 6; ++l) { int offset = l * 8; float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); @@ -828,16 +3105,47 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 3] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)), bias_ptr); - m[l + 6] = vaddq_f32( - vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)), t[7 + offset]), bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + m[l + 12] = vmaxq_f32(zero, m[l + 12]); + m[l + 18] = vmaxq_f32(zero, m[l + 18]); + m[l + 24] = vmaxq_f32(zero, m[l + 24]); + m[l + 30] = vmaxq_f32(zero, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } - Store9Data; #else float src[64]; - float t[24]; - float m[9]; - for (int i = 0; i < C4NUM; ++i) { + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; @@ -846,36 +3154,53 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float int offset = l * 8; t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + src[6 + offset]; - t[l + 8] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + - 3 * (src[5 + offset] - src[6 + offset]); - t[l + 16] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + - 9 * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; } - for (int l = 0; l < 3; ++l) { + for (int l = 0; l < 6; ++l) { int offset = l * 8; m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; - m[l + 3] = - t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + 3 * (t[5 + offset] - t[6 + offset]); - m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + - 9 * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 3; ++k) { - int dst_k_offset = k * dst_step * C4NUM; - int m_k_offset = k * 3; - for (int j = 0; j < 3; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[64]; - float32x4_t t[32]; - float32x4_t m[16]; + float32x4_t t[48]; + float32x4_t m[36]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 8; ++l) { @@ -887,11 +3212,14 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)); - t[l + 16] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)); - t[l + 24] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)), src[7 + offset]); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)); + t[l + 40] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), src[7 + offset]); } - for (int l = 0; l < 4; ++l) { + for (int l = 0; l < 6; ++l) { int offset = l * 8; float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); @@ -900,17 +3228,53 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 4] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)), bias_ptr); - m[l + 8] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)), bias_ptr); - m[l + 12] = vaddq_f32( - vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)), t[7 + offset]), bias_ptr); + m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 6] = vmaxq_f32(zero, m[l + 6]); + m[l + 6] = vminq_f32(six, m[l + 6]); + m[l + 12] = vmaxq_f32(zero, m[l + 12]); + m[l + 12] = vminq_f32(six, m[l + 12]); + m[l + 18] = vmaxq_f32(zero, m[l + 18]); + m[l + 18] = vminq_f32(six, m[l + 18]); + m[l + 24] = vmaxq_f32(zero, m[l + 24]); + m[l + 24] = vminq_f32(six, m[l + 24]); + m[l + 30] = vmaxq_f32(zero, m[l + 30]); + m[l + 30] = vminq_f32(six, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } - Store16Data; #else float src[64]; - float t[32]; - float m[16]; - for (int i = 0; i < C4NUM; ++i) { + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; @@ -919,40 +3283,52 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float int offset = l * 8; t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + src[6 + offset]; - t[l + 8] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + - 3 * (src[5 + offset] - src[6 + offset]); - t[l + 16] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + - 9 * (src[5 + offset] + src[6 + offset]); - t[l + 24] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + - 27 * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; } - for (int l = 0; l < 4; ++l) { + for (int l = 0; l < 6; ++l) { int offset = l * 8; m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; - m[l + 4] = - t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + 3 * (t[5 + offset] - t[6 + offset]); - m[l + 8] = - t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + 9 * (t[5 + offset] + t[6 + offset]); - m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + - 27 * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 4; ++k) { - int dst_k_offset = k * dst_step * C4NUM; - int m_k_offset = k * 4; - for (int j = 0; j < 4; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[64]; - float32x4_t t[40]; - float32x4_t m[25]; + float32x4_t t[56]; + float32x4_t m[49]; Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 8; ++l) { @@ -964,12 +3340,15 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)); - t[l + 16] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)); - t[l + 24] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)); - t[l + 32] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), vmulq_n_f32(tmp3, 81)), src[7 + offset]); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)); + t[l + 40] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)); + t[l + 48] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.015625), tmp2), vmulq_n_f32(tmp3, 11.390625)), src[7 + offset]); } - for (int l = 0; l < 5; ++l) { + for (int l = 0; l < 7; ++l) { int offset = l * 8; float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); @@ -978,18 +3357,43 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 5] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)), bias_ptr); - m[l + 10] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)), bias_ptr); - m[l + 15] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)), bias_ptr); - m[l + 20] = vaddq_f32( - vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), vmulq_n_f32(tmp3, 81)), t[7 + offset]), bias_ptr); + m[l + 7] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.015625), tmp2), vmulq_n_f32(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } - Store25Data; #else float src[64]; - float t[40]; - float m[25]; - for (int i = 0; i < C4NUM; ++i) { + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; @@ -998,44 +3402,54 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float int offset = l * 8; t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + src[6 + offset]; - t[l + 8] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + - 3 * (src[5 + offset] - src[6 + offset]); - t[l + 16] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + - 9 * (src[5 + offset] + src[6 + offset]); - t[l + 24] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + - 27 * (src[5 + offset] - src[6 + offset]); - t[l + 32] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + - 81 * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; } - for (int l = 0; l < 5; ++l) { + for (int l = 0; l < 7; ++l) { int offset = l * 8; m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; - m[l + 5] = - t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + 3 * (t[5 + offset] - t[6 + offset]); - m[l + 10] = - t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + 9 * (t[5 + offset] + t[6 + offset]); - m[l + 15] = - t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + 27 * (t[5 + offset] - t[6 + offset]); - m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + - 81 * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 5; ++k) { - int dst_k_offset = k * dst_step * C4NUM; - int m_k_offset = k * 5; - for (int j = 0; j < 5; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; } } } #endif } -void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[64]; - float32x4_t t[48]; - float32x4_t m[36]; + float32x4_t t[56]; + float32x4_t m[49]; + float32x4_t zero = vdupq_n_f32(0); Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 8; ++l) { @@ -1047,13 +3461,15 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)); - t[l + 16] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)); - t[l + 24] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)); - t[l + 32] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), vmulq_n_f32(tmp3, 81)); - t[l + 40] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 32)), vmulq_n_f32(tmp6, 243)), src[7 + offset]); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)); + t[l + 40] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)); + t[l + 48] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.015625), tmp2), vmulq_n_f32(tmp3, 11.390625)), src[7 + offset]); } - for (int l = 0; l < 6; ++l) { + for (int l = 0; l < 7; ++l) { int offset = l * 8; float32x4_t tmp1 = vaddq_f32(t[1 + offset], t[2 + offset]); float32x4_t tmp2 = vaddq_f32(t[3 + offset], t[4 + offset]); @@ -1062,28 +3478,50 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)), bias_ptr); - m[l + 12] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)), bias_ptr); - m[l + 18] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)), bias_ptr); - m[l + 24] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), vmulq_n_f32(tmp3, 81)), bias_ptr); - m[l + 30] = vaddq_f32( - vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 32)), vmulq_n_f32(tmp6, 243)), t[7 + offset]), bias_ptr); + m[l + 7] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.015625), tmp2), vmulq_n_f32(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l + 7] = vmaxq_f32(zero, m[l + 7]); + m[l + 14] = vmaxq_f32(zero, m[l + 14]); + m[l + 21] = vmaxq_f32(zero, m[l + 21]); + m[l + 28] = vmaxq_f32(zero, m[l + 28]); + m[l + 35] = vmaxq_f32(zero, m[l + 35]); + m[l + 42] = vmaxq_f32(zero, m[l + 42]); } - for (int i = 0; i < 6; i++) { - int dst_k_offset = i * dst_step * C4NUM; - int m_k_offset = i * 6; - vst1q_f32(dst_data + dst_k_offset + 0 * C4NUM, m[m_k_offset]); - vst1q_f32(dst_data + dst_k_offset + 1 * C4NUM, m[m_k_offset + 1]); - vst1q_f32(dst_data + dst_k_offset + 2 * C4NUM, m[m_k_offset + 2]); - vst1q_f32(dst_data + dst_k_offset + 3 * C4NUM, m[m_k_offset + 3]); - vst1q_f32(dst_data + dst_k_offset + 4 * C4NUM, m[m_k_offset + 4]); - vst1q_f32(dst_data + dst_k_offset + 5 * C4NUM, m[m_k_offset + 5]); + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } #else float src[64]; - float t[48]; - float m[36]; - for (int i = 0; i < C4NUM; ++i) { + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; @@ -1092,48 +3530,57 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float int offset = l * 8; t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + src[6 + offset]; - t[l + 8] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + - 3 * (src[5 + offset] - src[6 + offset]); - t[l + 16] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + - 9 * (src[5 + offset] + src[6 + offset]); - t[l + 24] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + - 27 * (src[5 + offset] - src[6 + offset]); - t[l + 32] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + - 81 * (src[5 + offset] + src[6 + offset]); - t[l + 40] = src[1 + offset] - src[2 + offset] + 32 * (src[3 + offset] - src[4 + offset]) + - 243 * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; } - for (int l = 0; l < 6; ++l) { + for (int l = 0; l < 7; ++l) { int offset = l * 8; m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; - m[l + 6] = - t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + 3 * (t[5 + offset] - t[6 + offset]); - m[l + 12] = - t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + 9 * (t[5 + offset] + t[6 + offset]); - m[l + 18] = - t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + 27 * (t[5 + offset] - t[6 + offset]); - m[l + 24] = - t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + 81 * (t[5 + offset] + t[6 + offset]); - m[l + 30] = t[1 + offset] - t[2 + offset] + 32 * (t[3 + offset] - t[4 + offset]) + - 243 * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 6; ++k) { - int dst_k_offset = k * dst_step * C4NUM; - int m_k_offset = k * 6; - for (int j = 0; j < 6; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } #endif } -void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { + +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { #ifdef ENABLE_ARM float32x4_t src[64]; float32x4_t t[56]; float32x4_t m[49]; + float32x4_t zero = vdupq_n_f32(0); + float32x4_t six = vdupq_n_f32(6); Load64Data; float32x4_t bias_ptr = vld1q_f32(bias_data); for (int l = 0; l < 8; ++l) { @@ -1145,12 +3592,13 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(src[3 + offset], src[4 + offset]); float32x4_t tmp6 = vsubq_f32(src[5 + offset], src[6 + offset]); t[l] = vaddq_f32(vaddq_f32(vaddq_f32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)); - t[l + 16] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)); - t[l + 24] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)); - t[l + 32] = vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), vmulq_n_f32(tmp3, 81)); - t[l + 40] = vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 32)), vmulq_n_f32(tmp6, 243)); - t[l + 48] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 64)), vmulq_n_f32(tmp3, 729)), src[7 + offset]); + t[l + 8] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)); + t[l + 16] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)); + t[l + 24] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)); + t[l + 32] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)); + t[l + 40] = vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)); + t[l + 48] = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.015625), tmp2), vmulq_n_f32(tmp3, 11.390625)), src[7 + offset]); } for (int l = 0; l < 7; ++l) { int offset = l * 8; @@ -1161,30 +3609,57 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float float32x4_t tmp5 = vsubq_f32(t[3 + offset], t[4 + offset]); float32x4_t tmp6 = vsubq_f32(t[5 + offset], t[6 + offset]); m[l] = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 7] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 2)), vmulq_n_f32(tmp6, 3)), bias_ptr); - m[l + 14] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 4)), vmulq_n_f32(tmp3, 9)), bias_ptr); - m[l + 21] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 8)), vmulq_n_f32(tmp6, 27)), bias_ptr); - m[l + 28] = vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 16)), vmulq_n_f32(tmp3, 81)), bias_ptr); - m[l + 35] = vaddq_f32(vaddq_f32(vaddq_f32(tmp4, vmulq_n_f32(tmp5, 32)), vmulq_n_f32(tmp6, 243)), bias_ptr); + m[l + 7] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.5), tmp5), vmulq_n_f32(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.25), tmp2), vmulq_n_f32(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.125), tmp5), vmulq_n_f32(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.0625), tmp2), vmulq_n_f32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp4, 0.03125), tmp5), vmulq_n_f32(tmp6, 7.59375)), bias_ptr); m[l + 42] = vaddq_f32( - vaddq_f32(vaddq_f32(vaddq_f32(tmp1, vmulq_n_f32(tmp2, 64)), vmulq_n_f32(tmp3, 729)), t[7 + offset]), bias_ptr); + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(tmp1, 0.015625), tmp2), vmulq_n_f32(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f32(zero, m[l]); + m[l] = vminq_f32(six, m[l]); + m[l + 7] = vmaxq_f32(zero, m[l + 7]); + m[l + 7] = vminq_f32(six, m[l + 7]); + m[l + 14] = vmaxq_f32(zero, m[l + 14]); + m[l + 14] = vminq_f32(six, m[l + 14]); + m[l + 21] = vmaxq_f32(zero, m[l + 21]); + m[l + 21] = vminq_f32(six, m[l + 21]); + m[l + 28] = vmaxq_f32(zero, m[l + 28]); + m[l + 28] = vminq_f32(six, m[l + 28]); + m[l + 35] = vmaxq_f32(zero, m[l + 35]); + m[l + 35] = vminq_f32(six, m[l + 35]); + m[l + 42] = vmaxq_f32(zero, m[l + 42]); + m[l + 42] = vminq_f32(six, m[l + 42]); } - for (int i = 0; i < 7; i++) { - int dst_k_offset = i * dst_step * C4NUM; - int m_k_offset = i * 7; - vst1q_f32(dst_data + dst_k_offset + 0 * C4NUM, m[m_k_offset]); - vst1q_f32(dst_data + dst_k_offset + 1 * C4NUM, m[m_k_offset + 1]); - vst1q_f32(dst_data + dst_k_offset + 2 * C4NUM, m[m_k_offset + 2]); - vst1q_f32(dst_data + dst_k_offset + 3 * C4NUM, m[m_k_offset + 3]); - vst1q_f32(dst_data + dst_k_offset + 4 * C4NUM, m[m_k_offset + 4]); - vst1q_f32(dst_data + dst_k_offset + 5 * C4NUM, m[m_k_offset + 5]); - vst1q_f32(dst_data + dst_k_offset + 6 * C4NUM, m[m_k_offset + 6]); + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } } #else float src[64]; float t[56]; float m[49]; - for (int i = 0; i < C4NUM; ++i) { + for (int i = 0; i < r_c; ++i) { // load source data for (int j = 0; j < 64; ++j) { src[j] = src_data[i + j * src_step]; @@ -1193,41 +3668,44 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float int offset = l * 8; t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + src[6 + offset]; - t[l + 8] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + - 3 * (src[5 + offset] - src[6 + offset]); - t[l + 16] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + - 9 * (src[5 + offset] + src[6 + offset]); - t[l + 24] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + - 27 * (src[5 + offset] - src[6 + offset]); - t[l + 32] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + - 81 * (src[5 + offset] + src[6 + offset]); - t[l + 40] = src[1 + offset] - src[2 + offset] + 32 * (src[3 + offset] - src[4 + offset]) + - 243 * (src[5 + offset] - src[6 + offset]); - t[l + 48] = src[1 + offset] + src[2 + offset] + 64 * (src[3 + offset] + src[4 + offset]) + - 729 * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; } for (int l = 0; l < 7; ++l) { int offset = l * 8; m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; - m[l + 7] = - t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + 3 * (t[5 + offset] - t[6 + offset]); - m[l + 14] = - t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + 9 * (t[5 + offset] + t[6 + offset]); - m[l + 21] = - t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + 27 * (t[5 + offset] - t[6 + offset]); - m[l + 28] = - t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + 81 * (t[5 + offset] + t[6 + offset]); - m[l + 35] = - t[1 + offset] - t[2 + offset] + 32 * (t[3 + offset] - t[4 + offset]) + 243 * (t[5 + offset] - t[6 + offset]); - m[l + 42] = t[1 + offset] + t[2 + offset] + 64 * (t[3 + offset] + t[4 + offset]) + - 729 * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; } // store output - for (int k = 0; k < 7; ++k) { - int dst_k_offset = k * dst_step * C4NUM; + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; int m_k_offset = k * 7; - for (int j = 0; j < 7; ++j) { - dst_data[i + dst_k_offset + j * C4NUM] = m[j + m_k_offset] + bias_data[i]; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; } } } @@ -1255,7 +3733,7 @@ int SelectOutputUnit(ConvParameter *conv_param) { for (int i = MIN_UNIT; i <= max_out_unit; ++i) { int input_unit = i + kernel_w - 1; - if (!GetOutputTransFunc(input_unit, i)) { + if (!GetOutputTransFunc(input_unit, i, ActType_No)) { continue; } float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; diff --git a/mindspore/lite/nnacl/winograd_utils.h b/mindspore/lite/nnacl/winograd_utils.h index 4fb06563f2..8e7e8745d7 100644 --- a/mindspore/lite/nnacl/winograd_utils.h +++ b/mindspore/lite/nnacl/winograd_utils.h @@ -31,7 +31,7 @@ extern "C" { typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step); typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step); + int dst_step, int out_c, int r_w, int r_h, int r_c); void GeneralInputTransformUnit(const float *src_data, float *dst_data, float *matrix_b, float *matrix_bt, int src_step, int dst_step, int in_unit); @@ -169,84 +169,144 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step); -OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit); +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); #define Store4Data \ vst1q_f32(dst_data, m[0]); \ - vst1q_f32(dst_data + C4NUM, m[1]); \ - vst1q_f32(dst_data + dst_step * C4NUM, m[2]); \ - vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[3]); + vst1q_f32(dst_data + out_c, m[1]); \ + vst1q_f32(dst_data + dst_step * out_c, m[2]); \ + vst1q_f32(dst_data + dst_step * out_c + out_c, m[3]); #define Store9Data \ vst1q_f32(dst_data, m[0]); \ - vst1q_f32(dst_data + C4NUM, m[1]); \ - vst1q_f32(dst_data + 2 * C4NUM, m[2]); \ - vst1q_f32(dst_data + dst_step * C4NUM, m[3]); \ - vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[4]); \ - vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[5]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[6]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[7]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[8]); + vst1q_f32(dst_data + out_c, m[1]); \ + vst1q_f32(dst_data + 2 * out_c, m[2]); \ + vst1q_f32(dst_data + dst_step * out_c, m[3]); \ + vst1q_f32(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); #define Store16Data \ vst1q_f32(dst_data, m[0]); \ - vst1q_f32(dst_data + C4NUM, m[1]); \ - vst1q_f32(dst_data + 2 * C4NUM, m[2]); \ - vst1q_f32(dst_data + 3 * C4NUM, m[3]); \ - vst1q_f32(dst_data + dst_step * C4NUM, m[4]); \ - vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[5]); \ - vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[6]); \ - vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[7]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[8]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[9]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[10]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[11]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[12]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[13]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[14]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[15]); + vst1q_f32(dst_data + out_c, m[1]); \ + vst1q_f32(dst_data + 2 * out_c, m[2]); \ + vst1q_f32(dst_data + 3 * out_c, m[3]); \ + vst1q_f32(dst_data + dst_step * out_c, m[4]); \ + vst1q_f32(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); #define Store25Data \ vst1q_f32(dst_data, m[0]); \ - vst1q_f32(dst_data + C4NUM, m[1]); \ - vst1q_f32(dst_data + 2 * C4NUM, m[2]); \ - vst1q_f32(dst_data + 3 * C4NUM, m[3]); \ - vst1q_f32(dst_data + 4 * C4NUM, m[4]); \ - vst1q_f32(dst_data + dst_step * C4NUM, m[5]); \ - vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[6]); \ - vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[7]); \ - vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[8]); \ - vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, m[9]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[10]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[11]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[12]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[13]); \ - vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, m[14]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[15]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[16]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[17]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[18]); \ - vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, m[19]); \ - vst1q_f32(dst_data + 4 * dst_step * C4NUM, m[20]); \ - vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, m[21]); \ - vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, m[22]); \ - vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, m[23]); \ - vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, m[24]); + vst1q_f32(dst_data + out_c, m[1]); \ + vst1q_f32(dst_data + 2 * out_c, m[2]); \ + vst1q_f32(dst_data + 3 * out_c, m[3]); \ + vst1q_f32(dst_data + 4 * out_c, m[4]); \ + vst1q_f32(dst_data + dst_step * out_c, m[5]); \ + vst1q_f32(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1q_f32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1q_f32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1q_f32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1q_f32(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1q_f32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1q_f32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1q_f32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1q_f32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); -void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); -void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); -void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); -void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); int SelectOutputUnit(ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc index 46af13e272..af21fc75e6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -136,8 +136,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { float matrix_at[64]; float matrix_b[64]; float matrix_bt[64]; + float coef = 1.0f; + if (input_unit_ == 8) { + coef = 0.5f; + } auto ret = - CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 1.0f, output_unit_, kernel_unit_); + CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_); if (ret != RET_OK) { MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; return ret; @@ -243,17 +247,11 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() { MS_LOG(ERROR) << "in_func_ is null."; return RET_ERROR; } - out_func_ = GetOutputTransFunc(input_unit_, output_unit_); + out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_); if (out_func_ == nullptr) { MS_LOG(ERROR) << "out_func_ is null."; return RET_ERROR; } - - // #ifdef ENABLE_ARM32 - // gemm_func_ = IndirectGemmFp32_8x4; - // #else - gemm_func_ = IndirectGemmFp32_8x8; - // #endif return RET_OK; } @@ -300,12 +298,9 @@ int ConvolutionWinogradCPUKernel::ReSize() { } int ConvolutionWinogradCPUKernel::RunImpl(int task_id) { - if (gemm_func_ == nullptr) { - MS_LOG(ERROR) << "gemm_func is nullptr."; - return RET_ERROR; - } + auto output_data = reinterpret_cast(out_tensors_.front()->MutableData()); ConvWinogardFp32(reinterpret_cast(nhwc4_input_), trans_weight_, reinterpret_cast(bias_data_), - tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_, gemm_func_); + output_data, tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_); return RET_OK; } @@ -368,12 +363,6 @@ int ConvolutionWinogradCPUKernel::Run() { return RET_ERROR; } - ret = PostProcess(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Post process failed."; - FreeTmpBuffer(); - return ret; - } FreeTmpBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h index 28f11b57b3..f1e97291c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h @@ -87,7 +87,6 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { TmpBufferAddress tmp_buffer_address_list_[5]; InputTransFunc in_func_; OutputTransFunc out_func_; - GEMM_FUNC_FP32 gemm_func_ = nullptr; }; } // namespace mindspore::kernel