/** * \file dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/arm_common/conv_bias/f16/helper.h" #include "src/arm_common/conv_bias/f16/strategy.h" #include "src/arm_common/elemwise_helper/op_unary.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include "midout.h" MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F45) using namespace megdnn; using namespace arm_common; namespace { struct FilterTransform4X5 { #define FILTER_TRANSFORM(d, wd) \ do { \ wd##0 = d##0; \ wd##r0 = d##r0; \ wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ wd##r1 = (d##r0 + d##r1 + d##r2 + d##r3 + d##r4) * -0.222168; \ wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ wd##r2 = (d##r0 - d##r1 + d##r2 - d##r3 + d##r4) * -0.222168; \ auto tmpd0 = d##0 * 0.710938; \ auto tmpd1 = d##1 * 0.355469; \ auto tmpd2 = d##2 * 0.177734; \ auto tmpd3 = d##3 * 0.088867; \ auto tmpd4 = d##4 * 0.044434; \ auto tmpdr0 = d##r0 * 0.710938; \ auto tmpdr1 = d##r1 * 0.355469; \ auto tmpdr2 = d##r2 * 0.177734; \ auto tmpdr3 = d##r3 * 0.088867; \ auto tmpdr4 = d##r4 * 0.044434; \ wd##3 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ wd##r3 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ wd##4 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ wd##r4 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ tmpd0 = d##0 * 0.011108; \ tmpd1 = d##1 * 0.022217; \ tmpd2 = d##2 * 0.044434; \ tmpd3 = d##3 * 0.088867; \ tmpd4 = d##4 * 0.177734; \ tmpdr0 = d##r0 * 0.011108; \ tmpdr1 = d##r1 * 0.022217; \ tmpdr2 = d##r2 * 0.044434; \ tmpdr3 = d##r3 * 0.088867; \ tmpdr4 = d##r4 * 0.177734; \ wd##5 = tmpd0 + tmpd1 + tmpd2 + tmpd3 + tmpd4; \ ; \ wd##r5 = tmpdr0 + tmpdr1 + tmpdr2 + tmpdr3 + tmpdr4; \ ; \ wd##6 = tmpd0 - tmpd1 + tmpd2 - tmpd3 + tmpd4; \ ; \ wd##r6 = tmpdr0 - tmpdr1 + tmpdr2 - tmpdr3 + tmpdr4; \ ; \ wd##7 = d##4; \ wd##r7 = d##r4; \ } while (0); #define FILTER_TRANSFORM_FINAL(d, wd) \ do { \ wd##0 = d##0; \ wd##1 = (d##0 + d##1 + d##2 + d##3 + d##4) * -0.222168; \ wd##2 = (d##0 - d##1 + d##2 - d##3 + d##4) * -0.222168; \ auto tmp0 = d##0 * 0.710938 + d##2 * 0.177734 + d##4 * 0.044434; \ auto tmp1 = d##1 * 0.355469 + d##3 * 0.088867; \ wd##3 = tmp0 + tmp1; \ wd##4 = tmp0 - tmp1; \ tmp0 = d##0 * 0.011108 + d##2 * 0.044434 + d##4 * 0.177734; \ tmp1 = d##1 * 0.022217 + d##3 * 0.088867; \ wd##5 = tmp0 + tmp1; \ wd##6 = tmp0 - tmp1; \ wd##7 = d##4; \ } while (0); static void transform(const __fp16* filter, __fp16* filter_transform_buf, __fp16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, size_t oc_end) { // Gg * GT // G //[[ 1. 0. 0. 0. 0. ] // [-0.2222222 -0.2222222 -0.2222222 -0.2222222 -0.2222222] // [-0.2222222 0.2222222 -0.2222222 0.2222222 -0.2222222] // [ 0.7111111 0.3555556 0.1777778 0.0888889 0.0444444] // [ 0.7111111 -0.3555556 0.1777778 -0.0888889 0.0444444] // [ 0.0111111 0.0222222 0.0444444 0.0888889 0.1777778] // [ 0.0111111 -0.0222222 0.0444444 -0.0888889 0.1777778] // [ 0. 0. 0. 0. 1. ]] constexpr size_t alpha = 4 + 5 - 1; for (size_t oc = oc_start; oc < oc_end; oc++) { rep(ic, IC) { const __fp16* fptr = filter + (oc * IC + ic) * 5 * 5; #define cb(i) Vector<__fp16, 4> g##i = Vector<__fp16, 4>::load(fptr + 5 * i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb #define cb(i) __fp16 gr##i = *(fptr + 5 * i + 4); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb #define cb(i) Vector<__fp16, 4> Gg##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(i) __fp16 Ggr##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(i) Vector<__fp16, 8> Ggt##i; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb #define cb(i) Vector<__fp16, 8> result##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb FILTER_TRANSFORM(g, Gg) #if MEGDNN_AARCH64 float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3, Ggr4, Ggr5, Ggr6, Ggr7}; Vector<__fp16, 8> Ggt4(vgr); TRANSPOSE_8x4(Gg, Ggt); FILTER_TRANSFORM_FINAL(Ggt, result); #define cb(i) result##i.save(transform_mid_buf + i * alpha); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = transform_mid_buf[j * alpha + i]; } #else #define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) #define cb(i) \ do { \ mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \ auto tmp13 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) + \ GET_VECTOR_FP16D_ELEM(Gg, i, 3); \ mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \ mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \ auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ auto tmp4 = Ggr##i * 0.0444444; \ tmp024 = tmp0 + tmp2 + tmp4; \ tmp13 = tmp1 + tmp3; \ mid_buf1[3] = tmp024 + tmp13; \ mid_buf1[4] = tmp024 - tmp13; \ tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ tmp4 = Ggr##i * 0.1777778; \ tmp024 = tmp0 + tmp2 + tmp4; \ tmp13 = tmp1 + tmp3; \ mid_buf1[5] = tmp024 + tmp13; \ mid_buf1[6] = tmp024 - tmp13; \ mid_buf1[7] = Ggr##i; \ mid_buf1 += 8; \ } while (0); __fp16* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); mid_buf1 = transform_mid_buf; #undef cb #undef GET_VECTOR_FP16D_ELEM rep(i, alpha) rep(j, alpha) { filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = transform_mid_buf[i * alpha + j]; } #endif } } } }; #undef FILTER_TRANSFORM #undef FILTER_TRANSFORM_FINAL struct InputTransform4X5 { #define INPUT_TRANSFORM(d, wd) \ do { \ wd##0 = (d##0 - d##6) + (d##4 - d##2) * 5.25f; \ auto tmp0 = d##2 - d##4 * 4.25f + d##6; \ auto tmp1 = d##1 - d##3 * 4.25f + d##5; \ wd##1 = tmp0 + tmp1; \ wd##2 = tmp0 - tmp1; \ tmp0 = d##2 * 4.0f - d##4 * 5.0f + d##6; \ tmp1 = d##1 * 2.0f - d##3 * 2.5f + d##5 * 0.5f; \ wd##3 = tmp0 + tmp1; \ wd##4 = tmp0 - tmp1; \ tmp0 = d##2 * 0.25f - d##4 * 1.25f + d##6; \ tmp1 = d##1 * 0.5f - d##3 * 2.5f + d##5 * 2.0f; \ wd##5 = tmp0 + tmp1; \ wd##6 = tmp0 - tmp1; \ wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0) #define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) template static void transform(const __fp16* input, __fp16* input_transform_buf, __fp16* transform_mid_buf, int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, size_t unit_idx, size_t nr_units_in_tile) { // BTd * B //([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], // [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], // [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], // [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], // [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], // [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], // [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], // [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) constexpr size_t alpha = 4 + 5 - 1; if (!inner) { memset(transform_mid_buf, 0, sizeof(__fp16) * alpha * alpha); } #define cb(i) Vector<__fp16, 8> d##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb if (inner) { const __fp16* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; #define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } else { int ih0_act = std::max(ih_start, 0), ih1_act = std::min(ih_start + alpha, IH), iw0_act = std::max(iw_start, 0), iw1_act = std::min(iw_start + alpha, IW); for (int ih = ih0_act; ih < ih1_act; ++ih) { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; transform_mid_buf[iho * alpha + iwo] = input[ic * IH * IW + ih * IW + iw]; } } #define cb(i) d##i = Vector<__fp16, 8>::load(transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb } #define cb(i) Vector<__fp16, 8> wd##i, ret##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb INPUT_TRANSFORM(d, wd); TRANSPOSE_8x8(wd, d); INPUT_TRANSFORM(d, ret); #define cb(i) ret##i.save(transform_mid_buf + i * alpha); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = transform_mid_buf[j * alpha + i]; } } }; #undef INPUT_TRANSFORM #define OUTPUT_TRANSFORM(m, s) \ do { \ s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \ s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - \ m##6 * 2.0; \ s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + \ m##6 * 4.0; \ s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - \ m##6 * 8.0 + m##7; \ } while (0) template struct OutputTransform4X5 { static void transform(const dt_float16* output_transform_buf, const dt_float16* bias, dt_float16* output, dt_float16* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A // AT f45 // 1.0 1.0 1.0 1.000 1.000 1.0 1.0 0.0 // 0.0 1.0 -1.0 0.500 -0.500 2.0 -2.0 0.0 // 0.0 1.0 1.0 0.250 0.250 4.0 4.0 0.0 // 0.0 1.0 -1.0 0.125 -0.125 8.0 -8.0 1.0 constexpr size_t alpha = 5 + 4 - 1; const __fp16* fp16_output_transform_buf = reinterpret_cast(output_transform_buf); const __fp16* fp16_bias = reinterpret_cast(bias); __fp16* fp16_output = reinterpret_cast<__fp16*>(output); __fp16* fp16_transform_mid_buf = reinterpret_cast<__fp16*>(transform_mid_buf); __fp16* mid_buf1 = fp16_transform_mid_buf; size_t OC = oc_end - oc_start; size_t oc = oc_start + oc_index; #define cb(m, n) \ fp16_transform_mid_buf[m * alpha + n] = \ fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ OC + \ unit_idx * OC + oc_index]; UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb #define cb(i) \ auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(i) Vector<__fp16, 8> s##i; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb #define cb(i) Vector<__fp16, 4> st##i; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(i) Vector<__fp16, 4> result##i; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb OUTPUT_TRANSFORM(m, s); TRANSPOSE_4x8(s, st); OUTPUT_TRANSFORM(st, result); TRANSPOSE_4x4(result, result); if (oh_start + 4 <= OH && ow_start + 4 <= OW) { int index = (oc * OH + oh_start) * OW + ow_start; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); result0.value = vadd_f16(result0.value, bias0); result1.value = vadd_f16(result1.value, bias0); result2.value = vadd_f16(result2.value, bias0); result3.value = vadd_f16(result3.value, bias0); } else if (bmode == BiasMode::BIAS) { float16x4_t bmbias0 = vld1_f16(fp16_bias + index); float16x4_t bmbias1 = vld1_f16(fp16_bias + index + OW); float16x4_t bmbias2 = vld1_f16(fp16_bias + index + OW * 2); float16x4_t bmbias3 = vld1_f16(fp16_bias + index + OW * 3); result0.value = vadd_f16(result0.value, bmbias0); result1.value = vadd_f16(result1.value, bmbias1); result2.value = vadd_f16(result2.value, bmbias2); result3.value = vadd_f16(result3.value, bmbias3); } float16x8_t item01 = op(vcombine_f16(result0.value, result1.value)); float16x8_t item23 = op(vcombine_f16(result2.value, result3.value)); vst1_f16(fp16_output + index, vget_low_f16(item01)); vst1_f16(fp16_output + index + OW, vget_high_f16(item01)); vst1_f16(fp16_output + index + OW * 2, vget_low_f16(item23)); vst1_f16(fp16_output + index + OW * 3, vget_high_f16(item23)); } else { #define cb(i) result##i.save(mid_buf1 + i * 4); mid_buf1 = fp16_transform_mid_buf; UNROLL_CALL_NOWRAPPER(4, cb); mid_buf1 = fp16_transform_mid_buf; #undef cb for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { size_t oh = oh_start + oho; size_t ow = ow_start + owo; __fp16 res = mid_buf1[oho * 4 + owo]; if (bmode == BiasMode::BIAS) { res += fp16_bias[oc * OH * OW + oh * OW + ow]; } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { res += fp16_bias[oc]; } res = op(res); fp16_output[oc * OH * OW + oh * OW + ow] = res; } } } } }; #undef OUTPUT_TRANSFORM #undef GET_VECTOR_FP16Q_ELEM } // namespace namespace megdnn { namespace arm_common { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f16) void winograd_4x5_1x1_f16::filter(const dt_float16* filter, dt_float16* filter_transform_buf, dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, size_t oc_end) { FilterTransform4X5::transform( reinterpret_cast(filter), reinterpret_cast<__fp16*>(filter_transform_buf), reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, oc_end); } void winograd_4x5_1x1_f16::input(const dt_float16* input, dt_float16* input_transform_buf, dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { constexpr int alpha = 4 + 5 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); rep(ic, IC) { rep(unit_idx, nr_units_in_tile) { size_t index = unit_start_idx + unit_idx; size_t nh = index / units_w; size_t nw = index % units_w; int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { InputTransform4X5::transform( reinterpret_cast(input), reinterpret_cast<__fp16*>(input_transform_buf), reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); } else { InputTransform4X5::transform( reinterpret_cast(input), reinterpret_cast<__fp16*>(input_transform_buf), reinterpret_cast<__fp16*>(transform_mid_buf), ih_start, iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); } } } } void winograd_4x5_1x1_f16::output(const dt_float16* output_transform_buf, const dt_float16* bias, dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); for (size_t oc = oc_start; oc < oc_end; oc++) { size_t oc_index = oc - oc_start; rep(unit_idx, nr_units_in_tile) { size_t index = unit_start_idx + unit_idx; auto nh = index / units_w; auto nw = index % units_w; size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; DISPATCH_CONV_WINOGRAD_BIAS( megdnn_arm_common_winograd_fp16_F45, cb, __fp16, __fp16, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); } } #undef cb } } // namespace winograd } // namespace arm_common } // namespace megdnn #endif // vim: syntax=cpp.doxygen