| @@ -49,6 +49,14 @@ namespace { | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||
| megdnn::arm_common::VEC_BCAST101x4>:: \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||
| @@ -57,20 +65,26 @@ namespace { | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N* OC* OH* OW); | |||
| #define FOR_BIAS(_mode) \ | |||
| switch (_mode) { \ | |||
| case megdnn::BiasMode::NO_BIAS: \ | |||
| FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ | |||
| break; \ | |||
| case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST) \ | |||
| break; \ | |||
| case megdnn::BiasMode::BIAS: \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_throw("no quantized unsupported biasmode"); \ | |||
| break; \ | |||
| #define FOR_BIAS(_mode) \ | |||
| switch (_mode) { \ | |||
| case megdnn::BiasMode::NO_BIAS: \ | |||
| FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ | |||
| break; \ | |||
| case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
| if (pack_oc_size == 1) { \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
| } else { \ | |||
| megdnn_assert(pack_oc_size == 4, \ | |||
| "Only support nchw44 in ARM"); \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
| } \ | |||
| break; \ | |||
| case megdnn::BiasMode::BIAS: \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_throw("no quantized unsupported biasmode"); \ | |||
| break; \ | |||
| } | |||
| #define FOR_NONLINEAR(_caller) \ | |||
| @@ -129,6 +143,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| #undef FOR_NONLINEAR_UNARY | |||
| #undef FOR_NONLINEAR_BINARY_BROADCAST | |||
| #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 | |||
| #undef FOR_NONLINEAR_BINARY | |||
| #undef FOR_NONLINEAR_NOBIAS | |||
| #undef FOR_NONLINEAR | |||
| @@ -187,6 +202,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| if (pack_oc_size == 1) { \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
| } else { \ | |||
| megdnn_assert(pack_oc_size == 4, \ | |||
| "Only support nchw44 in ARM"); \ | |||
| FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
| } \ | |||
| break; \ | |||
| @@ -216,14 +216,18 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
| return false; | |||
| if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||
| //! nchw44 hybird mode and channel wise is not support | |||
| if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | |||
| param.filter_meta.ocpg == 1) { | |||
| return false; | |||
| } | |||
| } | |||
| size_t OH = param.osz[0]; | |||
| size_t OW = param.osz[1]; | |||
| MatrixMulImpl::KernSizeParam matmul_param = | |||
| get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
| if(opr->param().format == param::ConvBias::Format::NCHW44) | |||
| matmul_param.format = param::MatrixMul::Format::MK4; | |||
| MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param( | |||
| param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
| bool matmul_usable = m_matmul_algo->usable(matmul_param); | |||
| return matmul_usable && | |||
| @@ -22,6 +22,20 @@ namespace conv1x1 { | |||
| namespace { | |||
| size_t get_format_pack_size(param::ConvBias::Format format) { | |||
| switch(format){ | |||
| case param::ConvBias::Format::NCHW44: | |||
| case param::ConvBias::Format::NCHW4: | |||
| return 4_z; | |||
| case param::ConvBias::Format::NCHW88: | |||
| return 8_z; | |||
| case param::ConvBias::Format::NCHW: | |||
| return 1_z; | |||
| default: | |||
| megdnn_throw("unknow pack size of the format"); | |||
| } | |||
| } | |||
| struct StrategyHashParam { | |||
| ConvBiasImpl::NCBKernSizeParam param; | |||
| param::ConvBias::Format format; | |||
| @@ -71,7 +85,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
| const ConvBiasImpl::NCBKernSizeParam& param, | |||
| MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
| param::ConvBias::Format format) { | |||
| size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; | |||
| size_t pack_size = get_format_pack_size(format); | |||
| #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| @@ -41,19 +41,25 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
| (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
| param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
| size_t pack_c_size = 1_z; | |||
| auto format = param::MatrixMul::Format::DEFAULT; | |||
| if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ | |||
| pack_c_size = 4_z; | |||
| format = param::MatrixMul::Format::MK4; | |||
| } | |||
| return {param.filter_type, | |||
| param.src_type, | |||
| is_dst_8bit ? param.bias_type : param.dst_type, | |||
| M, | |||
| N, | |||
| K, | |||
| LDA, | |||
| LDB, | |||
| LDC, | |||
| LDA * pack_c_size, | |||
| LDB * pack_c_size, | |||
| LDC * pack_c_size, | |||
| false, | |||
| false, | |||
| param::MatrixMul::ComputeMode::DEFAULT, | |||
| param::MatrixMul::Format::DEFAULT}; | |||
| format}; | |||
| } | |||
| } // namespace | |||
| @@ -137,9 +143,7 @@ public: | |||
| src_ctype* a_panel = reinterpret_cast<src_ctype*>( | |||
| reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | |||
| bytes_offset_of_a_panel); | |||
| matmul_kern_param.LDA *= m_pack_size; | |||
| matmul_kern_param.A_ptr = const_cast<src_ctype*>( | |||
| ncb_param.filter<src_ctype>(group_id) + | |||
| numbers_offset_of_filter); | |||
| @@ -172,7 +176,6 @@ public: | |||
| static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
| get_matmul_kern_param(param, OH * OW, OC); | |||
| matmul_kern_param.LDB *= m_pack_size; | |||
| rep(batch, BATCH) { | |||
| rep(g, GROUP) { | |||
| @@ -282,8 +285,6 @@ public: | |||
| matmul_kern_param.C_ptr = matmul_dst; | |||
| matmul_kern_param.LDC *= m_pack_size; | |||
| if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||
| auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | |||
| matmul_kern(matmul_kern_param); | |||
| @@ -295,14 +296,15 @@ public: | |||
| //! do postprocess | |||
| void* bias_ptr = nullptr; | |||
| if (param.bias_mode == megdnn::BiasMode::BIAS) | |||
| if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
| bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
| ncb_param.bias<bias_ctype>(batch_id, group_id) + | |||
| numbers_of_ncb_dst_offset)); | |||
| else | |||
| } else { | |||
| bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
| ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||
| } | |||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
| matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | |||
| param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
| @@ -137,8 +137,8 @@ class ConvBias { | |||
| sizeof(output_compute_type) * | |||
| std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE); | |||
| size_t matmul_workspace_size = | |||
| matmul_algo->get_workspace(get_matmul_kern_param(param)); | |||
| size_t matmul_workspace_size = matmul_algo->get_workspace( | |||
| get_matmul_kern_param(param, m_unit_oc_size)); | |||
| //! compute workspace is independent and separated as far as possible | |||
| //! in case of false cache line sharing | |||
| @@ -384,7 +384,7 @@ public: | |||
| get_wbundle_compute(param, matmul_algo); | |||
| fallback::MatrixMulImpl::KernParam matmul_param; | |||
| static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | |||
| get_matmul_kern_param(param); | |||
| get_matmul_kern_param(param, m_unit_oc_size); | |||
| Strategy strategy = m_strategy; | |||
| size_t unit_tile_size = m_unit_tile_size; | |||
| @@ -450,21 +450,24 @@ public: | |||
| } | |||
| fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
| const NCBKernSizeParam& param) const { | |||
| const NCBKernSizeParam& param, size_t nr_oc_in_unit = 0) const { | |||
| size_t M = 0; | |||
| size_t N = 0; | |||
| size_t K = 0; | |||
| size_t LDA = 0, LDB = 0, LDC = 0; | |||
| if (nr_oc_in_unit == 0) { | |||
| nr_oc_in_unit = param.filter_meta.ocpg; | |||
| } | |||
| if (format == param::MatrixMul::Format::DEFAULT) { | |||
| M = m_unit_tile_size; | |||
| N = param.filter_meta.ocpg; | |||
| N = nr_oc_in_unit; | |||
| K = param.filter_meta.icpg; | |||
| LDA = K; | |||
| LDB = N; | |||
| LDC = N; | |||
| } else { | |||
| M = param.filter_meta.ocpg; | |||
| M = nr_oc_in_unit; | |||
| N = m_unit_tile_size; | |||
| K = param.filter_meta.icpg; | |||
| megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu", | |||
| @@ -126,6 +126,8 @@ struct PostProcess { | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn_assert(pack_oc_size == 1, | |||
| "PostProcess only support nchw in x86"); | |||
| megdnn::param::Elemwise::Mode elem_mode = | |||
| megdnn::param::Elemwise::Mode::ADD; | |||
| if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
| @@ -149,38 +151,6 @@ struct PostProcess { | |||
| } | |||
| }; | |||
| template <typename ctype, typename dtype> | |||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> { | |||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||
| megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW, size_t pack_oc_size=1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn::param::Elemwise::Mode elem_mode = | |||
| megdnn::param::Elemwise::Mode::ADD; | |||
| if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
| switch (nonlineMode) { | |||
| BIAS_CASE(RELU); | |||
| BIAS_CASE(SIGMOID); | |||
| BIAS_CASE(H_SWISH); | |||
| IDENTITY_CASE(IDENTITY); | |||
| DEFAULT_CASE; | |||
| } | |||
| } else { | |||
| switch (nonlineMode) { | |||
| NOBIAS_CASE(RELU); | |||
| NOBIAS_CASE(SIGMOID); | |||
| NOBIAS_CASE(H_SWISH); | |||
| IDENTITY_CASE(IDENTITY); | |||
| DEFAULT_CASE; | |||
| } | |||
| } | |||
| FOR_BIAS(bias_mode); | |||
| } | |||
| }; | |||
| template <typename ctype, typename dtype> | |||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
| @@ -297,6 +267,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn_assert(pack_oc_size == 1, | |||
| "PostProcess only support nchw nchw in x86"); | |||
| megdnn::param::Elemwise::Mode elem_mode = | |||
| megdnn::param::Elemwise::Mode::ADD; | |||
| if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
| @@ -1297,6 +1297,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | |||
| #endif | |||
| } | |||
| #if MEGDNN_AARCH64 | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
| check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||
| } | |||
| #endif | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
| std::vector<conv_bias::TestArg> args_of_4; | |||
| for (auto&& arg : args) { | |||
| if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) { | |||
| args_of_4.push_back(arg); | |||
| } | |||
| } | |||
| #if MEGDNN_AARCH64 | |||
| check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24"); | |||
| #elif MEGDNN_ARMV7 | |||
| check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48"); | |||
| #endif | |||
| } | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | |||
| using namespace conv_bias; | |||