GitOrigin-RevId: c91aee2c7c
tags/v0.3.2
| @@ -57,11 +57,12 @@ public: | |||
| const ConvBiasImpl::NCBKernParam& param, | |||
| const WorkspaceBundle& bundle_thread, size_t bundle_id, | |||
| size_t oc_cur_index, size_t OHW, bool is_dst_8bit, | |||
| bool ohw_bigger_ohwblock) { | |||
| bool ohw_bigger_ohwblock, size_t batch_id, size_t group_id) { | |||
| if (is_dst_8bit || !ohw_bigger_ohwblock) { | |||
| return static_cast<dtype*>(bundle_thread.get(bundle_id)); | |||
| } else { | |||
| dtype* dst = param.dst<dtype>() + oc_cur_index * OHW; | |||
| dtype* dst = | |||
| param.dst<dtype>(batch_id, group_id) + oc_cur_index * OHW; | |||
| return static_cast<dtype*>(dst); | |||
| } | |||
| } | |||
| @@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
| size_t IW2 = IW + 2 * PW; | |||
| size_t IH2 = IH + 2 * PH; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||
| size_t channel_id = ncb_index.ndrange_id[2]; | |||
| size_t padding_group_size = IH2 * IW2 * IC; | |||
| size_t input_channel_offset = IH * IW * ncb_index.ndrange_id[2]; | |||
| size_t workspace_channel_offset = IH2 * IW2 * ncb_index.ndrange_id[2]; | |||
| size_t workspace_group_offset = | |||
| ncb_index.ndrange_id[0] * padding_group_size; | |||
| size_t workspace_batch_offset = param.filter_meta.group * | |||
| ncb_index.ndrange_id[1] * | |||
| padding_group_size; | |||
| size_t input_channel_offset = IH * IW * channel_id; | |||
| size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||
| size_t workspace_group_offset = group_id * padding_group_size; | |||
| size_t workspace_batch_offset = | |||
| param.filter_meta.group * batch_id * padding_group_size; | |||
| bundle.set(param.workspace_ptr); | |||
| src_ctype src_zp = static_cast<src_ctype>(0); | |||
| if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
| src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
| } | |||
| src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>() + | |||
| input_channel_offset); | |||
| src_ctype* src = const_cast<src_ctype*>( | |||
| param.src<src_ctype>(batch_id, group_id) + input_channel_offset); | |||
| src_ctype* src2; | |||
| src2 = static_cast<src_ctype*>( | |||
| bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + | |||
| @@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
| */ | |||
| #define COPY_BIAS() \ | |||
| const bias_ctype* bias_ptr = \ | |||
| static_cast<const bias_ctype*>(param.bias_ptr); \ | |||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( \ | |||
| param.bias<bias_ctype>(batch_id, group_id)); \ | |||
| bias_ctype* bias_temp_ptr = \ | |||
| PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \ | |||
| if (param.bias_mode == megdnn::BiasMode::BIAS) { \ | |||
| @@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
| #define IM2COL() \ | |||
| src_ctype* im2col_dst = nullptr; \ | |||
| src_ctype* no_padding_src = \ | |||
| const_cast<src_ctype*>(param.src<src_ctype>()) + ohw_cur_index; \ | |||
| const_cast<src_ctype*>(param.src<src_ctype>(batch_id, group_id)) + \ | |||
| ohw_cur_index; \ | |||
| if (!special_1x1) { \ | |||
| size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \ | |||
| src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ | |||
| @@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
| param.filter_meta.group * ncb_index.ndrange_id[1]) * \ | |||
| padding_group_size); \ | |||
| if (PH == 0 && PW == 0) { \ | |||
| src2 = const_cast<src_ctype*>(param.src<src_ctype>()); \ | |||
| src2 = const_cast<src_ctype*>( \ | |||
| param.src<src_ctype>(batch_id, group_id)); \ | |||
| } \ | |||
| im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \ | |||
| Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \ | |||
| @@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
| output_block_size); \ | |||
| if (!skip_copy_dst) { \ | |||
| dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \ | |||
| dst_ctype* dst = \ | |||
| param.dst<dst_ctype>() + oc_cur_index * OHW + ohw_cur_index; \ | |||
| dst_ctype* dst = param.dst<dst_ctype>(batch_id, group_id) + \ | |||
| oc_cur_index * OHW + ohw_cur_index; \ | |||
| for (size_t oc = 0; oc < output_block_oc_size; oc++) { \ | |||
| std::memcpy(dst, dst_tmp_ptr, \ | |||
| sizeof(dst_ctype) * output_block_size); \ | |||
| @@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
| bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ | |||
| param, bundle_thread, \ | |||
| Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \ | |||
| is_dst_8bit, is_ohw_size_bigger); | |||
| is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); | |||
| #define MATMUL_COMPUTE() \ | |||
| auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ | |||
| @@ -272,6 +276,7 @@ public: | |||
| ConvBiasImpl::NCBKernIndex ncb_index) { | |||
| bundle.set(param.workspace_ptr); | |||
| fallback::MatrixMulImpl::KernParam matmul_param; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | |||
| matmulparam; | |||
| size_t packA_group_size = | |||
| @@ -283,11 +288,11 @@ public: | |||
| matmul_algo->get_packA_type_size(); | |||
| size_t a_panel_offset = | |||
| ncb_index.ndrange_id[2] * packed_per_oc_block_size; | |||
| int8_t* a_panel = | |||
| static_cast<int8_t*>( | |||
| bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + | |||
| ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; | |||
| matmul_param.A_ptr = const_cast<src_ctype*>(param.filter<src_ctype>()); | |||
| int8_t* a_panel = static_cast<int8_t*>(bundle.get( | |||
| Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + | |||
| group_id * packA_group_size + a_panel_offset; | |||
| matmul_param.A_ptr = | |||
| const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | |||
| matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2], | |||
| matmul_algo->get_inner_block_size().m); | |||
| }; | |||
| @@ -309,6 +314,8 @@ public: | |||
| auto IH2 = IH + 2 * PH; | |||
| auto IW2 = IW + 2 * PW; | |||
| size_t OHW = OH * OW; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||
| size_t output_block_size = std::min( | |||
| ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); | |||
| size_t output_block_oc_size = std::min( | |||
| @@ -369,11 +376,11 @@ public: | |||
| \ | |||
| src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ | |||
| bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \ | |||
| ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset); \ | |||
| group_id * packA_group_size + a_panel_offset); \ | |||
| matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ | |||
| param, bundle_thread, \ | |||
| Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ | |||
| OHW, is_dst_8bit, is_ohw_size_bigger); | |||
| OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); | |||
| #define MATMUL_COMPUTE() \ | |||
| auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ | |||
| @@ -402,6 +409,7 @@ public: | |||
| matmulparam; | |||
| size_t OC = param.filter_meta.ocpg; | |||
| size_t oc_tile_size = matmul_param.M; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| size_t output_block_oc_size = std::min( | |||
| oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size); | |||
| size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size; | |||
| @@ -411,12 +419,12 @@ public: | |||
| size_t a_panel_offset = | |||
| ncb_index.ndrange_id[2] * | |||
| matmul_algo->get_bundle(matmul_param).get_size(0); | |||
| int8_t* a_panel = | |||
| static_cast<int8_t*>( | |||
| bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + | |||
| ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; | |||
| matmul_param.A_ptr = const_cast<src_ctype*>(param.filter<src_ctype>()) + | |||
| oc_cur_index * matmul_param.K; | |||
| int8_t* a_panel = static_cast<int8_t*>(bundle.get( | |||
| Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + | |||
| group_id * packA_group_size + a_panel_offset; | |||
| matmul_param.A_ptr = | |||
| const_cast<src_ctype*>(param.filter<src_ctype>(group_id)) + | |||
| oc_cur_index * matmul_param.K; | |||
| matmul_param.M = output_block_oc_size; | |||
| matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); | |||
| }; | |||
| @@ -437,6 +445,8 @@ public: | |||
| MEGDNN_MARK_USED_VAR(N); | |||
| auto IH2 = IH + 2 * PH; | |||
| auto IW2 = IW + 2 * PW; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||
| size_t OHW = OH * OW; | |||
| size_t output_block_size = std::min( | |||
| ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); | |||
| @@ -490,11 +500,11 @@ public: | |||
| #define PREPAR_MATMUL_DATA() \ | |||
| bias_ctype* matmul_dst = nullptr; \ | |||
| const src_ctype* filter = \ | |||
| param.filter<src_ctype>() + oc_cur_index * IC * FH * FW; \ | |||
| param.filter<src_ctype>(group_id) + oc_cur_index * IC * FH * FW; \ | |||
| matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ | |||
| param, bundle_thread, \ | |||
| Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ | |||
| OHW, is_dst_8bit, is_ohw_size_bigger); | |||
| OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); | |||
| #define MATMUL_COMPUTE() \ | |||
| matmul_param.M = output_block_oc_size; \ | |||
| @@ -526,6 +536,8 @@ public: | |||
| MEGDNN_MARK_USED_VAR(N); | |||
| auto IH2 = IH + 2 * PH; | |||
| auto IW2 = IW + 2 * PW; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||
| size_t OHW = OH * OW; | |||
| size_t output_block_size = std::min( | |||
| ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); | |||
| @@ -245,65 +245,10 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param( | |||
| void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
| ConvBiasImpl::Algorithm* algo) { | |||
| auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); | |||
| size_t src_batch_stride = param.inp_bs * param.src_type.size(); | |||
| size_t dst_batch_stride = param.out_bs * param.dst_type.size(); | |||
| size_t bias_batch_stride = 0; | |||
| if (param.bias_mode == BiasMode::BIAS) { | |||
| bias_batch_stride = param.bias_bs * param.bias_type.size(); | |||
| } | |||
| for (auto&& kernel : ncb_kerns) { | |||
| megdnn_assert( | |||
| param.filter_meta.format == Param::Format::NCHW || | |||
| param.filter_meta.format == Param::Format::NHWC || | |||
| param.filter_meta.format == | |||
| Param::Format::NCHW_WINOGRAD || | |||
| param.filter_meta.format == Param::Format::NCHW88 || | |||
| param.filter_meta.format == | |||
| Param::Format::NCHW88_WINOGRAD, | |||
| "invalid conv format"); | |||
| ptrdiff_t istrd = 0, fstrd = 0, bstrd = 0, ostrd = 0; | |||
| if (param.filter_meta.format == Param::Format::NCHW_WINOGRAD || | |||
| param.filter_meta.format == Param::Format::NCHW88_WINOGRAD) { | |||
| fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * | |||
| (param.filter_meta.spatial[0] + param.output_block_size - | |||
| 1) * | |||
| (param.filter_meta.spatial[1] + param.output_block_size - | |||
| 1) * | |||
| param.filter_type.size(); | |||
| } else { | |||
| fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * | |||
| param.filter_meta.spatial[0] * | |||
| param.filter_meta.spatial[1] * param.filter_type.size(); | |||
| } | |||
| istrd = param.filter_meta.icpg * param.src_type.size(); | |||
| ostrd = param.filter_meta.ocpg * param.dst_type.size(); | |||
| if (param.bias_mode != BiasMode::NO_BIAS) { | |||
| bstrd = param.filter_meta.ocpg * param.bias_type.size(); | |||
| } | |||
| if (param.filter_meta.format == Param::Format::NCHW || | |||
| param.filter_meta.format == Param::Format::NCHW_WINOGRAD || | |||
| param.filter_meta.format == Param::Format::NCHW88_WINOGRAD) { | |||
| istrd *= param.isz[0] * param.isz[1]; | |||
| ostrd *= param.osz[0] * param.osz[1]; | |||
| if (param.bias_mode == BiasMode::BIAS) { | |||
| bstrd *= param.osz[0] * param.osz[1]; | |||
| } | |||
| } else { | |||
| // must be NHWC. No action performed. | |||
| } | |||
| auto run = [=](size_t index, size_t thread_id) { | |||
| auto copy_param = param; | |||
| CpuNDRange ndrange_id(kernel.global_size, index); | |||
| size_t group_id = ndrange_id[0]; | |||
| size_t batch_id = ndrange_id[1]; | |||
| //! The kernel ptr point to batch index | |||
| incr_ptr(copy_param.src_ptr, | |||
| group_id * istrd + batch_id * src_batch_stride); | |||
| incr_ptr(copy_param.filter_ptr, group_id * fstrd); | |||
| incr_ptr(copy_param.bias_ptr, | |||
| group_id * bstrd + batch_id * bias_batch_stride); | |||
| incr_ptr(copy_param.dst_ptr, | |||
| group_id * ostrd + batch_id * dst_batch_stride); | |||
| kernel.kern(copy_param, {thread_id, ndrange_id}); | |||
| }; | |||
| static_cast<naive::HandleImpl*>(handle())->dispatch_kern( | |||
| @@ -381,4 +326,124 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { | |||
| return "F0"; | |||
| } | |||
| namespace megdnn{ | |||
| namespace fallback { | |||
| //! when format is nchwxx and channel wise mode, multi group will pack | |||
| //! together, so pack_group_size is the number of packed group | |||
| template <typename T> | |||
| const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id, | |||
| size_t group_pack_size) const { | |||
| src_type.assert_is_compatible_ctype<T>(); | |||
| size_t batch_offset = batch_id * inp_bs * src_type.size(); | |||
| size_t group_offset = group_pack_size * group_id * filter_meta.icpg * | |||
| isz[0] * isz[1] * src_type.size(); | |||
| return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) + | |||
| batch_offset + group_offset); | |||
| } | |||
| //! when format is nchwxx and channel wise mode, multi group will pack | |||
| //! together, so pack_group_size is the number of packed group | |||
| template <typename T> | |||
| const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||
| size_t pack_group_size) const { | |||
| size_t group_offset = 0_z; | |||
| switch (filter_meta.format) { | |||
| case Param::Format::NCHW: { | |||
| group_offset = pack_group_size * group_id * filter_meta.icpg * | |||
| filter_meta.ocpg * filter_meta.spatial[0] * | |||
| filter_meta.spatial[1] * filter_type.size(); | |||
| break; | |||
| } | |||
| case Param::Format::NCHW88: { | |||
| size_t group = filter_meta.group; | |||
| size_t icpg = filter_meta.icpg; | |||
| size_t ocpg = filter_meta.ocpg; | |||
| //! four format of weight layout | |||
| //! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh, | |||
| //! fw, 8, 8} | |||
| //! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8} | |||
| megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) || | |||
| (group % 8 == 0 && icpg == 1 && ocpg == 1 && | |||
| pack_group_size > 1) || | |||
| (group == 1 && ocpg % 8 == 0), | |||
| "The filter shepe is not right of nchw88"); | |||
| group_offset = pack_group_size * group_id * filter_meta.icpg * | |||
| filter_meta.ocpg * filter_meta.spatial[0] * | |||
| filter_meta.spatial[1] * filter_type.size(); | |||
| break; | |||
| } | |||
| case ConvBiasImpl::Param::Format::NCHW_WINOGRAD: | |||
| case ConvBiasImpl::Param::Format::NCHW88_WINOGRAD: { | |||
| //! four format of weight layout | |||
| //! 1. {g, alpha, alpha, ocpg/8, icpg/8, 8, 8} | |||
| //! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8} | |||
| //! 3. {g, alpha, alpha, oc, ic, 8, 8} | |||
| //! 4. {alpha, alpha, oc, ic} | |||
| group_offset = pack_group_size * group_id * filter_meta.icpg * | |||
| filter_meta.ocpg * | |||
| (filter_meta.spatial[0] + output_block_size - 1) * | |||
| (filter_meta.spatial[1] + output_block_size - 1) * | |||
| filter_type.size(); | |||
| break; | |||
| } | |||
| default: | |||
| megdnn_assert("other filter format is not support yet"); | |||
| } | |||
| return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(filter_ptr) + | |||
| group_offset); | |||
| } | |||
| //! when format is nchwxx and channel wise mode, multi group will pack | |||
| //! together, so pack_group_size is the number of packed group | |||
| template <typename T> | |||
| const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id, | |||
| size_t group_pack_size) const { | |||
| bias_type.assert_is_compatible_ctype<T>(); | |||
| size_t batch_offset = 0_z; | |||
| size_t group_offset = 0_z; | |||
| if (bias_mode == BiasMode::BIAS) { | |||
| batch_offset = batch_id * bias_bs * bias_type.size(); | |||
| group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] * | |||
| osz[1] * bias_type.size(); | |||
| } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
| group_offset = group_pack_size * group_id * filter_meta.ocpg * | |||
| bias_type.size(); | |||
| } | |||
| return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(bias_ptr) + | |||
| batch_offset + group_offset); | |||
| } | |||
| //! when format is nchwxx and channel wise mode, multi group will pack | |||
| //! together, so pack_group_size is the number of packed group | |||
| template <typename T> | |||
| T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id, | |||
| size_t group_pack_size) const { | |||
| dst_type.assert_is_compatible_ctype<T>(); | |||
| size_t batch_offset = batch_id * out_bs * dst_type.size(); | |||
| size_t group_offset = group_pack_size * group_id * filter_meta.ocpg * | |||
| osz[0] * osz[1] * dst_type.size(); | |||
| return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) + | |||
| batch_offset + group_offset); | |||
| } | |||
| #define INST(T) \ | |||
| template const T* ConvBiasImpl::NCBKernParam::src<T>( \ | |||
| size_t batch_id, size_t group_id, size_t group_pack_size) const; \ | |||
| template const T* ConvBiasImpl::NCBKernParam::bias<T>( \ | |||
| size_t batch_id, size_t group_id, size_t group_pack_size) const; \ | |||
| template const T* ConvBiasImpl::NCBKernParam::filter<T>( \ | |||
| size_t group_id, size_t group_pack_size) const; \ | |||
| template T* ConvBiasImpl::NCBKernParam::dst<T>( \ | |||
| size_t batch_id, size_t group_id, size_t group_pack_size) const; | |||
| #define INST_DT(d) INST(DTypeTrait<d>::ctype) | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) | |||
| #undef INST | |||
| #undef INST_DT | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -104,24 +104,39 @@ public: | |||
| return static_cast<const T*>(src_ptr); | |||
| } | |||
| template <typename T> | |||
| const T* src(size_t batch_id, size_t group_id, | |||
| size_t group_pack_size = 1_z) const; | |||
| template <typename T> | |||
| const T* filter() const { | |||
| filter_type.assert_is_compatible_ctype<T>(); | |||
| return static_cast<const T*>(filter_ptr); | |||
| } | |||
| template <typename T> | |||
| const T* filter(size_t group_id, size_t pack_group_size = 1_z) const; | |||
| template <typename T> | |||
| const T* bias() const { | |||
| bias_type.assert_is_compatible_ctype<T>(); | |||
| return static_cast<const T*>(bias_ptr); | |||
| } | |||
| template <typename T> | |||
| const T* bias(size_t batch_id, size_t group_id, | |||
| size_t group_pack_size = 1_z) const; | |||
| template <typename T> | |||
| T* dst() const { | |||
| dst_type.assert_is_compatible_ctype<T>(); | |||
| return static_cast<T*>(dst_ptr); | |||
| } | |||
| template <typename T> | |||
| T* dst(size_t batch_id, size_t group_id, | |||
| size_t group_pack_size = 1_z) const; | |||
| template <typename T> | |||
| T* workspace() const { | |||
| return static_cast<T*>(workspace_ptr); | |||
| @@ -210,7 +210,7 @@ public: | |||
| reinterpret_cast<input_filter_compute_type*>( | |||
| reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + | |||
| compute_workspace_size_per_thread * thread_id); | |||
| const stype* filter_ptr = kern_param.filter<stype>(); | |||
| const stype* filter_ptr = kern_param.filter<stype>(group_id); | |||
| size_t oc_start = oc_id, oc_end = oc_id+1; | |||
| if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||
| oc_start = 8 * oc_id; | |||
| @@ -246,16 +246,19 @@ public: | |||
| size_t oc_block_id = ncb_index.ndrange_id[3]; | |||
| size_t tile_id = ncb_index.ndrange_id[2]; | |||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| size_t thread_id = ncb_index.thread_id; | |||
| bundle_top.set(ncb_param.workspace_ptr); | |||
| bundle_compute.set(bundle_top.get(0)); | |||
| const stype* src_ptr = ncb_param.src<stype>(); | |||
| dst_type* dst_ptr = ncb_param.dst<dst_type>(); | |||
| const stype* src_ptr = ncb_param.src<stype>(batch_id, group_id); | |||
| dst_type* dst_ptr = ncb_param.dst<dst_type>(batch_id, group_id); | |||
| const output_compute_type* bias_ptr = | |||
| static_cast<const output_compute_type*>(ncb_param.bias_ptr); | |||
| static_cast<const output_compute_type*>( | |||
| ncb_param.bias<output_compute_type>(batch_id, | |||
| group_id)); | |||
| input_filter_compute_type* input_transform_buf = | |||
| reinterpret_cast<input_filter_compute_type*>( | |||
| @@ -271,9 +274,10 @@ public: | |||
| reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + | |||
| compute_workspace_size_per_thread * thread_id); | |||
| //! NCHW88_WINOGRAD and NCHW_WINOGRAD is the same offset | |||
| const input_filter_compute_type* filter_transform_buf = | |||
| static_cast<const input_filter_compute_type*>( | |||
| ncb_param.filter_ptr); | |||
| ncb_param.filter<input_filter_compute_type>(group_id)); | |||
| if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW || | |||
| ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||
| filter_transform_buf = reinterpret_cast<input_filter_compute_type*>( | |||