GitOrigin-RevId: a43077550c
tags/v0.3.2
| @@ -247,33 +247,31 @@ void StrategyHelper< | |||||
| Getter<ctype, input_filter_compute_type> getter(dtype); | Getter<ctype, input_filter_compute_type> getter(dtype); | ||||
| InputVisitor<layout, format> intput_visitor(IC); | InputVisitor<layout, format> intput_visitor(IC); | ||||
| rep(ic, IC) { | |||||
| memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); | |||||
| rep(i, alpha) rep(j, alpha) { | |||||
| int ih = ih_start + i; | |||||
| int iw = iw_start + j; | |||||
| if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { | |||||
| mid_buf1[i * alpha + j] = getter( | |||||
| input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); | |||||
| } | |||||
| memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); | |||||
| rep(i, alpha) rep(j, alpha) { | |||||
| int ih = ih_start + i; | |||||
| int iw = iw_start + j; | |||||
| if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { | |||||
| mid_buf1[i * alpha + j] = getter( | |||||
| input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); | |||||
| } | } | ||||
| } | |||||
| megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
| input_filter_compute_type, true, | |||||
| false>( | |||||
| winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, | |||||
| alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
| megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
| input_filter_compute_type, false, | |||||
| false>( | |||||
| mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, | |||||
| alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
| rep(i, alpha) rep(j, alpha) { | |||||
| input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, | |||||
| unit_idx, i, j)] = | |||||
| mid_buf1[i * alpha + j]; | |||||
| } | |||||
| megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
| input_filter_compute_type, true, | |||||
| false>( | |||||
| winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, | |||||
| alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
| megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
| input_filter_compute_type, false, | |||||
| false>( | |||||
| mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, | |||||
| alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
| rep(i, alpha) rep(j, alpha) { | |||||
| input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, | |||||
| unit_idx, i, j)] = | |||||
| mid_buf1[i * alpha + j]; | |||||
| } | } | ||||
| } | } | ||||
| @@ -287,7 +285,7 @@ void StrategyHelper< | |||||
| output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, | NonlineMode nonline_mode, size_t oh_start, | ||||
| size_t ow_start, size_t OH, size_t OW, size_t oc_start, | size_t ow_start, size_t OH, size_t OW, size_t oc_start, | ||||
| size_t oc_end, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t m, size_t r, | size_t m, size_t r, | ||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| float input_filter_scale, float input_filter_rescale, | float input_filter_scale, float input_filter_rescale, | ||||
| @@ -300,49 +298,49 @@ void StrategyHelper< | |||||
| OutputGetter<output_compute_type, dst_type> getter(dtype); | OutputGetter<output_compute_type, dst_type> getter(dtype); | ||||
| OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | ||||
| for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
| /* gather */ | |||||
| rep(i, alpha) rep(j, alpha) { | |||||
| mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( | |||||
| alpha, oc - oc_start, oc, nr_units_in_tile, unit_idx, i, | |||||
| j)]; | |||||
| } | |||||
| /* A[alpha*m] M[alpha*alpha] */ | |||||
| megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
| output_compute_type, true, false>( | |||||
| winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, | |||||
| alpha, m, alpha, alpha, dtype, dtype); | |||||
| megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
| output_compute_type, false, false>( | |||||
| mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, | |||||
| alpha, alpha, m, m, dtype, dtype); | |||||
| rep(i, m) rep(j, m) { | |||||
| auto oh = oh_start + i; | |||||
| auto ow = ow_start + j; | |||||
| if (oh < OH && ow < OW) { | |||||
| float val = mid_buf1[i * m + j]; | |||||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| val += bias[oc] * input_filter_rescale * | |||||
| input_filter_rescale; | |||||
| } else if (bmode == BiasMode::BIAS) { | |||||
| val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * | |||||
| input_filter_rescale * input_filter_rescale; | |||||
| } | |||||
| val = val * input_filter_scale / | |||||
| (input_filter_rescale * input_filter_rescale * rescale * | |||||
| rescale); | |||||
| if (nonline_mode == NonlineMode::RELU) { | |||||
| val = val > 0 ? val : 0; | |||||
| } else if (nonline_mode == NonlineMode::SIGMOID) { | |||||
| val = 1.f / (expf(-val) + 1.f); | |||||
| } else if (nonline_mode == NonlineMode::H_SWISH) { | |||||
| val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; | |||||
| } else { | |||||
| megdnn_assert(nonline_mode == NonlineMode::IDENTITY); | |||||
| } | |||||
| output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); | |||||
| size_t oc = oc_start + oc_index; | |||||
| /* gather */ | |||||
| rep(i, alpha) rep(j, alpha) { | |||||
| mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( | |||||
| alpha, oc_index, oc, nr_units_in_tile, unit_idx, i, | |||||
| j)]; | |||||
| } | |||||
| /* A[alpha*m] M[alpha*alpha] */ | |||||
| megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
| output_compute_type, true, false>( | |||||
| winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, | |||||
| alpha, m, alpha, alpha, dtype, dtype); | |||||
| megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
| output_compute_type, false, false>( | |||||
| mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, | |||||
| alpha, alpha, m, m, dtype, dtype); | |||||
| rep(i, m) rep(j, m) { | |||||
| auto oh = oh_start + i; | |||||
| auto ow = ow_start + j; | |||||
| if (oh < OH && ow < OW) { | |||||
| float val = mid_buf1[i * m + j]; | |||||
| if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| val += bias[oc] * input_filter_rescale * | |||||
| input_filter_rescale; | |||||
| } else if (bmode == BiasMode::BIAS) { | |||||
| val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * | |||||
| input_filter_rescale * input_filter_rescale; | |||||
| } | |||||
| val = val * input_filter_scale / | |||||
| (input_filter_rescale * input_filter_rescale * rescale * | |||||
| rescale); | |||||
| if (nonline_mode == NonlineMode::RELU) { | |||||
| val = val > 0 ? val : 0; | |||||
| } else if (nonline_mode == NonlineMode::SIGMOID) { | |||||
| val = 1.f / (expf(-val) + 1.f); | |||||
| } else if (nonline_mode == NonlineMode::H_SWISH) { | |||||
| val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; | |||||
| } else { | |||||
| megdnn_assert(nonline_mode == NonlineMode::IDENTITY); | |||||
| } | } | ||||
| output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -44,7 +44,7 @@ public: | |||||
| input_filter_compute_type* input_transform_buf, | input_filter_compute_type* input_transform_buf, | ||||
| input_filter_compute_type* transform_mid_buf, | input_filter_compute_type* transform_mid_buf, | ||||
| int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
| size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||||
| size_t m, size_t r, | size_t m, size_t r, | ||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| float rescale = 1.0f); | float rescale = 1.0f); | ||||
| @@ -54,7 +54,7 @@ public: | |||||
| const output_compute_type* bias, dst_type* output, | const output_compute_type* bias, dst_type* output, | ||||
| output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | ||||
| size_t OH, size_t OW, size_t oc_start, size_t oc_end, | |||||
| size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||||
| size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | ||||
| const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
| float input_filter_scale = 1.0f, // input_scale * filter_scale | float input_filter_scale = 1.0f, // input_scale * filter_scale | ||||
| @@ -55,7 +55,7 @@ public: | |||||
| ohw_tile_size)); | ohw_tile_size)); | ||||
| all_algos.emplace_back(refhold.back().get()); | all_algos.emplace_back(refhold.back().get()); | ||||
| } | } | ||||
| #if 0 | |||||
| #if 1 | |||||
| //! As these algos maybe very slow, it will make fastrun search slow, so | //! As these algos maybe very slow, it will make fastrun search slow, so | ||||
| //! we disable it, but for the test of strategyhelper, we just keep it. | //! we disable it, but for the test of strategyhelper, we just keep it. | ||||
| //! FIXME: I do not know a better way to do it. | //! FIXME: I do not know a better way to do it. | ||||
| @@ -6,8 +6,7 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | */ | ||||
| #include "src/fallback/conv_bias/winograd/strategy.h" | #include "src/fallback/conv_bias/winograd/strategy.h" | ||||
| @@ -31,27 +30,54 @@ void winograd_2x3_1x1_f::filter(const float* filter, | |||||
| } | } | ||||
| void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, | void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, | ||||
| float* transform_mid_buf, int ih_start, | |||||
| int iw_start, size_t IH, size_t IW, size_t IC, | |||||
| size_t unit_idx, size_t nr_units_in_tile) { | |||||
| ::megdnn::winograd::StrategyHelper<float, float, float, float>::input( | |||||
| input, input_transform_buf, transform_mid_buf, ih_start, iw_start, | |||||
| IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, {0, 1, -1}, src_dtype); | |||||
| float* transform_mid_buf, size_t IH, size_t IW, | |||||
| size_t IC, size_t PH, size_t PW, | |||||
| size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
| auto units_w = | |||||
| div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| rep(ic, IC) { | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
| ::megdnn::winograd::StrategyHelper<float, float, float, float>:: | |||||
| input(input, input_transform_buf, transform_mid_buf, | |||||
| ih_start, iw_start, IH, IW, IC, ic, unit_idx, | |||||
| nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
| {0, 1, -1}, src_dtype); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void winograd_2x3_1x1_f::output(const float* output_transform_buf, | void winograd_2x3_1x1_f::output(const float* output_transform_buf, | ||||
| const float* bias, float* output, | const float* bias, float* output, | ||||
| float* transform_mid_buf, BiasMode bmode, | float* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, | |||||
| size_t ow_start, size_t OH, size_t OW, | |||||
| size_t oc_start, size_t oc_end, size_t unit_idx, | |||||
| NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
| size_t oc_start, size_t oc_end, | |||||
| size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
| ::megdnn::winograd::StrategyHelper<float, float, float, float>::output( | |||||
| output_transform_buf, bias, output, transform_mid_buf, bmode, | |||||
| nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
| unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
| {0, 1, -1}, dst_dtype); | |||||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
| size_t OC = oc_end - oc_start; | |||||
| for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
| size_t oc_index = oc - oc_start; | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
| ::megdnn::winograd::StrategyHelper<float, float, float, float>:: | |||||
| output(output_transform_buf, bias, output, | |||||
| transform_mid_buf, bmode, nonline_mode, oh_start, | |||||
| ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, | |||||
| nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
| {0, 1, -1}, dst_dtype); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | ||||
| @@ -71,38 +97,70 @@ void winograd_2x3_4x4_f::filter(const float* filter, | |||||
| } | } | ||||
| void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | ||||
| float* transform_mid_buf, int ih_start, | |||||
| int iw_start, size_t IH, size_t IW, size_t IC, | |||||
| size_t unit_idx, size_t nr_units_in_tile) { | |||||
| ::megdnn::winograd::StrategyHelper< | |||||
| float, float, float, float, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK4>::input(input, input_transform_buf, | |||||
| transform_mid_buf, ih_start, | |||||
| iw_start, IH, IW, IC, | |||||
| unit_idx, nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, {0, 1, -1}, | |||||
| src_dtype); | |||||
| float* transform_mid_buf, size_t IH, size_t IW, | |||||
| size_t IC, size_t PH, size_t PW, | |||||
| size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
| auto units_w = | |||||
| div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| rep(ic, IC) { | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
| ::megdnn::winograd::StrategyHelper< | |||||
| float, float, float, float, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK4>::input(input, | |||||
| input_transform_buf, | |||||
| transform_mid_buf, | |||||
| ih_start, iw_start, | |||||
| IH, IW, IC, ic, | |||||
| unit_idx, | |||||
| nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, | |||||
| {0, 1, -1}, | |||||
| src_dtype); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void winograd_2x3_4x4_f::output(const float* output_transform_buf, | void winograd_2x3_4x4_f::output(const float* output_transform_buf, | ||||
| const float* bias, float* output, | const float* bias, float* output, | ||||
| float* transform_mid_buf, BiasMode bmode, | float* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, | |||||
| size_t ow_start, size_t OH, size_t OW, | |||||
| size_t oc_start, size_t oc_end, size_t unit_idx, | |||||
| NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
| size_t oc_start, size_t oc_end, | |||||
| size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
| ::megdnn::winograd::StrategyHelper< | |||||
| float, float, float, float, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK4>::output(output_transform_buf, bias, | |||||
| output, transform_mid_buf, | |||||
| bmode, nonline_mode, | |||||
| oh_start, ow_start, OH, OW, | |||||
| oc_start, oc_end, unit_idx, | |||||
| nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, {0, 1, -1}, | |||||
| dst_dtype); | |||||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
| size_t OC = oc_end - oc_start; | |||||
| for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
| size_t oc_index = oc - oc_start; | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
| ::megdnn::winograd::StrategyHelper< | |||||
| float, float, float, float, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK4>::output(output_transform_buf, | |||||
| bias, output, | |||||
| transform_mid_buf, | |||||
| bmode, nonline_mode, | |||||
| oh_start, ow_start, | |||||
| OH, OW, OC, oc_start, | |||||
| oc_index, unit_idx, | |||||
| nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, | |||||
| {0, 1, -1}, | |||||
| dst_dtype); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) | ||||
| @@ -119,29 +177,59 @@ void winograd_2x3_1x1_qs8::filter(const int8_t* filter, | |||||
| void winograd_2x3_1x1_qs8::input(const int8_t* input, | void winograd_2x3_1x1_qs8::input(const int8_t* input, | ||||
| int16_t* input_transform_buf, | int16_t* input_transform_buf, | ||||
| int16_t* transform_mid_buf, int ih_start, | |||||
| int iw_start, size_t IH, size_t IW, size_t IC, | |||||
| size_t unit_idx, size_t nr_units_in_tile) { | |||||
| ::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::input( | |||||
| input, input_transform_buf, transform_mid_buf, ih_start, iw_start, | |||||
| IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, {0, 1, -1}, src_dtype, 1.0f); | |||||
| int16_t* transform_mid_buf, size_t IH, | |||||
| size_t IW, size_t IC, size_t PH, size_t PW, | |||||
| size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | |||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
| auto units_w = | |||||
| div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| rep(ic, IC) { | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
| ::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>:: | |||||
| input(input, input_transform_buf, transform_mid_buf, | |||||
| ih_start, iw_start, IH, IW, IC, ic, unit_idx, | |||||
| nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
| {0, 1, -1}, src_dtype, 1.0f); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, | void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, | ||||
| const int* bias, int8_t* output, | const int* bias, int8_t* output, | ||||
| int* transform_mid_buf, BiasMode bmode, | int* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, | |||||
| size_t ow_start, size_t OH, size_t OW, | |||||
| size_t oc_start, size_t oc_end, | |||||
| size_t unit_idx, size_t nr_units_in_tile) { | |||||
| NonlineMode nonline_mode, size_t OH, | |||||
| size_t OW, size_t oc_start, size_t oc_end, | |||||
| size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | |||||
| float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | ||||
| float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale; | float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale; | ||||
| ::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::output( | |||||
| output_transform_buf, bias, output, transform_mid_buf, bmode, | |||||
| nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
| unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
| {0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f); | |||||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
| size_t OC = oc_end - oc_start; | |||||
| for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
| size_t oc_index = oc - oc_start; | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
| ::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>:: | |||||
| output(output_transform_buf, bias, output, | |||||
| transform_mid_buf, bmode, nonline_mode, oh_start, | |||||
| ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, | |||||
| nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
| {0, 1, -1}, dst_dtype, scale_input * scale_filter, | |||||
| 2.0f, 1.0f); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) | ||||
| @@ -162,27 +250,44 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter, | |||||
| void winograd_2x3_8x8_qs8::input(const int8_t* input, | void winograd_2x3_8x8_qs8::input(const int8_t* input, | ||||
| int16_t* input_transform_buf, | int16_t* input_transform_buf, | ||||
| int16_t* transform_mid_buf, int ih_start, | |||||
| int iw_start, size_t IH, size_t IW, size_t IC, | |||||
| size_t unit_idx, size_t nr_units_in_tile) { | |||||
| ::megdnn::winograd::StrategyHelper< | |||||
| int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK8>::input(input, input_transform_buf, | |||||
| transform_mid_buf, ih_start, | |||||
| iw_start, IH, IW, IC, | |||||
| unit_idx, nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, {0, 1, -1}, | |||||
| src_dtype, 1.0f); | |||||
| int16_t* transform_mid_buf, size_t IH, | |||||
| size_t IW, size_t IC, size_t PH, size_t PW, | |||||
| size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | |||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
| auto units_w = | |||||
| div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| rep(ic, IC) { | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
| ::megdnn::winograd::StrategyHelper< | |||||
| int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK8>::input(input, | |||||
| input_transform_buf, | |||||
| transform_mid_buf, | |||||
| ih_start, iw_start, | |||||
| IH, IW, IC, ic, | |||||
| unit_idx, | |||||
| nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, | |||||
| {0, 1, -1}, src_dtype, | |||||
| 1.0f); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | ||||
| const int* bias, int8_t* output, | const int* bias, int8_t* output, | ||||
| int* transform_mid_buf, BiasMode bmode, | int* transform_mid_buf, BiasMode bmode, | ||||
| NonlineMode nonline_mode, size_t oh_start, | |||||
| size_t ow_start, size_t OH, size_t OW, | |||||
| size_t oc_start, size_t oc_end, | |||||
| size_t unit_idx, size_t nr_units_in_tile) { | |||||
| NonlineMode nonline_mode, size_t OH, | |||||
| size_t OW, size_t oc_start, size_t oc_end, | |||||
| size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | |||||
| float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | ||||
| float scale_filter = 0.f; | float scale_filter = 0.f; | ||||
| if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { | if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
| @@ -191,19 +296,37 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | |||||
| megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); | megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); | ||||
| scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; | scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; | ||||
| } | } | ||||
| ::megdnn::winograd::StrategyHelper< | |||||
| int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK8>::output(output_transform_buf, bias, | |||||
| output, transform_mid_buf, | |||||
| bmode, nonline_mode, | |||||
| oh_start, ow_start, OH, OW, | |||||
| oc_start, oc_end, unit_idx, | |||||
| nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, {0, 1, -1}, | |||||
| dst_dtype, | |||||
| scale_input * scale_filter, | |||||
| 2.0f, 1.0f); | |||||
| auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
| size_t OC = oc_end - oc_start; | |||||
| for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
| size_t oc_index = oc - oc_start; | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
| ::megdnn::winograd::StrategyHelper< | |||||
| int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
| param::MatrixMul::Format::MK8>::output(output_transform_buf, | |||||
| bias, output, | |||||
| transform_mid_buf, | |||||
| bmode, nonline_mode, | |||||
| oh_start, ow_start, | |||||
| OH, OW, OC, oc_start, | |||||
| oc_index, unit_idx, | |||||
| nr_units_in_tile, | |||||
| OUTPUT_BLOCK_SIZE, | |||||
| KERNEL_SIZE, | |||||
| {0, 1, -1}, | |||||
| dst_dtype, | |||||
| scale_input * | |||||
| scale_filter, | |||||
| 2.0f, 1.0f); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } // namespace winograd | } // namespace winograd | ||||
| @@ -321,17 +321,10 @@ public: | |||||
| "nr_tiles_in_unit: %zu TILE_SIZE:%zu", | "nr_tiles_in_unit: %zu TILE_SIZE:%zu", | ||||
| nr_tiles_in_unit, unit_tile_size); | nr_tiles_in_unit, unit_tile_size); | ||||
| } | } | ||||
| rep(unit_idx, nr_tiles_in_unit) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * Strategy::OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * Strategy::OUTPUT_BLOCK_SIZE - PW; | |||||
| strategy.input(src_ptr, input_transform_buf, transform_mid_buf, | |||||
| ih_start, iw_start, IH, IW, IC, unit_idx, | |||||
| nr_tiles_in_unit); | |||||
| } | |||||
| //! BTdB | |||||
| strategy.input(src_ptr, input_transform_buf, transform_mid_buf, | |||||
| IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit); | |||||
| rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { | rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { | ||||
| if (format == param::MatrixMul::Format::DEFAULT) { | if (format == param::MatrixMul::Format::DEFAULT) { | ||||
| matmul_param.A_ptr = | matmul_param.A_ptr = | ||||
| @@ -368,22 +361,14 @@ public: | |||||
| } | } | ||||
| matmul_kern(matmul_param); | matmul_kern(matmul_param); | ||||
| } | } | ||||
| /* Y = ATmA */ | |||||
| rep(unit_idx, nr_tiles_in_unit) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| auto nh = index / units_w; | |||||
| auto nw = index % units_w; | |||||
| size_t oh_start = nh * Strategy::OUTPUT_BLOCK_SIZE; | |||||
| size_t ow_start = nw * Strategy::OUTPUT_BLOCK_SIZE; | |||||
| size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; | |||||
| strategy.output( | |||||
| output_transform_buf, bias_ptr, dst_ptr, | |||||
| reinterpret_cast<output_compute_type*>(transform_mid_buf), | |||||
| ncb_param.bias_mode, ncb_param.nonlineMode, oh_start, | |||||
| ow_start, OH, OW, oc_start_idx, oc_end_idx, unit_idx, | |||||
| nr_tiles_in_unit); | |||||
| } | |||||
| //! Y = ATmA | |||||
| size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; | |||||
| strategy.output( | |||||
| output_transform_buf, bias_ptr, dst_ptr, | |||||
| reinterpret_cast<output_compute_type*>(transform_mid_buf), | |||||
| ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW, | |||||
| oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit); | |||||
| }; | }; | ||||
| SmallVector<NCBKern> get_kerns( | SmallVector<NCBKern> get_kerns( | ||||
| @@ -542,15 +527,16 @@ public: | |||||
| size_t IC, size_t oc_start, size_t oc_end); \ | size_t IC, size_t oc_start, size_t oc_end); \ | ||||
| void input(const stype* input, \ | void input(const stype* input, \ | ||||
| input_filter_compute_type* input_transform_buf, \ | input_filter_compute_type* input_transform_buf, \ | ||||
| input_filter_compute_type* transform_mid_buf, int ih_start, \ | |||||
| int iw_start, size_t IH, size_t IW, size_t IC, \ | |||||
| size_t unit_idx, size_t nr_tiles_in_unit); \ | |||||
| input_filter_compute_type* transform_mid_buf, \ | |||||
| size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \ | |||||
| size_t unit_start_idx, size_t nr_tiles_in_unit); \ | |||||
| void output(const output_compute_type* output_transform_buf, \ | void output(const output_compute_type* output_transform_buf, \ | ||||
| const output_compute_type* bias, dst_type* output, \ | const output_compute_type* bias, dst_type* output, \ | ||||
| output_compute_type* transform_mid_buf, BiasMode bmode, \ | output_compute_type* transform_mid_buf, BiasMode bmode, \ | ||||
| NonlineMode nonline_mode, size_t oh_start, \ | |||||
| size_t ow_start, size_t OH, size_t OW, size_t oc_start, \ | |||||
| size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \ | |||||
| NonlineMode nonline_mode, size_t OH, size_t OW, \ | |||||
| size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | |||||
| size_t nr_tiles_in_unit); \ | |||||
| }; | }; | ||||
| #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | ||||
| @@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter, | |||||
| transform_mid_buf, OC, IC, oc_start, | transform_mid_buf, OC, IC, oc_start, | ||||
| oc_end); | oc_end); | ||||
| } | } | ||||
| void winograd_nchw88_2x3_8x8_f::input(const float* input, | void winograd_nchw88_2x3_8x8_f::input(const float* input, | ||||
| float* input_transform_buf, | float* input_transform_buf, | ||||
| float* transform_mid_buf, int ih_start, | |||||
| int iw_start, size_t IH, size_t IW, | |||||
| size_t IC, size_t unit_idx, | |||||
| float* transform_mid_buf, size_t IH, | |||||
| size_t IW, size_t IC, size_t PH, | |||||
| size_t PW, size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
| megdnn_assert(IC % 8 == 0); | megdnn_assert(IC % 8 == 0); | ||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
| auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| float* patch = transform_mid_buf; | float* patch = transform_mid_buf; | ||||
| float* patchT = transform_mid_buf + 8 * alpha * alpha; | float* patchT = transform_mid_buf + 8 * alpha * alpha; | ||||
| if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
| iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
| for (size_t ic = 0; ic < IC; ic += 8) { | |||||
| InputTransform2X3_NCHW88::prepare<true>( | |||||
| input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); | |||||
| InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, ic, | |||||
| IC); | |||||
| } | |||||
| } else { | |||||
| for (size_t ic = 0; ic < IC; ic += 8) { | |||||
| InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, ih_start, | |||||
| iw_start, IH, IW, ic, IC); | |||||
| InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, ic, | |||||
| IC); | |||||
| for (size_t ic = 0; ic < IC; ic += 8) { | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
| if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
| iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
| InputTransform2X3_NCHW88::prepare<true>(input, patch, patchT, | |||||
| ih_start, iw_start, IH, | |||||
| IW, ic, IC); | |||||
| InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, | |||||
| ic, IC); | |||||
| } else { | |||||
| InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, | |||||
| ih_start, iw_start, IH, | |||||
| IW, ic, IC); | |||||
| InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, | |||||
| ic, IC); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter, | |||||
| transform_mid_buf, OC, IC, oc_start, | transform_mid_buf, OC, IC, oc_start, | ||||
| oc_end); | oc_end); | ||||
| } | } | ||||
| void winograd_nchw88_6x3_8x8_f::input(const float* input, | void winograd_nchw88_6x3_8x8_f::input(const float* input, | ||||
| float* input_transform_buf, | float* input_transform_buf, | ||||
| float* transform_mid_buf, int ih_start, | |||||
| int iw_start, size_t IH, size_t IW, | |||||
| size_t IC, size_t unit_idx, | |||||
| float* transform_mid_buf, size_t IH, | |||||
| size_t IW, size_t IC, size_t PH, | |||||
| size_t PW, size_t unit_start_idx, | |||||
| size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
| megdnn_assert(IC % 8 == 0); | megdnn_assert(IC % 8 == 0); | ||||
| // OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
| auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
| float* patch = transform_mid_buf; | float* patch = transform_mid_buf; | ||||
| float* patchT = transform_mid_buf + 8 * alpha * alpha; | float* patchT = transform_mid_buf + 8 * alpha * alpha; | ||||
| if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
| iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
| for (size_t ic = 0; ic < IC; ic += 8) { | |||||
| InputTransform6X3_NCHW88::prepare<true>( | |||||
| input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); | |||||
| InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, ic, | |||||
| IC); | |||||
| } | |||||
| } else { | |||||
| for (size_t ic = 0; ic < IC; ic += 8) { | |||||
| InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, ih_start, | |||||
| iw_start, IH, IW, ic, IC); | |||||
| InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, ic, | |||||
| IC); | |||||
| for (size_t ic = 0; ic < IC; ic += 8) { | |||||
| rep(unit_idx, nr_units_in_tile) { | |||||
| size_t index = unit_start_idx + unit_idx; | |||||
| size_t nh = index / units_w; | |||||
| size_t nw = index % units_w; | |||||
| int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
| int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
| if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
| iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
| InputTransform6X3_NCHW88::prepare<true>(input, patch, patchT, | |||||
| ih_start, iw_start, IH, | |||||
| IW, ic, IC); | |||||
| InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, | |||||
| ic, IC); | |||||
| } else { | |||||
| InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, | |||||
| ih_start, iw_start, IH, | |||||
| IW, ic, IC); | |||||
| InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
| unit_idx, nr_units_in_tile, | |||||
| ic, IC); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||