| @@ -40,7 +40,8 @@ size_t ConvBiasImpl::AlgoConv1x1::get_oc_tile_size_heuristic( | |||
| size_t OC = param.filter_meta.ocpg; | |||
| if (OH * OW >= 56 * 56 || OC >= 64) | |||
| return m_oc_block_size; | |||
| return div_ceil(OC, param.nr_threads); | |||
| size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||
| return round_up<size_t>(oc_block_size_one_thread, 24); | |||
| } | |||
| size_t ConvBiasImpl::AlgoConv1x1::get_workspace( | |||
| @@ -180,8 +181,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
| const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy) const { | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | |||
| //! only support nchw format | |||
| if (opr->param().format != param::ConvBias::Format::NCHW) | |||
| if (opr->param().format != param::ConvBias::Format::NCHW && | |||
| opr->param().format != param::ConvBias::Format::NCHW44) | |||
| return false; | |||
| size_t FH = param.filter_meta.spatial[0], | |||
| @@ -218,8 +219,12 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
| MatrixMulImpl::KernSizeParam matmul_param = | |||
| get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
| bool matmulusable = m_matmul_algo->usable(matmul_param); | |||
| return matmulusable && | |||
| if(opr->param().format == param::ConvBias::Format::NCHW44) | |||
| matmul_param.format = param::MatrixMul::Format::MK4; | |||
| bool matmul_usable = m_matmul_algo->usable(matmul_param); | |||
| return matmul_usable && | |||
| (param.filter_meta.dilation[0] == | |||
| param.filter_meta.dilation[1] && | |||
| param.filter_meta.dilation[0] == 1) && | |||
| @@ -71,33 +71,32 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
| const ConvBiasImpl::NCBKernSizeParam& param, | |||
| MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
| param::ConvBias::Format format) { | |||
| MEGDNN_MARK_USED_VAR(format); | |||
| #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
| _postprocess_mode, _packmode>>(); \ | |||
| } \ | |||
| } \ | |||
| size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; | |||
| #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
| _postprocess_mode, _packmode>>(pack_size); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END() | |||
| #define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
| DTypeTrait<_i_bias_type>::ctype, \ | |||
| DTypeTrait<_i_dst_type>::ctype, \ | |||
| _postprocess_mode, _packmode>>(); \ | |||
| } \ | |||
| } \ | |||
| #define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
| DTypeTrait<_i_bias_type>::ctype, \ | |||
| DTypeTrait<_i_dst_type>::ctype, \ | |||
| _postprocess_mode, _packmode>>(pack_size); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END() | |||
| switch (pack_mode) { | |||
| @@ -88,6 +88,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| megdnn::PostprocessMode postprocess_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode> | |||
| class Conv1x1Strategy : public Conv1x1StrategyBase { | |||
| public: | |||
| explicit Conv1x1Strategy(size_t pack_size = 1) : m_pack_size(pack_size) {} | |||
| void packA(WorkspaceBundle& whole_bundle, | |||
| WorkspaceBundle& matmul_bundle, | |||
| size_t oc_tile_size, | |||
| @@ -133,6 +135,9 @@ public: | |||
| src_ctype* a_panel = reinterpret_cast<src_ctype*>( | |||
| reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | |||
| bytes_offset_of_a_panel); | |||
| matmul_kern_param.LDA *= m_pack_size; | |||
| matmul_kern_param.A_ptr = const_cast<src_ctype*>( | |||
| ncb_param.filter<src_ctype>(group_id) + | |||
| numbers_offset_of_filter); | |||
| @@ -165,6 +170,8 @@ public: | |||
| static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
| get_matmul_kern_param(param, OH * OW, OC); | |||
| matmul_kern_param.LDB *= m_pack_size; | |||
| rep(batch, BATCH) { | |||
| rep(g, GROUP) { | |||
| if (SH == 2 && SW == 2) | |||
| @@ -273,6 +280,8 @@ public: | |||
| matmul_kern_param.C_ptr = matmul_dst; | |||
| matmul_kern_param.LDC *= m_pack_size; | |||
| if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||
| auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | |||
| matmul_kern(matmul_kern_param); | |||
| @@ -291,11 +300,14 @@ public: | |||
| else | |||
| bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
| ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
| matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | |||
| param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
| oc_end - oc_start, OH, OW); | |||
| (oc_end - oc_start) / m_pack_size, OH, OW, m_pack_size); | |||
| } | |||
| private: | |||
| size_t m_pack_size = 1; | |||
| }; | |||
| class Conv1x1Factory { | |||