From d346c8783643d29ef79641ec09baf6b146ec6edf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 Mar 2020 13:42:07 +0800 Subject: [PATCH] fix(dnn/fallbackls): delete the conv_bias fallback offset GitOrigin-RevId: c91aee2c7cfc95d1f31cc7f7eb7a05ece40ba002 --- dnn/src/fallback/conv_bias/im2col/algos.cpp | 78 ++++---- dnn/src/fallback/conv_bias/opr_impl.cpp | 175 ++++++++++++------ dnn/src/fallback/conv_bias/opr_impl.h | 15 ++ .../fallback/conv_bias/winograd/winograd.h | 14 +- 4 files changed, 189 insertions(+), 93 deletions(-) diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 368a8e2a..b2712c8c 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -57,11 +57,12 @@ public: const ConvBiasImpl::NCBKernParam& param, const WorkspaceBundle& bundle_thread, size_t bundle_id, size_t oc_cur_index, size_t OHW, bool is_dst_8bit, - bool ohw_bigger_ohwblock) { + bool ohw_bigger_ohwblock, size_t batch_id, size_t group_id) { if (is_dst_8bit || !ohw_bigger_ohwblock) { return static_cast(bundle_thread.get(bundle_id)); } else { - dtype* dst = param.dst() + oc_cur_index * OHW; + dtype* dst = + param.dst(batch_id, group_id) + oc_cur_index * OHW; return static_cast(dst); } } @@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle, size_t IW2 = IW + 2 * PW; size_t IH2 = IH + 2 * PH; + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; + size_t channel_id = ncb_index.ndrange_id[2]; size_t padding_group_size = IH2 * IW2 * IC; - size_t input_channel_offset = IH * IW * ncb_index.ndrange_id[2]; - size_t workspace_channel_offset = IH2 * IW2 * ncb_index.ndrange_id[2]; - size_t workspace_group_offset = - ncb_index.ndrange_id[0] * padding_group_size; - size_t workspace_batch_offset = param.filter_meta.group * - ncb_index.ndrange_id[1] * - padding_group_size; + size_t input_channel_offset = IH * IW * channel_id; + size_t workspace_channel_offset = IH2 * IW2 * channel_id; + size_t workspace_group_offset = group_id * padding_group_size; + size_t workspace_batch_offset = + param.filter_meta.group * batch_id * padding_group_size; bundle.set(param.workspace_ptr); src_ctype src_zp = static_cast(0); if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { src_zp = param.src_type.param().zero_point; } - src_ctype* src = const_cast(param.src() + - input_channel_offset); + src_ctype* src = const_cast( + param.src(batch_id, group_id) + input_channel_offset); src_ctype* src2; src2 = static_cast( bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + @@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, */ #define COPY_BIAS() \ - const bias_ctype* bias_ptr = \ - static_cast(param.bias_ptr); \ + const bias_ctype* bias_ptr = static_cast( \ + param.bias(batch_id, group_id)); \ bias_ctype* bias_temp_ptr = \ PtrGetter::get_bias_temp_ptr(param, bundle_thread); \ if (param.bias_mode == megdnn::BiasMode::BIAS) { \ @@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, #define IM2COL() \ src_ctype* im2col_dst = nullptr; \ src_ctype* no_padding_src = \ - const_cast(param.src()) + ohw_cur_index; \ + const_cast(param.src(batch_id, group_id)) + \ + ohw_cur_index; \ if (!special_1x1) { \ size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \ src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr( \ @@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, param.filter_meta.group * ncb_index.ndrange_id[1]) * \ padding_group_size); \ if (PH == 0 && PW == 0) { \ - src2 = const_cast(param.src()); \ + src2 = const_cast( \ + param.src(batch_id, group_id)); \ } \ im2col_dst = static_cast(bundle_thread.get( \ Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \ @@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, output_block_size); \ if (!skip_copy_dst) { \ dst_ctype* dst_tmp_ptr = reinterpret_cast(matmul_dst); \ - dst_ctype* dst = \ - param.dst() + oc_cur_index * OHW + ohw_cur_index; \ + dst_ctype* dst = param.dst(batch_id, group_id) + \ + oc_cur_index * OHW + ohw_cur_index; \ for (size_t oc = 0; oc < output_block_oc_size; oc++) { \ std::memcpy(dst, dst_tmp_ptr, \ sizeof(dst_ctype) * output_block_size); \ @@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr( \ param, bundle_thread, \ Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \ - is_dst_8bit, is_ohw_size_bigger); + is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); #define MATMUL_COMPUTE() \ auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ @@ -272,6 +276,7 @@ public: ConvBiasImpl::NCBKernIndex ncb_index) { bundle.set(param.workspace_ptr); fallback::MatrixMulImpl::KernParam matmul_param; + size_t group_id = ncb_index.ndrange_id[0]; static_cast(matmul_param) = matmulparam; size_t packA_group_size = @@ -283,11 +288,11 @@ public: matmul_algo->get_packA_type_size(); size_t a_panel_offset = ncb_index.ndrange_id[2] * packed_per_oc_block_size; - int8_t* a_panel = - static_cast( - bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + - ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; - matmul_param.A_ptr = const_cast(param.filter()); + int8_t* a_panel = static_cast(bundle.get( + Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + + group_id * packA_group_size + a_panel_offset; + matmul_param.A_ptr = + const_cast(param.filter(group_id)); matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2], matmul_algo->get_inner_block_size().m); }; @@ -309,6 +314,8 @@ public: auto IH2 = IH + 2 * PH; auto IW2 = IW + 2 * PW; size_t OHW = OH * OW; + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; size_t output_block_size = std::min( ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); size_t output_block_oc_size = std::min( @@ -369,11 +376,11 @@ public: \ src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr( \ bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \ - ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset); \ + group_id * packA_group_size + a_panel_offset); \ matmul_dst = PtrGetter::get_matmul_dst_ptr( \ param, bundle_thread, \ Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ - OHW, is_dst_8bit, is_ohw_size_bigger); + OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); #define MATMUL_COMPUTE() \ auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ @@ -402,6 +409,7 @@ public: matmulparam; size_t OC = param.filter_meta.ocpg; size_t oc_tile_size = matmul_param.M; + size_t group_id = ncb_index.ndrange_id[0]; size_t output_block_oc_size = std::min( oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size); size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size; @@ -411,12 +419,12 @@ public: size_t a_panel_offset = ncb_index.ndrange_id[2] * matmul_algo->get_bundle(matmul_param).get_size(0); - int8_t* a_panel = - static_cast( - bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + - ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; - matmul_param.A_ptr = const_cast(param.filter()) + - oc_cur_index * matmul_param.K; + int8_t* a_panel = static_cast(bundle.get( + Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + + group_id * packA_group_size + a_panel_offset; + matmul_param.A_ptr = + const_cast(param.filter(group_id)) + + oc_cur_index * matmul_param.K; matmul_param.M = output_block_oc_size; matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); }; @@ -437,6 +445,8 @@ public: MEGDNN_MARK_USED_VAR(N); auto IH2 = IH + 2 * PH; auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; size_t OHW = OH * OW; size_t output_block_size = std::min( ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); @@ -490,11 +500,11 @@ public: #define PREPAR_MATMUL_DATA() \ bias_ctype* matmul_dst = nullptr; \ const src_ctype* filter = \ - param.filter() + oc_cur_index * IC * FH * FW; \ + param.filter(group_id) + oc_cur_index * IC * FH * FW; \ matmul_dst = PtrGetter::get_matmul_dst_ptr( \ param, bundle_thread, \ Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ - OHW, is_dst_8bit, is_ohw_size_bigger); + OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); #define MATMUL_COMPUTE() \ matmul_param.M = output_block_oc_size; \ @@ -526,6 +536,8 @@ public: MEGDNN_MARK_USED_VAR(N); auto IH2 = IH + 2 * PH; auto IW2 = IW + 2 * PW; + size_t group_id = ncb_index.ndrange_id[0]; + size_t batch_id = ncb_index.ndrange_id[1]; size_t OHW = OH * OW; size_t output_block_size = std::min( ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 8f431f46..7985770b 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -245,65 +245,10 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param( void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); - size_t src_batch_stride = param.inp_bs * param.src_type.size(); - size_t dst_batch_stride = param.out_bs * param.dst_type.size(); - size_t bias_batch_stride = 0; - if (param.bias_mode == BiasMode::BIAS) { - bias_batch_stride = param.bias_bs * param.bias_type.size(); - } for (auto&& kernel : ncb_kerns) { - megdnn_assert( - param.filter_meta.format == Param::Format::NCHW || - param.filter_meta.format == Param::Format::NHWC || - param.filter_meta.format == - Param::Format::NCHW_WINOGRAD || - param.filter_meta.format == Param::Format::NCHW88 || - param.filter_meta.format == - Param::Format::NCHW88_WINOGRAD, - "invalid conv format"); - ptrdiff_t istrd = 0, fstrd = 0, bstrd = 0, ostrd = 0; - if (param.filter_meta.format == Param::Format::NCHW_WINOGRAD || - param.filter_meta.format == Param::Format::NCHW88_WINOGRAD) { - fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * - (param.filter_meta.spatial[0] + param.output_block_size - - 1) * - (param.filter_meta.spatial[1] + param.output_block_size - - 1) * - param.filter_type.size(); - } else { - fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * - param.filter_meta.spatial[0] * - param.filter_meta.spatial[1] * param.filter_type.size(); - } - istrd = param.filter_meta.icpg * param.src_type.size(); - ostrd = param.filter_meta.ocpg * param.dst_type.size(); - if (param.bias_mode != BiasMode::NO_BIAS) { - bstrd = param.filter_meta.ocpg * param.bias_type.size(); - } - if (param.filter_meta.format == Param::Format::NCHW || - param.filter_meta.format == Param::Format::NCHW_WINOGRAD || - param.filter_meta.format == Param::Format::NCHW88_WINOGRAD) { - istrd *= param.isz[0] * param.isz[1]; - ostrd *= param.osz[0] * param.osz[1]; - if (param.bias_mode == BiasMode::BIAS) { - bstrd *= param.osz[0] * param.osz[1]; - } - } else { - // must be NHWC. No action performed. - } auto run = [=](size_t index, size_t thread_id) { auto copy_param = param; CpuNDRange ndrange_id(kernel.global_size, index); - size_t group_id = ndrange_id[0]; - size_t batch_id = ndrange_id[1]; - //! The kernel ptr point to batch index - incr_ptr(copy_param.src_ptr, - group_id * istrd + batch_id * src_batch_stride); - incr_ptr(copy_param.filter_ptr, group_id * fstrd); - incr_ptr(copy_param.bias_ptr, - group_id * bstrd + batch_id * bias_batch_stride); - incr_ptr(copy_param.dst_ptr, - group_id * ostrd + batch_id * dst_batch_stride); kernel.kern(copy_param, {thread_id, ndrange_id}); }; static_cast(handle())->dispatch_kern( @@ -381,4 +326,124 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { return "F0"; } +namespace megdnn{ +namespace fallback { +//! when format is nchwxx and channel wise mode, multi group will pack + //! together, so pack_group_size is the number of packed group +template +const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id, + size_t group_pack_size) const { + src_type.assert_is_compatible_ctype(); + size_t batch_offset = batch_id * inp_bs * src_type.size(); + size_t group_offset = group_pack_size * group_id * filter_meta.icpg * + isz[0] * isz[1] * src_type.size(); + return reinterpret_cast(reinterpret_cast(src_ptr) + + batch_offset + group_offset); +} + +//! when format is nchwxx and channel wise mode, multi group will pack +//! together, so pack_group_size is the number of packed group +template +const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, + size_t pack_group_size) const { + size_t group_offset = 0_z; + switch (filter_meta.format) { + case Param::Format::NCHW: { + group_offset = pack_group_size * group_id * filter_meta.icpg * + filter_meta.ocpg * filter_meta.spatial[0] * + filter_meta.spatial[1] * filter_type.size(); + break; + } + case Param::Format::NCHW88: { + size_t group = filter_meta.group; + size_t icpg = filter_meta.icpg; + size_t ocpg = filter_meta.ocpg; + //! four format of weight layout + //! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh, + //! fw, 8, 8} + //! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8} + megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) || + (group % 8 == 0 && icpg == 1 && ocpg == 1 && + pack_group_size > 1) || + (group == 1 && ocpg % 8 == 0), + "The filter shepe is not right of nchw88"); + group_offset = pack_group_size * group_id * filter_meta.icpg * + filter_meta.ocpg * filter_meta.spatial[0] * + filter_meta.spatial[1] * filter_type.size(); + + break; + } + case ConvBiasImpl::Param::Format::NCHW_WINOGRAD: + case ConvBiasImpl::Param::Format::NCHW88_WINOGRAD: { + //! four format of weight layout + //! 1. {g, alpha, alpha, ocpg/8, icpg/8, 8, 8} + //! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8} + //! 3. {g, alpha, alpha, oc, ic, 8, 8} + //! 4. {alpha, alpha, oc, ic} + group_offset = pack_group_size * group_id * filter_meta.icpg * + filter_meta.ocpg * + (filter_meta.spatial[0] + output_block_size - 1) * + (filter_meta.spatial[1] + output_block_size - 1) * + filter_type.size(); + break; + } + default: + megdnn_assert("other filter format is not support yet"); + } + return reinterpret_cast(reinterpret_cast(filter_ptr) + + group_offset); +} + +//! when format is nchwxx and channel wise mode, multi group will pack +//! together, so pack_group_size is the number of packed group +template +const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id, + size_t group_pack_size) const { + bias_type.assert_is_compatible_ctype(); + size_t batch_offset = 0_z; + size_t group_offset = 0_z; + if (bias_mode == BiasMode::BIAS) { + batch_offset = batch_id * bias_bs * bias_type.size(); + group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] * + osz[1] * bias_type.size(); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + group_offset = group_pack_size * group_id * filter_meta.ocpg * + bias_type.size(); + } + return reinterpret_cast(reinterpret_cast(bias_ptr) + + batch_offset + group_offset); +} + +//! when format is nchwxx and channel wise mode, multi group will pack +//! together, so pack_group_size is the number of packed group +template +T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id, + size_t group_pack_size) const { + dst_type.assert_is_compatible_ctype(); + size_t batch_offset = batch_id * out_bs * dst_type.size(); + size_t group_offset = group_pack_size * group_id * filter_meta.ocpg * + osz[0] * osz[1] * dst_type.size(); + return reinterpret_cast(reinterpret_cast(dst_ptr) + + batch_offset + group_offset); +} + +#define INST(T) \ + template const T* ConvBiasImpl::NCBKernParam::src( \ + size_t batch_id, size_t group_id, size_t group_pack_size) const; \ + template const T* ConvBiasImpl::NCBKernParam::bias( \ + size_t batch_id, size_t group_id, size_t group_pack_size) const; \ + template const T* ConvBiasImpl::NCBKernParam::filter( \ + size_t group_id, size_t group_pack_size) const; \ + template T* ConvBiasImpl::NCBKernParam::dst( \ + size_t batch_id, size_t group_id, size_t group_pack_size) const; + +#define INST_DT(d) INST(DTypeTrait::ctype) + +MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) +#undef INST +#undef INST_DT +} // namespace fallback +} // namespace megdnn + + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 6fef6135..1b121382 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -104,24 +104,39 @@ public: return static_cast(src_ptr); } + template + const T* src(size_t batch_id, size_t group_id, + size_t group_pack_size = 1_z) const; + template const T* filter() const { filter_type.assert_is_compatible_ctype(); return static_cast(filter_ptr); } + template + const T* filter(size_t group_id, size_t pack_group_size = 1_z) const; + template const T* bias() const { bias_type.assert_is_compatible_ctype(); return static_cast(bias_ptr); } + template + const T* bias(size_t batch_id, size_t group_id, + size_t group_pack_size = 1_z) const; + template T* dst() const { dst_type.assert_is_compatible_ctype(); return static_cast(dst_ptr); } + template + T* dst(size_t batch_id, size_t group_id, + size_t group_pack_size = 1_z) const; + template T* workspace() const { return static_cast(workspace_ptr); diff --git a/dnn/src/fallback/conv_bias/winograd/winograd.h b/dnn/src/fallback/conv_bias/winograd/winograd.h index e6887f50..8e317d3e 100644 --- a/dnn/src/fallback/conv_bias/winograd/winograd.h +++ b/dnn/src/fallback/conv_bias/winograd/winograd.h @@ -210,7 +210,7 @@ public: reinterpret_cast( reinterpret_cast(bundle_compute.get(2)) + compute_workspace_size_per_thread * thread_id); - const stype* filter_ptr = kern_param.filter(); + const stype* filter_ptr = kern_param.filter(group_id); size_t oc_start = oc_id, oc_end = oc_id+1; if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) { oc_start = 8 * oc_id; @@ -246,16 +246,19 @@ public: size_t oc_block_id = ncb_index.ndrange_id[3]; size_t tile_id = ncb_index.ndrange_id[2]; + size_t batch_id = ncb_index.ndrange_id[1]; size_t group_id = ncb_index.ndrange_id[0]; size_t thread_id = ncb_index.thread_id; bundle_top.set(ncb_param.workspace_ptr); bundle_compute.set(bundle_top.get(0)); - const stype* src_ptr = ncb_param.src(); - dst_type* dst_ptr = ncb_param.dst(); + const stype* src_ptr = ncb_param.src(batch_id, group_id); + dst_type* dst_ptr = ncb_param.dst(batch_id, group_id); const output_compute_type* bias_ptr = - static_cast(ncb_param.bias_ptr); + static_cast( + ncb_param.bias(batch_id, + group_id)); input_filter_compute_type* input_transform_buf = reinterpret_cast( @@ -271,9 +274,10 @@ public: reinterpret_cast(bundle_compute.get(2)) + compute_workspace_size_per_thread * thread_id); + //! NCHW88_WINOGRAD and NCHW_WINOGRAD is the same offset const input_filter_compute_type* filter_transform_buf = static_cast( - ncb_param.filter_ptr); + ncb_param.filter(group_id)); if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW || ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88) { filter_transform_buf = reinterpret_cast(