GitOrigin-RevId: 6d4b225ea5
tags/v0.3.2
| @@ -235,7 +235,7 @@ void StrategyHelper< | |||
| input_filter_compute_type* input_transform_buf, | |||
| input_filter_compute_type* transform_mid_buf, | |||
| int ih_start, int iw_start, size_t IH, size_t IW, | |||
| size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||
| size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||
| size_t m, size_t r, | |||
| const std::vector<float>& interp_points, DType dtype, | |||
| float rescale) { | |||
| @@ -284,7 +284,7 @@ void StrategyHelper< | |||
| const output_compute_type* bias, dst_type* output, | |||
| output_compute_type* transform_mid_buf, BiasMode bmode, | |||
| NonlineMode nonline_mode, size_t oh_start, | |||
| size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
| size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start, | |||
| size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | |||
| size_t m, size_t r, | |||
| const std::vector<float>& interp_points, DType dtype, | |||
| @@ -296,7 +296,7 @@ void StrategyHelper< | |||
| output_compute_type* mid_buf1 = transform_mid_buf; | |||
| output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; | |||
| OutputGetter<output_compute_type, dst_type> getter(dtype); | |||
| OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | |||
| OutputVisitor<layout, format> output_visitor(OC); | |||
| size_t oc = oc_start + oc_index; | |||
| @@ -6,8 +6,7 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| @@ -44,8 +43,8 @@ public: | |||
| input_filter_compute_type* input_transform_buf, | |||
| input_filter_compute_type* transform_mid_buf, | |||
| int ih_start, int iw_start, size_t IH, size_t IW, | |||
| size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||
| size_t m, size_t r, | |||
| size_t IC, size_t ic, size_t unit_idx, | |||
| size_t nr_units_in_tile, size_t m, size_t r, | |||
| const std::vector<float>& interp_points, DType dtype, | |||
| float rescale = 1.0f); | |||
| @@ -54,7 +53,7 @@ public: | |||
| const output_compute_type* bias, dst_type* output, | |||
| output_compute_type* transform_mid_buf, BiasMode bmode, | |||
| NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | |||
| size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||
| size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, | |||
| size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | |||
| const std::vector<float>& interp_points, DType dtype, | |||
| float input_filter_scale = 1.0f, // input_scale * filter_scale | |||
| @@ -45,7 +45,6 @@ public: | |||
| static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | |||
| for (auto&& algo : matmul_algos) { | |||
| if (algo->algoset() == | |||
| //! TODO: threre should filter MK matmul | |||
| MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
| continue; | |||
| } | |||
| @@ -536,7 +536,6 @@ public: | |||
| NonlineMode nonline_mode, size_t OH, size_t OW, \ | |||
| size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | |||
| size_t nr_tiles_in_unit); \ | |||
| }; | |||
| #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | |||
| @@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 { | |||
| float* output, float* transform_mid_buf, | |||
| size_t oh_start, size_t ow_start, size_t OH, | |||
| size_t OW, size_t oc_start, size_t oc_end, | |||
| size_t unit_idx, size_t nr_units_in_tile, | |||
| const DType& src_dtype, const DType& dst_dtype) { | |||
| size_t oc_index, size_t unit_idx, | |||
| size_t nr_units_in_tile, const DType& src_dtype, | |||
| const DType& dst_dtype) { | |||
| MEGDNN_MARK_USED_VAR(transform_mid_buf); | |||
| megdnn_assert( | |||
| (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||
| oc_end % 8 == 0, | |||
| "Winograd output transform input param is not times of 8!"); | |||
| Op op(src_dtype, dst_dtype); | |||
| //! AT * m * A | |||
| size_t OCB = (oc_end - oc_start) / 8; | |||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
| size_t ocb = (oc - oc_start) / 8; | |||
| size_t oc = oc_start + oc_index; | |||
| size_t ocb = oc_index / 8; | |||
| #define cb(m, n) \ | |||
| auto v##m##n = Vector<float, 8>::load( \ | |||
| output_transform_buf + \ | |||
| (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | |||
| ocb * nr_units_in_tile * 8 + unit_idx * 8); | |||
| UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
| UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
| #undef cb | |||
| //! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
| //! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
| //! v20 v21 v22 v23 1 -1 | |||
| //! v30 v31 v32 v33 0 1 | |||
| //! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
| //! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
| //! v20 v21 v22 v23 1 -1 | |||
| //! v30 v31 v32 v33 0 1 | |||
| #define cb(m) \ | |||
| auto t0##m = v0##m + v1##m + v2##m; \ | |||
| auto t1##m = v1##m - v2##m + v3##m; | |||
| UNROLL_CALL_NOWRAPPER(4, cb); | |||
| UNROLL_CALL_NOWRAPPER(4, cb); | |||
| #undef cb | |||
| #define cb(m) \ | |||
| v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | |||
| v##m##1 = t##m##1 - t##m##2 + t##m##3; | |||
| UNROLL_CALL_NOWRAPPER(2, cb); | |||
| UNROLL_CALL_NOWRAPPER(2, cb); | |||
| #undef cb | |||
| Vector<float, 8> vbias; | |||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
| vbias = Vector<float, 8>::load(bias + oc); | |||
| Vector<float, 8> vbias; | |||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
| vbias = Vector<float, 8>::load(bias + oc); | |||
| #define cb(m, n) v##m##n += vbias; | |||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||
| #undef cb | |||
| } | |||
| if (bmode != BiasMode::BIAS) { | |||
| } | |||
| if (bmode != BiasMode::BIAS) { | |||
| #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | |||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||
| #undef cb | |||
| } | |||
| } | |||
| #define out_save(oho, owo) \ | |||
| do { \ | |||
| size_t oh = oh_start + oho; \ | |||
| @@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 { | |||
| ow * 8); \ | |||
| } \ | |||
| } while (0); | |||
| UNROLL_CALL_RAW_D2(2, 2, out_save); | |||
| } | |||
| UNROLL_CALL_RAW_D2(2, 2, out_save); | |||
| } | |||
| }; | |||
| #undef CONCAT | |||
| @@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, | |||
| } | |||
| } | |||
| void winograd_nchw88_2x3_8x8_f::output( | |||
| const float* output_transform_buf, const float* bias, float* output, | |||
| float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||
| size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
| size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||
| void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf, | |||
| const float* bias, float* output, | |||
| float* transform_mid_buf, BiasMode bmode, | |||
| NonlineMode nonline_mode, size_t OH, | |||
| size_t OW, size_t oc_start, | |||
| size_t oc_end, size_t unit_start_idx, | |||
| size_t nr_units_in_tile) { | |||
| #define cb(_bmode, _nonline_op, ...) \ | |||
| OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||
| __VA_ARGS__); | |||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||
| megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, | |||
| float, bmode, nonline_mode, output_transform_buf, bias, output, | |||
| transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||
| unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
| size_t OC = oc_end - oc_start; | |||
| megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||
| "Winograd output transform input param is not times of 8!"); | |||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
| size_t oc_index = oc - oc_start; | |||
| rep(unit_idx, nr_units_in_tile) { | |||
| size_t index = unit_start_idx + unit_idx; | |||
| auto nh = index / units_w; | |||
| auto nw = index % units_w; | |||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||
| megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, | |||
| float, float, bmode, nonline_mode, output_transform_buf, | |||
| bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||
| oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, | |||
| dst_dtype); | |||
| } | |||
| } | |||
| #undef cb | |||
| } | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/common/unroll_macro.h" | |||
| @@ -19,10 +20,10 @@ | |||
| #include <x86intrin.h> | |||
| #ifdef WIN32CMAKE | |||
| #include <avxintrin.h> | |||
| #include <smmintrin.h> | |||
| #include <avx2intrin.h> | |||
| #include <avxintrin.h> | |||
| #include <fmaintrin.h> | |||
| #include <smmintrin.h> | |||
| #endif | |||
| #include "midout.h" | |||
| @@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 { | |||
| int ih_start, int iw_start, size_t IH, size_t IW, | |||
| size_t ic, size_t IC) { | |||
| MEGDNN_MARK_USED_VAR(patch); | |||
| size_t IW8 = IW * 8; //! For nchw88 mode | |||
| size_t IW8 = IW * 8; //! For nchw88 mode | |||
| size_t iw8_start = iw_start * 8; //! For nchw88 mode | |||
| size_t icb = ic / 8; | |||
| if (!(inner && ic + 8 < IC)) { | |||
| @@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { | |||
| for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { | |||
| for (size_t icb = 0; icb < ICB; icb++) { | |||
| for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ | |||
| for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) { | |||
| const float* fptr = filter + | |||
| (ocb * ICB + icb) * 3 * 3 * 8 * 8 + | |||
| ic_inner * 8; | |||
| @@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 { | |||
| float* output, float* transform_mid_buf, | |||
| size_t oh_start, size_t ow_start, size_t OH, | |||
| size_t OW, size_t oc_start, size_t oc_end, | |||
| size_t unit_idx, size_t nr_units_in_tile, | |||
| const DType& src_dtype, const DType& dst_dtype) { | |||
| size_t oc_index, size_t unit_idx, | |||
| size_t nr_units_in_tile, const DType& src_dtype, | |||
| const DType& dst_dtype) { | |||
| MEGDNN_MARK_USED_VAR(transform_mid_buf); | |||
| megdnn_assert( | |||
| (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||
| oc_end % 8 == 0, | |||
| "Winograd output transform input param is not times of 8!"); | |||
| Op op(src_dtype, dst_dtype); | |||
| //! AT * m * A | |||
| size_t OCB = (oc_end - oc_start) / 8; | |||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
| size_t ocb = (oc - oc_start) / 8; | |||
| size_t oc = oc_start + oc_index; | |||
| size_t ocb = oc_index / 8; | |||
| #define cb(m, n) \ | |||
| auto v##m##n = Vector<float, 8>::load( \ | |||
| output_transform_buf + \ | |||
| (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | |||
| ocb * nr_units_in_tile * 8 + unit_idx * 8); | |||
| UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
| UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
| #undef cb | |||
| /** | |||
| * A | |||
| * | |||
| * 1 0 0 0 0 0 | |||
| * 1 1 1 1 1 1 | |||
| * 1 -1 1 -1 1 -1 | |||
| * 1 2 4 8 16 32 | |||
| * 1 -2 4 -8 16 -32 | |||
| * 1 0.5 0.25 0.125 0.0625 0.03125 | |||
| * 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
| * 0 0.0 0 0 0 1 | |||
| */ | |||
| Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, | |||
| v5subv6; | |||
| /** | |||
| * A | |||
| * | |||
| * 1 0 0 0 0 0 | |||
| * 1 1 1 1 1 1 | |||
| * 1 -1 1 -1 1 -1 | |||
| * 1 2 4 8 16 32 | |||
| * 1 -2 4 -8 16 -32 | |||
| * 1 0.5 0.25 0.125 0.0625 0.03125 | |||
| * 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
| * 0 0.0 0 0 0 1 | |||
| */ | |||
| Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
| #define cb(m) \ | |||
| v1addv2 = v1##m + v2##m; \ | |||
| v1subv2 = v1##m - v2##m; \ | |||
| @@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 { | |||
| auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
| auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
| UNROLL_CALL_NOWRAPPER(8, cb); | |||
| UNROLL_CALL_NOWRAPPER(8, cb); | |||
| #undef cb | |||
| #define cb(m) \ | |||
| @@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 { | |||
| v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
| v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
| UNROLL_CALL_NOWRAPPER(6, cb); | |||
| UNROLL_CALL_NOWRAPPER(6, cb); | |||
| #undef cb | |||
| Vector<float, 8> vbias; | |||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
| vbias = Vector<float, 8>::load(bias + oc); | |||
| Vector<float, 8> vbias; | |||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
| vbias = Vector<float, 8>::load(bias + oc); | |||
| #define cb(m, n) v##m##n += vbias; | |||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||
| #undef cb | |||
| } | |||
| if (bmode != BiasMode::BIAS) { | |||
| } | |||
| if (bmode != BiasMode::BIAS) { | |||
| #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | |||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||
| #undef cb | |||
| } | |||
| } | |||
| #define out_save(oho, owo) \ | |||
| do { \ | |||
| size_t oh = oh_start + oho; \ | |||
| @@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 { | |||
| ow * 8); \ | |||
| } \ | |||
| } while (0); | |||
| UNROLL_CALL_RAW_D2(6, 6, out_save); | |||
| } | |||
| UNROLL_CALL_RAW_D2(6, 6, out_save); | |||
| } | |||
| }; | |||
| #undef CONCAT | |||
| @@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||
| megdnn_assert(IC % 8 == 0); | |||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
| auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
| auto units_w = | |||
| div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
| float* patch = transform_mid_buf; | |||
| float* patchT = transform_mid_buf + 8 * alpha * alpha; | |||
| @@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||
| } | |||
| } | |||
| void winograd_nchw88_6x3_8x8_f::output( | |||
| const float* output_transform_buf, const float* bias, float* output, | |||
| float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||
| size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
| size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||
| void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf, | |||
| const float* bias, float* output, | |||
| float* transform_mid_buf, BiasMode bmode, | |||
| NonlineMode nonline_mode, size_t OH, | |||
| size_t OW, size_t oc_start, | |||
| size_t oc_end, size_t unit_start_idx, | |||
| size_t nr_units_in_tile) { | |||
| #define cb(_bmode, _nonline_op, ...) \ | |||
| OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||
| __VA_ARGS__); | |||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||
| megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, | |||
| float, bmode, nonline_mode, output_transform_buf, bias, output, | |||
| transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||
| unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
| size_t OC = oc_end - oc_start; | |||
| megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||
| "Winograd output transform input param is not times of 8!"); | |||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
| size_t oc_index = oc - oc_start; | |||
| rep(unit_idx, nr_units_in_tile) { | |||
| size_t index = unit_start_idx + unit_idx; | |||
| auto nh = index / units_w; | |||
| auto nw = index % units_w; | |||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||
| megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, | |||
| float, float, bmode, nonline_mode, output_transform_buf, | |||
| bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||
| oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, | |||
| src_dtype, dst_dtype); | |||
| } | |||
| } | |||
| #undef cb | |||
| } | |||
| } // namespace winograd | |||
| } // namespace arm_common | |||
| } // namespace x86 | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||