|
|
@@ -57,11 +57,12 @@ public: |
|
|
const ConvBiasImpl::NCBKernParam& param, |
|
|
const ConvBiasImpl::NCBKernParam& param, |
|
|
const WorkspaceBundle& bundle_thread, size_t bundle_id, |
|
|
const WorkspaceBundle& bundle_thread, size_t bundle_id, |
|
|
size_t oc_cur_index, size_t OHW, bool is_dst_8bit, |
|
|
size_t oc_cur_index, size_t OHW, bool is_dst_8bit, |
|
|
bool ohw_bigger_ohwblock) { |
|
|
|
|
|
|
|
|
bool ohw_bigger_ohwblock, size_t batch_id, size_t group_id) { |
|
|
if (is_dst_8bit || !ohw_bigger_ohwblock) { |
|
|
if (is_dst_8bit || !ohw_bigger_ohwblock) { |
|
|
return static_cast<dtype*>(bundle_thread.get(bundle_id)); |
|
|
return static_cast<dtype*>(bundle_thread.get(bundle_id)); |
|
|
} else { |
|
|
} else { |
|
|
dtype* dst = param.dst<dtype>() + oc_cur_index * OHW; |
|
|
|
|
|
|
|
|
dtype* dst = |
|
|
|
|
|
param.dst<dtype>(batch_id, group_id) + oc_cur_index * OHW; |
|
|
return static_cast<dtype*>(dst); |
|
|
return static_cast<dtype*>(dst); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -105,23 +106,24 @@ static void copy_padding_kern(WorkspaceBundle bundle, |
|
|
|
|
|
|
|
|
size_t IW2 = IW + 2 * PW; |
|
|
size_t IW2 = IW + 2 * PW; |
|
|
size_t IH2 = IH + 2 * PH; |
|
|
size_t IH2 = IH + 2 * PH; |
|
|
|
|
|
size_t group_id = ncb_index.ndrange_id[0]; |
|
|
|
|
|
size_t batch_id = ncb_index.ndrange_id[1]; |
|
|
|
|
|
size_t channel_id = ncb_index.ndrange_id[2]; |
|
|
|
|
|
|
|
|
size_t padding_group_size = IH2 * IW2 * IC; |
|
|
size_t padding_group_size = IH2 * IW2 * IC; |
|
|
size_t input_channel_offset = IH * IW * ncb_index.ndrange_id[2]; |
|
|
|
|
|
size_t workspace_channel_offset = IH2 * IW2 * ncb_index.ndrange_id[2]; |
|
|
|
|
|
size_t workspace_group_offset = |
|
|
|
|
|
ncb_index.ndrange_id[0] * padding_group_size; |
|
|
|
|
|
size_t workspace_batch_offset = param.filter_meta.group * |
|
|
|
|
|
ncb_index.ndrange_id[1] * |
|
|
|
|
|
padding_group_size; |
|
|
|
|
|
|
|
|
size_t input_channel_offset = IH * IW * channel_id; |
|
|
|
|
|
size_t workspace_channel_offset = IH2 * IW2 * channel_id; |
|
|
|
|
|
size_t workspace_group_offset = group_id * padding_group_size; |
|
|
|
|
|
size_t workspace_batch_offset = |
|
|
|
|
|
param.filter_meta.group * batch_id * padding_group_size; |
|
|
bundle.set(param.workspace_ptr); |
|
|
bundle.set(param.workspace_ptr); |
|
|
|
|
|
|
|
|
src_ctype src_zp = static_cast<src_ctype>(0); |
|
|
src_ctype src_zp = static_cast<src_ctype>(0); |
|
|
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { |
|
|
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { |
|
|
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; |
|
|
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; |
|
|
} |
|
|
} |
|
|
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>() + |
|
|
|
|
|
input_channel_offset); |
|
|
|
|
|
|
|
|
src_ctype* src = const_cast<src_ctype*>( |
|
|
|
|
|
param.src<src_ctype>(batch_id, group_id) + input_channel_offset); |
|
|
src_ctype* src2; |
|
|
src_ctype* src2; |
|
|
src2 = static_cast<src_ctype*>( |
|
|
src2 = static_cast<src_ctype*>( |
|
|
bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + |
|
|
bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + |
|
|
@@ -153,8 +155,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#define COPY_BIAS() \ |
|
|
#define COPY_BIAS() \ |
|
|
const bias_ctype* bias_ptr = \ |
|
|
|
|
|
static_cast<const bias_ctype*>(param.bias_ptr); \ |
|
|
|
|
|
|
|
|
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( \ |
|
|
|
|
|
param.bias<bias_ctype>(batch_id, group_id)); \ |
|
|
bias_ctype* bias_temp_ptr = \ |
|
|
bias_ctype* bias_temp_ptr = \ |
|
|
PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \ |
|
|
PtrGetter::get_bias_temp_ptr<bias_ctype>(param, bundle_thread); \ |
|
|
if (param.bias_mode == megdnn::BiasMode::BIAS) { \ |
|
|
if (param.bias_mode == megdnn::BiasMode::BIAS) { \ |
|
|
@@ -172,7 +174,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, |
|
|
#define IM2COL() \ |
|
|
#define IM2COL() \ |
|
|
src_ctype* im2col_dst = nullptr; \ |
|
|
src_ctype* im2col_dst = nullptr; \ |
|
|
src_ctype* no_padding_src = \ |
|
|
src_ctype* no_padding_src = \ |
|
|
const_cast<src_ctype*>(param.src<src_ctype>()) + ohw_cur_index; \ |
|
|
|
|
|
|
|
|
const_cast<src_ctype*>(param.src<src_ctype>(batch_id, group_id)) + \ |
|
|
|
|
|
ohw_cur_index; \ |
|
|
if (!special_1x1) { \ |
|
|
if (!special_1x1) { \ |
|
|
size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \ |
|
|
size_t padding_group_size = IH2 * IW2 * IC * sizeof(src_ctype); \ |
|
|
src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ |
|
|
src_ctype* src2 = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ |
|
|
@@ -181,7 +184,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, |
|
|
param.filter_meta.group * ncb_index.ndrange_id[1]) * \ |
|
|
param.filter_meta.group * ncb_index.ndrange_id[1]) * \ |
|
|
padding_group_size); \ |
|
|
padding_group_size); \ |
|
|
if (PH == 0 && PW == 0) { \ |
|
|
if (PH == 0 && PW == 0) { \ |
|
|
src2 = const_cast<src_ctype*>(param.src<src_ctype>()); \ |
|
|
|
|
|
|
|
|
src2 = const_cast<src_ctype*>( \ |
|
|
|
|
|
param.src<src_ctype>(batch_id, group_id)); \ |
|
|
} \ |
|
|
} \ |
|
|
im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \ |
|
|
im2col_dst = static_cast<src_ctype*>(bundle_thread.get( \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX)); \ |
|
|
@@ -217,8 +221,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, |
|
|
output_block_size); \ |
|
|
output_block_size); \ |
|
|
if (!skip_copy_dst) { \ |
|
|
if (!skip_copy_dst) { \ |
|
|
dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \ |
|
|
dst_ctype* dst_tmp_ptr = reinterpret_cast<dst_ctype*>(matmul_dst); \ |
|
|
dst_ctype* dst = \ |
|
|
|
|
|
param.dst<dst_ctype>() + oc_cur_index * OHW + ohw_cur_index; \ |
|
|
|
|
|
|
|
|
dst_ctype* dst = param.dst<dst_ctype>(batch_id, group_id) + \ |
|
|
|
|
|
oc_cur_index * OHW + ohw_cur_index; \ |
|
|
for (size_t oc = 0; oc < output_block_oc_size; oc++) { \ |
|
|
for (size_t oc = 0; oc < output_block_oc_size; oc++) { \ |
|
|
std::memcpy(dst, dst_tmp_ptr, \ |
|
|
std::memcpy(dst, dst_tmp_ptr, \ |
|
|
sizeof(dst_ctype) * output_block_size); \ |
|
|
sizeof(dst_ctype) * output_block_size); \ |
|
|
@@ -243,7 +247,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, |
|
|
bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ |
|
|
bias_ctype* matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ |
|
|
param, bundle_thread, \ |
|
|
param, bundle_thread, \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_IM2COL_INDEX, oc_cur_index, OHW, \ |
|
|
is_dst_8bit, is_ohw_size_bigger); |
|
|
|
|
|
|
|
|
is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); |
|
|
|
|
|
|
|
|
#define MATMUL_COMPUTE() \ |
|
|
#define MATMUL_COMPUTE() \ |
|
|
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ |
|
|
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ |
|
|
@@ -272,6 +276,7 @@ public: |
|
|
ConvBiasImpl::NCBKernIndex ncb_index) { |
|
|
ConvBiasImpl::NCBKernIndex ncb_index) { |
|
|
bundle.set(param.workspace_ptr); |
|
|
bundle.set(param.workspace_ptr); |
|
|
fallback::MatrixMulImpl::KernParam matmul_param; |
|
|
fallback::MatrixMulImpl::KernParam matmul_param; |
|
|
|
|
|
size_t group_id = ncb_index.ndrange_id[0]; |
|
|
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = |
|
|
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = |
|
|
matmulparam; |
|
|
matmulparam; |
|
|
size_t packA_group_size = |
|
|
size_t packA_group_size = |
|
|
@@ -283,11 +288,11 @@ public: |
|
|
matmul_algo->get_packA_type_size(); |
|
|
matmul_algo->get_packA_type_size(); |
|
|
size_t a_panel_offset = |
|
|
size_t a_panel_offset = |
|
|
ncb_index.ndrange_id[2] * packed_per_oc_block_size; |
|
|
ncb_index.ndrange_id[2] * packed_per_oc_block_size; |
|
|
int8_t* a_panel = |
|
|
|
|
|
static_cast<int8_t*>( |
|
|
|
|
|
bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + |
|
|
|
|
|
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; |
|
|
|
|
|
matmul_param.A_ptr = const_cast<src_ctype*>(param.filter<src_ctype>()); |
|
|
|
|
|
|
|
|
int8_t* a_panel = static_cast<int8_t*>(bundle.get( |
|
|
|
|
|
Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + |
|
|
|
|
|
group_id * packA_group_size + a_panel_offset; |
|
|
|
|
|
matmul_param.A_ptr = |
|
|
|
|
|
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); |
|
|
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2], |
|
|
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[2], |
|
|
matmul_algo->get_inner_block_size().m); |
|
|
matmul_algo->get_inner_block_size().m); |
|
|
}; |
|
|
}; |
|
|
@@ -309,6 +314,8 @@ public: |
|
|
auto IH2 = IH + 2 * PH; |
|
|
auto IH2 = IH + 2 * PH; |
|
|
auto IW2 = IW + 2 * PW; |
|
|
auto IW2 = IW + 2 * PW; |
|
|
size_t OHW = OH * OW; |
|
|
size_t OHW = OH * OW; |
|
|
|
|
|
size_t group_id = ncb_index.ndrange_id[0]; |
|
|
|
|
|
size_t batch_id = ncb_index.ndrange_id[1]; |
|
|
size_t output_block_size = std::min( |
|
|
size_t output_block_size = std::min( |
|
|
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); |
|
|
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); |
|
|
size_t output_block_oc_size = std::min( |
|
|
size_t output_block_oc_size = std::min( |
|
|
@@ -369,11 +376,11 @@ public: |
|
|
\ |
|
|
\ |
|
|
src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ |
|
|
src_ctype* a_panel = PtrGetter::get_bundle_offset_byte_ptr<src_ctype>( \ |
|
|
bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \ |
|
|
bundle, Im2colBundelIndex::BUNDLE_PACKA_INDEX, \ |
|
|
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset); \ |
|
|
|
|
|
|
|
|
group_id * packA_group_size + a_panel_offset); \ |
|
|
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ |
|
|
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ |
|
|
param, bundle_thread, \ |
|
|
param, bundle_thread, \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ |
|
|
OHW, is_dst_8bit, is_ohw_size_bigger); |
|
|
|
|
|
|
|
|
OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); |
|
|
|
|
|
|
|
|
#define MATMUL_COMPUTE() \ |
|
|
#define MATMUL_COMPUTE() \ |
|
|
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ |
|
|
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); \ |
|
|
@@ -402,6 +409,7 @@ public: |
|
|
matmulparam; |
|
|
matmulparam; |
|
|
size_t OC = param.filter_meta.ocpg; |
|
|
size_t OC = param.filter_meta.ocpg; |
|
|
size_t oc_tile_size = matmul_param.M; |
|
|
size_t oc_tile_size = matmul_param.M; |
|
|
|
|
|
size_t group_id = ncb_index.ndrange_id[0]; |
|
|
size_t output_block_oc_size = std::min( |
|
|
size_t output_block_oc_size = std::min( |
|
|
oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size); |
|
|
oc_tile_size, OC - ncb_index.ndrange_id[2] * oc_tile_size); |
|
|
size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size; |
|
|
size_t oc_cur_index = ncb_index.ndrange_id[2] * oc_tile_size; |
|
|
@@ -411,12 +419,12 @@ public: |
|
|
size_t a_panel_offset = |
|
|
size_t a_panel_offset = |
|
|
ncb_index.ndrange_id[2] * |
|
|
ncb_index.ndrange_id[2] * |
|
|
matmul_algo->get_bundle(matmul_param).get_size(0); |
|
|
matmul_algo->get_bundle(matmul_param).get_size(0); |
|
|
int8_t* a_panel = |
|
|
|
|
|
static_cast<int8_t*>( |
|
|
|
|
|
bundle.get(Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + |
|
|
|
|
|
ncb_index.ndrange_id[0] * packA_group_size + a_panel_offset; |
|
|
|
|
|
matmul_param.A_ptr = const_cast<src_ctype*>(param.filter<src_ctype>()) + |
|
|
|
|
|
oc_cur_index * matmul_param.K; |
|
|
|
|
|
|
|
|
int8_t* a_panel = static_cast<int8_t*>(bundle.get( |
|
|
|
|
|
Im2colBundelIndex::BUNDLE_PACKA_INDEX)) + |
|
|
|
|
|
group_id * packA_group_size + a_panel_offset; |
|
|
|
|
|
matmul_param.A_ptr = |
|
|
|
|
|
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)) + |
|
|
|
|
|
oc_cur_index * matmul_param.K; |
|
|
matmul_param.M = output_block_oc_size; |
|
|
matmul_param.M = output_block_oc_size; |
|
|
matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); |
|
|
matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); |
|
|
}; |
|
|
}; |
|
|
@@ -437,6 +445,8 @@ public: |
|
|
MEGDNN_MARK_USED_VAR(N); |
|
|
MEGDNN_MARK_USED_VAR(N); |
|
|
auto IH2 = IH + 2 * PH; |
|
|
auto IH2 = IH + 2 * PH; |
|
|
auto IW2 = IW + 2 * PW; |
|
|
auto IW2 = IW + 2 * PW; |
|
|
|
|
|
size_t group_id = ncb_index.ndrange_id[0]; |
|
|
|
|
|
size_t batch_id = ncb_index.ndrange_id[1]; |
|
|
size_t OHW = OH * OW; |
|
|
size_t OHW = OH * OW; |
|
|
size_t output_block_size = std::min( |
|
|
size_t output_block_size = std::min( |
|
|
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); |
|
|
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); |
|
|
@@ -490,11 +500,11 @@ public: |
|
|
#define PREPAR_MATMUL_DATA() \ |
|
|
#define PREPAR_MATMUL_DATA() \ |
|
|
bias_ctype* matmul_dst = nullptr; \ |
|
|
bias_ctype* matmul_dst = nullptr; \ |
|
|
const src_ctype* filter = \ |
|
|
const src_ctype* filter = \ |
|
|
param.filter<src_ctype>() + oc_cur_index * IC * FH * FW; \ |
|
|
|
|
|
|
|
|
param.filter<src_ctype>(group_id) + oc_cur_index * IC * FH * FW; \ |
|
|
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ |
|
|
matmul_dst = PtrGetter::get_matmul_dst_ptr<bias_ctype>( \ |
|
|
param, bundle_thread, \ |
|
|
param, bundle_thread, \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ |
|
|
Im2colBundelIndex::THREAD_BUNDLE_MATMUL_DST_INDEX, oc_cur_index, \ |
|
|
OHW, is_dst_8bit, is_ohw_size_bigger); |
|
|
|
|
|
|
|
|
OHW, is_dst_8bit, is_ohw_size_bigger, batch_id, group_id); |
|
|
|
|
|
|
|
|
#define MATMUL_COMPUTE() \ |
|
|
#define MATMUL_COMPUTE() \ |
|
|
matmul_param.M = output_block_oc_size; \ |
|
|
matmul_param.M = output_block_oc_size; \ |
|
|
@@ -526,6 +536,8 @@ public: |
|
|
MEGDNN_MARK_USED_VAR(N); |
|
|
MEGDNN_MARK_USED_VAR(N); |
|
|
auto IH2 = IH + 2 * PH; |
|
|
auto IH2 = IH + 2 * PH; |
|
|
auto IW2 = IW + 2 * PW; |
|
|
auto IW2 = IW + 2 * PW; |
|
|
|
|
|
size_t group_id = ncb_index.ndrange_id[0]; |
|
|
|
|
|
size_t batch_id = ncb_index.ndrange_id[1]; |
|
|
size_t OHW = OH * OW; |
|
|
size_t OHW = OH * OW; |
|
|
size_t output_block_size = std::min( |
|
|
size_t output_block_size = std::min( |
|
|
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); |
|
|
ohw_tile_size, OHW - ncb_index.ndrange_id[2] * ohw_tile_size); |
|
|
|