GitOrigin-RevId: 6d4b225ea5
tags/v0.3.2
| @@ -235,7 +235,7 @@ void StrategyHelper< | |||||
| input_filter_compute_type* input_transform_buf, | input_filter_compute_type* input_transform_buf, | ||||
| input_filter_compute_type* transform_mid_buf, | input_filter_compute_type* transform_mid_buf, | ||||
| int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
| size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t m, size_t r, | size_t m, size_t r, | ||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| float rescale) { | float rescale) { | ||||
| @@ -284,7 +284,7 @@ void StrategyHelper< | |||||
| const output_compute_type* bias, dst_type* output, | const output_compute_type* bias, dst_type* output, | ||||
| output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, | NonlineMode nonline_mode, size_t oh_start, | ||||
| size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||||
| size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start, | |||||
| size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | ||||
| size_t m, size_t r, | size_t m, size_t r, | ||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| @@ -296,7 +296,7 @@ void StrategyHelper< | |||||
| output_compute_type* mid_buf1 = transform_mid_buf; | output_compute_type* mid_buf1 = transform_mid_buf; | ||||
| output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; | output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; | ||||
| OutputGetter<output_compute_type, dst_type> getter(dtype); | OutputGetter<output_compute_type, dst_type> getter(dtype); | ||||
| OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | |||||
| OutputVisitor<layout, format> output_visitor(OC); | |||||
| size_t oc = oc_start + oc_index; | size_t oc = oc_start + oc_index; | ||||
| @@ -6,8 +6,7 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -44,8 +43,8 @@ public: | |||||
| input_filter_compute_type* input_transform_buf, | input_filter_compute_type* input_transform_buf, | ||||
| input_filter_compute_type* transform_mid_buf, | input_filter_compute_type* transform_mid_buf, | ||||
| int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
| size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t m, size_t r, | |||||
| size_t IC, size_t ic, size_t unit_idx, | |||||
| size_t nr_units_in_tile, size_t m, size_t r, | |||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| float rescale = 1.0f); | float rescale = 1.0f); | ||||
| @@ -54,7 +53,7 @@ public: | |||||
| const output_compute_type* bias, dst_type* output, | const output_compute_type* bias, dst_type* output, | ||||
| output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | ||||
| size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||||
| size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, | |||||
| size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | ||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| float input_filter_scale = 1.0f, // input_scale * filter_scale | float input_filter_scale = 1.0f, // input_scale * filter_scale | ||||
| @@ -45,7 +45,6 @@ public: | |||||
| static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | ||||
| for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
| if (algo->algoset() == | if (algo->algoset() == | ||||
| //! TODO: threre should filter MK matmul | |||||
| MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -536,7 +536,6 @@ public: | |||||
| NonlineMode nonline_mode, size_t OH, size_t OW, \ | NonlineMode nonline_mode, size_t OH, size_t OW, \ | ||||
| size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | ||||
| size_t nr_tiles_in_unit); \ | size_t nr_tiles_in_unit); \ | ||||
| }; | }; | ||||
| #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | ||||
| @@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 { | |||||
| float* output, float* transform_mid_buf, | float* output, float* transform_mid_buf, | ||||
| size_t oh_start, size_t ow_start, size_t OH, | size_t oh_start, size_t ow_start, size_t OH, | ||||
| size_t OW, size_t oc_start, size_t oc_end, | size_t OW, size_t oc_start, size_t oc_end, | ||||
| size_t unit_idx, size_t nr_units_in_tile, | |||||
| const DType& src_dtype, const DType& dst_dtype) { | |||||
| size_t oc_index, size_t unit_idx, | |||||
| size_t nr_units_in_tile, const DType& src_dtype, | |||||
| const DType& dst_dtype) { | |||||
| MEGDNN_MARK_USED_VAR(transform_mid_buf); | MEGDNN_MARK_USED_VAR(transform_mid_buf); | ||||
| megdnn_assert( | |||||
| (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||||
| oc_end % 8 == 0, | |||||
| "Winograd output transform input param is not times of 8!"); | |||||
| Op op(src_dtype, dst_dtype); | Op op(src_dtype, dst_dtype); | ||||
| //! AT * m * A | //! AT * m * A | ||||
| size_t OCB = (oc_end - oc_start) / 8; | size_t OCB = (oc_end - oc_start) / 8; | ||||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
| size_t ocb = (oc - oc_start) / 8; | |||||
| size_t oc = oc_start + oc_index; | |||||
| size_t ocb = oc_index / 8; | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| auto v##m##n = Vector<float, 8>::load( \ | auto v##m##n = Vector<float, 8>::load( \ | ||||
| output_transform_buf + \ | output_transform_buf + \ | ||||
| (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | ||||
| ocb * nr_units_in_tile * 8 + unit_idx * 8); | ocb * nr_units_in_tile * 8 + unit_idx * 8); | ||||
| UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
| UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
| #undef cb | #undef cb | ||||
| //! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
| //! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
| //! v20 v21 v22 v23 1 -1 | |||||
| //! v30 v31 v32 v33 0 1 | |||||
| //! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
| //! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
| //! v20 v21 v22 v23 1 -1 | |||||
| //! v30 v31 v32 v33 0 1 | |||||
| #define cb(m) \ | #define cb(m) \ | ||||
| auto t0##m = v0##m + v1##m + v2##m; \ | auto t0##m = v0##m + v1##m + v2##m; \ | ||||
| auto t1##m = v1##m - v2##m + v3##m; | auto t1##m = v1##m - v2##m + v3##m; | ||||
| UNROLL_CALL_NOWRAPPER(4, cb); | |||||
| UNROLL_CALL_NOWRAPPER(4, cb); | |||||
| #undef cb | #undef cb | ||||
| #define cb(m) \ | #define cb(m) \ | ||||
| v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | ||||
| v##m##1 = t##m##1 - t##m##2 + t##m##3; | v##m##1 = t##m##1 - t##m##2 + t##m##3; | ||||
| UNROLL_CALL_NOWRAPPER(2, cb); | |||||
| UNROLL_CALL_NOWRAPPER(2, cb); | |||||
| #undef cb | #undef cb | ||||
| Vector<float, 8> vbias; | |||||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| vbias = Vector<float, 8>::load(bias + oc); | |||||
| Vector<float, 8> vbias; | |||||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| vbias = Vector<float, 8>::load(bias + oc); | |||||
| #define cb(m, n) v##m##n += vbias; | #define cb(m, n) v##m##n += vbias; | ||||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
| #undef cb | #undef cb | ||||
| } | |||||
| if (bmode != BiasMode::BIAS) { | |||||
| } | |||||
| if (bmode != BiasMode::BIAS) { | |||||
| #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | ||||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
| UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
| #undef cb | #undef cb | ||||
| } | |||||
| } | |||||
| #define out_save(oho, owo) \ | #define out_save(oho, owo) \ | ||||
| do { \ | do { \ | ||||
| size_t oh = oh_start + oho; \ | size_t oh = oh_start + oho; \ | ||||
| @@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 { | |||||
| ow * 8); \ | ow * 8); \ | ||||
| } \ | } \ | ||||
| } while (0); | } while (0); | ||||
| UNROLL_CALL_RAW_D2(2, 2, out_save); | |||||
| } | |||||
| UNROLL_CALL_RAW_D2(2, 2, out_save); | |||||
| } | } | ||||
| }; | }; | ||||
| #undef CONCAT | #undef CONCAT | ||||
| @@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, | |||||
| } | } | ||||
| } | } | ||||
| void winograd_nchw88_2x3_8x8_f::output( | |||||
| const float* output_transform_buf, const float* bias, float* output, | |||||
| float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||||
| size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||||
| size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||||
| void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf, | |||||
| const float* bias, float* output, | |||||
| float* transform_mid_buf, BiasMode bmode, | |||||
| NonlineMode nonline_mode, size_t OH, | |||||
| size_t OW, size_t oc_start, | |||||
| size_t oc_end, size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | |||||
| #define cb(_bmode, _nonline_op, ...) \ | #define cb(_bmode, _nonline_op, ...) \ | ||||
| OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | ||||
| __VA_ARGS__); | __VA_ARGS__); | ||||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||||
| megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, | |||||
| float, bmode, nonline_mode, output_transform_buf, bias, output, | |||||
| transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
| unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
| size_t OC = oc_end - oc_start; | |||||
| megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||||
| "Winograd output transform input param is not times of 8!"); | |||||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
| size_t oc_index = oc - oc_start; | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||||
| megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, | |||||
| float, float, bmode, nonline_mode, output_transform_buf, | |||||
| bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||||
| oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, | |||||
| dst_dtype); | |||||
| } | |||||
| } | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| @@ -19,10 +20,10 @@ | |||||
| #include <x86intrin.h> | #include <x86intrin.h> | ||||
| #ifdef WIN32CMAKE | #ifdef WIN32CMAKE | ||||
| #include <avxintrin.h> | |||||
| #include <smmintrin.h> | |||||
| #include <avx2intrin.h> | #include <avx2intrin.h> | ||||
| #include <avxintrin.h> | |||||
| #include <fmaintrin.h> | #include <fmaintrin.h> | ||||
| #include <smmintrin.h> | |||||
| #endif | #endif | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 { | |||||
| int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
| size_t ic, size_t IC) { | size_t ic, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(patch); | MEGDNN_MARK_USED_VAR(patch); | ||||
| size_t IW8 = IW * 8; //! For nchw88 mode | |||||
| size_t IW8 = IW * 8; //! For nchw88 mode | |||||
| size_t iw8_start = iw_start * 8; //! For nchw88 mode | size_t iw8_start = iw_start * 8; //! For nchw88 mode | ||||
| size_t icb = ic / 8; | size_t icb = ic / 8; | ||||
| if (!(inner && ic + 8 < IC)) { | if (!(inner && ic + 8 < IC)) { | ||||
| @@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { | |||||
| for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { | for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { | ||||
| for (size_t icb = 0; icb < ICB; icb++) { | for (size_t icb = 0; icb < ICB; icb++) { | ||||
| for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ | |||||
| for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) { | |||||
| const float* fptr = filter + | const float* fptr = filter + | ||||
| (ocb * ICB + icb) * 3 * 3 * 8 * 8 + | (ocb * ICB + icb) * 3 * 3 * 8 * 8 + | ||||
| ic_inner * 8; | ic_inner * 8; | ||||
| @@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 { | |||||
| float* output, float* transform_mid_buf, | float* output, float* transform_mid_buf, | ||||
| size_t oh_start, size_t ow_start, size_t OH, | size_t oh_start, size_t ow_start, size_t OH, | ||||
| size_t OW, size_t oc_start, size_t oc_end, | size_t OW, size_t oc_start, size_t oc_end, | ||||
| size_t unit_idx, size_t nr_units_in_tile, | |||||
| const DType& src_dtype, const DType& dst_dtype) { | |||||
| size_t oc_index, size_t unit_idx, | |||||
| size_t nr_units_in_tile, const DType& src_dtype, | |||||
| const DType& dst_dtype) { | |||||
| MEGDNN_MARK_USED_VAR(transform_mid_buf); | MEGDNN_MARK_USED_VAR(transform_mid_buf); | ||||
| megdnn_assert( | |||||
| (oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||||
| oc_end % 8 == 0, | |||||
| "Winograd output transform input param is not times of 8!"); | |||||
| Op op(src_dtype, dst_dtype); | Op op(src_dtype, dst_dtype); | ||||
| //! AT * m * A | //! AT * m * A | ||||
| size_t OCB = (oc_end - oc_start) / 8; | size_t OCB = (oc_end - oc_start) / 8; | ||||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
| size_t ocb = (oc - oc_start) / 8; | |||||
| size_t oc = oc_start + oc_index; | |||||
| size_t ocb = oc_index / 8; | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| auto v##m##n = Vector<float, 8>::load( \ | auto v##m##n = Vector<float, 8>::load( \ | ||||
| output_transform_buf + \ | output_transform_buf + \ | ||||
| (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | ||||
| ocb * nr_units_in_tile * 8 + unit_idx * 8); | ocb * nr_units_in_tile * 8 + unit_idx * 8); | ||||
| UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
| UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
| #undef cb | #undef cb | ||||
| /** | |||||
| * A | |||||
| * | |||||
| * 1 0 0 0 0 0 | |||||
| * 1 1 1 1 1 1 | |||||
| * 1 -1 1 -1 1 -1 | |||||
| * 1 2 4 8 16 32 | |||||
| * 1 -2 4 -8 16 -32 | |||||
| * 1 0.5 0.25 0.125 0.0625 0.03125 | |||||
| * 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
| * 0 0.0 0 0 0 1 | |||||
| */ | |||||
| Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, | |||||
| v5subv6; | |||||
| /** | |||||
| * A | |||||
| * | |||||
| * 1 0 0 0 0 0 | |||||
| * 1 1 1 1 1 1 | |||||
| * 1 -1 1 -1 1 -1 | |||||
| * 1 2 4 8 16 32 | |||||
| * 1 -2 4 -8 16 -32 | |||||
| * 1 0.5 0.25 0.125 0.0625 0.03125 | |||||
| * 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
| * 0 0.0 0 0 0 1 | |||||
| */ | |||||
| Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
| #define cb(m) \ | #define cb(m) \ | ||||
| v1addv2 = v1##m + v2##m; \ | v1addv2 = v1##m + v2##m; \ | ||||
| v1subv2 = v1##m - v2##m; \ | v1subv2 = v1##m - v2##m; \ | ||||
| @@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 { | |||||
| auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | ||||
| auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | ||||
| UNROLL_CALL_NOWRAPPER(8, cb); | |||||
| UNROLL_CALL_NOWRAPPER(8, cb); | |||||
| #undef cb | #undef cb | ||||
| #define cb(m) \ | #define cb(m) \ | ||||
| @@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 { | |||||
| v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | ||||
| v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | ||||
| UNROLL_CALL_NOWRAPPER(6, cb); | |||||
| UNROLL_CALL_NOWRAPPER(6, cb); | |||||
| #undef cb | #undef cb | ||||
| Vector<float, 8> vbias; | |||||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| vbias = Vector<float, 8>::load(bias + oc); | |||||
| Vector<float, 8> vbias; | |||||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| vbias = Vector<float, 8>::load(bias + oc); | |||||
| #define cb(m, n) v##m##n += vbias; | #define cb(m, n) v##m##n += vbias; | ||||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
| #undef cb | #undef cb | ||||
| } | |||||
| if (bmode != BiasMode::BIAS) { | |||||
| } | |||||
| if (bmode != BiasMode::BIAS) { | |||||
| #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | ||||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
| UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
| #undef cb | #undef cb | ||||
| } | |||||
| } | |||||
| #define out_save(oho, owo) \ | #define out_save(oho, owo) \ | ||||
| do { \ | do { \ | ||||
| size_t oh = oh_start + oho; \ | size_t oh = oh_start + oho; \ | ||||
| @@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 { | |||||
| ow * 8); \ | ow * 8); \ | ||||
| } \ | } \ | ||||
| } while (0); | } while (0); | ||||
| UNROLL_CALL_RAW_D2(6, 6, out_save); | |||||
| } | |||||
| UNROLL_CALL_RAW_D2(6, 6, out_save); | |||||
| } | } | ||||
| }; | }; | ||||
| #undef CONCAT | #undef CONCAT | ||||
| @@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||||
| megdnn_assert(IC % 8 == 0); | megdnn_assert(IC % 8 == 0); | ||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | // OW = IW + 2 * PW - KERNEL_SIZE + 1 | ||||
| auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| auto units_w = | |||||
| div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| float* patch = transform_mid_buf; | float* patch = transform_mid_buf; | ||||
| float* patchT = transform_mid_buf + 8 * alpha * alpha; | float* patchT = transform_mid_buf + 8 * alpha * alpha; | ||||
| @@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||||
| } | } | ||||
| } | } | ||||
| void winograd_nchw88_6x3_8x8_f::output( | |||||
| const float* output_transform_buf, const float* bias, float* output, | |||||
| float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||||
| size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||||
| size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||||
| void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf, | |||||
| const float* bias, float* output, | |||||
| float* transform_mid_buf, BiasMode bmode, | |||||
| NonlineMode nonline_mode, size_t OH, | |||||
| size_t OW, size_t oc_start, | |||||
| size_t oc_end, size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | |||||
| #define cb(_bmode, _nonline_op, ...) \ | #define cb(_bmode, _nonline_op, ...) \ | ||||
| OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | ||||
| __VA_ARGS__); | __VA_ARGS__); | ||||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||||
| megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, | |||||
| float, bmode, nonline_mode, output_transform_buf, bias, output, | |||||
| transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
| unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
| size_t OC = oc_end - oc_start; | |||||
| megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||||
| "Winograd output transform input param is not times of 8!"); | |||||
| for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
| size_t oc_index = oc - oc_start; | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
| DISPATCH_CONV_WINOGRAD_BIAS( | |||||
| megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, | |||||
| float, float, bmode, nonline_mode, output_transform_buf, | |||||
| bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||||
| oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, | |||||
| src_dtype, dst_dtype); | |||||
| } | |||||
| } | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| } // namespace winograd | } // namespace winograd | ||||
| } // namespace arm_common | |||||
| } // namespace x86 | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||