| @@ -407,6 +407,11 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern( | |||||
| return kern_mk8_16x12x1; | return kern_mk8_16x12x1; | ||||
| } | } | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( | |||||
| AlgoF16MK8_16x12x1, megdnn_aarch64_matmul_kern, "AlogF16MK8_16x12x1Impl"_hash, | |||||
| aarch64::matmul::hgemm_mk8_16x12, dt_float16, dt_float16, AlgoDataType::FLOAT16, | |||||
| MK8); | |||||
| #endif | #endif | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| @@ -93,7 +93,7 @@ public: | |||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(16, 12, 1, 2, AlgoDataType::FLOAT16, MK8); | |||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); | MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); | ||||
| }; | }; | ||||
| @@ -9,8 +9,8 @@ | |||||
| template <> | template <> | ||||
| void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>( | void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>( | ||||
| const dt_float16* packedA, const dt_float16* packedB, int K, | |||||
| dt_float16* out, int LDC, bool is_first_k) { | |||||
| const dt_float16* packedA, const dt_float16* packedB, int K, dt_float16* out, | |||||
| int LDC, bool is_first_k) { | |||||
| #define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" | #define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" | ||||
| #define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" | #define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" | ||||
| // clang-format off | // clang-format off | ||||
| @@ -26,6 +26,8 @@ static fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||||
| format = param::MatrixMul::Format::MK4; | format = param::MatrixMul::Format::MK4; | ||||
| } else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | } else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | ||||
| format = param::MatrixMul::Format::MK4_DOT; | format = param::MatrixMul::Format::MK4_DOT; | ||||
| } else if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||||
| format = param::MatrixMul::Format::MK8; | |||||
| } | } | ||||
| size_t M = oc_tile_size; | size_t M = oc_tile_size; | ||||
| size_t N = ohw_tile_size; | size_t N = ohw_tile_size; | ||||
| @@ -329,9 +331,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| if (format != param::ConvBias::Format::NCHW && | if (format != param::ConvBias::Format::NCHW && | ||||
| format != param::ConvBias::Format::NCHW44 && | format != param::ConvBias::Format::NCHW44 && | ||||
| format != param::ConvBias::Format::NCHW44_DOT) { | |||||
| format != param::ConvBias::Format::NCHW44_DOT && | |||||
| format != param::ConvBias::Format::NCHW88) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (format == param::ConvBias::Format::NCHW88) { | |||||
| if (matmul_desc.packmode != Pack_Mode::DEFAULT) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (format == param::ConvBias::Format::NCHW44 || | if (format == param::ConvBias::Format::NCHW44 || | ||||
| format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
| //! current NCHW44 im2col only support DEFAULT mode matmul | //! current NCHW44 im2col only support DEFAULT mode matmul | ||||
| @@ -248,8 +248,18 @@ public: | |||||
| break; | break; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| case StrategyType::FLOAT_FP16: | case StrategyType::FLOAT_FP16: | ||||
| cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
| "DefaultStrategyType::FLOAT_FP16"_hash); | |||||
| if (format == param::ConvBias::Format::NCHW) { | |||||
| cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
| "DefaultStrategyType::FLOAT_FP16"_hash); | |||||
| } else if (format == param::ConvBias::Format::NCHW88) { | |||||
| cb1(NCHW88, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
| "DefaultStrategyTypeNCHW88::FLOAT_FP16"_hash); | |||||
| } else { | |||||
| megdnn_throw(ssprintf( | |||||
| "Current only support layout NCHW/NCHW88 for im2col algo " | |||||
| "of float 16, but got %d\n", | |||||
| uint32_t(format))); | |||||
| } | |||||
| break; | break; | ||||
| #endif | #endif | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| @@ -343,6 +343,32 @@ public: | |||||
| const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | ||||
| }; | }; | ||||
| template < | |||||
| typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | |||||
| typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||||
| class Strategy< | |||||
| src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||||
| PackMode::DEFAULT, FormatMode::NCHW88> | |||||
| : public Strategy< | |||||
| src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT> { | |||||
| public: | |||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||||
| constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||||
| constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||||
| constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||||
| Strategy() = default; | |||||
| void exec_im2col( | |||||
| const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||||
| const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
| }; | |||||
| template < | template < | ||||
| typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | ||||
| typename op_dtype, megdnn::PostprocessMode postprocess_mode> | typename op_dtype, megdnn::PostprocessMode postprocess_mode> | ||||
| @@ -0,0 +1,98 @@ | |||||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | |||||
| #include "src/fallback/convolution/img2col_helper.h" | |||||
| #if MEGDNN_X86 | |||||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
| #else | |||||
| #include "src/common/postprocess_helper.h" | |||||
| #endif | |||||
| using namespace megdnn; | |||||
| #if MEGDNN_X86 | |||||
| using namespace x86; | |||||
| #endif | |||||
| namespace megdnn { | |||||
| template < | |||||
| typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | |||||
| typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy< | |||||
| src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||||
| PackMode::DEFAULT, FormatMode::NCHW88>:: | |||||
| exec_im2col( | |||||
| const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||||
| const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
| size_t sh = param.filter_meta.stride[0]; | |||||
| size_t sw = param.filter_meta.stride[1]; | |||||
| size_t ow = param.osz[1]; | |||||
| size_t ic = param.filter_meta.icpg; | |||||
| size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
| size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
| size_t fh = param.filter_meta.spatial[0]; | |||||
| size_t fw = param.filter_meta.spatial[1]; | |||||
| bool is_xcoor = !param.filter_meta.should_flip; | |||||
| constexpr static size_t pack_size = 8; | |||||
| size_t input_offset = | |||||
| ic * ih * iw * | |||||
| (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||||
| sizeof(src_ctype); | |||||
| src_ctype* src = reinterpret_cast<src_ctype*>( | |||||
| reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
| input_offset); | |||||
| bool is_phpwzero = | |||||
| (param.filter_meta.padding[0] == 0 && param.filter_meta.padding[1] == 0); | |||||
| if (is_phpwzero) { | |||||
| src = const_cast<src_ctype*>( | |||||
| param.src<src_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| } | |||||
| src_ctype* im2col_dst = | |||||
| reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||||
| if (sh == 1 && sw == 1) { | |||||
| if (is_xcoor) { | |||||
| img2col_nchw8<true>( | |||||
| src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, | |||||
| sparam.output_block_size); | |||||
| } else { | |||||
| img2col_nchw8<false>( | |||||
| src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, | |||||
| sparam.output_block_size); | |||||
| } | |||||
| } else { | |||||
| if (is_xcoor) { | |||||
| img2col_stride_nchw8<true>( | |||||
| src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, | |||||
| sparam.ohw_cur_index, sparam.output_block_size); | |||||
| } else { | |||||
| img2col_stride_nchw8<false>( | |||||
| src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, | |||||
| sparam.ohw_cur_index, sparam.output_block_size); | |||||
| } | |||||
| } | |||||
| matmul_param.M = sparam.output_block_oc_size; | |||||
| matmul_param.N = sparam.output_block_size; | |||||
| matmul_param.LDB = pack_size * sparam.output_block_size; | |||||
| matmul_param.LDC = pack_size * sparam.output_block_size; | |||||
| matmul_param.B_ptr = im2col_dst; | |||||
| src_ctype* b_panel = | |||||
| reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)); | |||||
| matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); | |||||
| } | |||||
| #define INSTANTIAL_CLASS( \ | |||||
| _src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, _postprocess_mode) \ | |||||
| template class Strategy< \ | |||||
| _src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, \ | |||||
| _postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW88>; | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| INSTANTIAL_CLASS( | |||||
| dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
| megdnn::PostprocessMode::FLOAT); | |||||
| #endif | |||||
| #undef INSTANTIAL_CLASS | |||||
| } // namespace megdnn | |||||
| @@ -347,6 +347,441 @@ void img2col_nchw4( | |||||
| } | } | ||||
| } | } | ||||
| template <bool is_xcorr, typename dtype> | |||||
| void img2col_nchw8( | |||||
| const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, | |||||
| const int IH, const int IW, const int FH, const int FW, const int cur_index, | |||||
| const int block_size) { | |||||
| int start_h = cur_index / OW; | |||||
| int cur_n_remain = cur_index % OW; | |||||
| int end_h = (cur_index + block_size) / OW; | |||||
| int end_n_remain = (cur_index + block_size) % OW; | |||||
| bool same_line = (start_h == end_h); | |||||
| int IC_div_8 = IC / 8; | |||||
| if (sizeof(dtype) == 2) { | |||||
| if (same_line) { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| //! TODO: Substitute GI for arm intrinsic when GI supports FP16 | |||||
| //! data type. | |||||
| int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + | |||||
| cur_n_remain + fw2); | |||||
| for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| src_idx)); | |||||
| dst_idx += 8; | |||||
| src_idx += 8; | |||||
| } | |||||
| #else | |||||
| int src_idx = 2 * (ic * IH * IW + (start_h + fh2) * IW + | |||||
| cur_n_remain + fw2); | |||||
| uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
| uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
| for (int w = cur_n_remain; w < end_n_remain; w++) { | |||||
| u64_dst[dst_idx] = u64_src[src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| src_idx += 2; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| int src_idx = 8 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + | |||||
| cur_n_remain); | |||||
| for (int w = cur_n_remain; w < OW; ++w) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| src_idx)); | |||||
| dst_idx += 8; | |||||
| src_idx += 8; | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); | |||||
| for (int h = start_h + 1; h < end_h; ++h) { | |||||
| int _src_idx = src_idx; | |||||
| rep(w, OW) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| _src_idx)); | |||||
| dst_idx += 8; | |||||
| _src_idx += 8; | |||||
| } | |||||
| src_idx += IW * 8; | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); | |||||
| rep(w, end_n_remain) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| src_idx)); | |||||
| dst_idx += 8; | |||||
| src_idx += 8; | |||||
| } | |||||
| #else | |||||
| uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
| uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
| int src_idx = 2 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + | |||||
| cur_n_remain); | |||||
| for (int w = cur_n_remain; w < OW; ++w) { | |||||
| u64_dst[dst_idx] = u64_src[src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| src_idx += 2; | |||||
| } | |||||
| src_idx = 2 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); | |||||
| for (int h = start_h + 1; h < end_h; ++h) { | |||||
| int _src_idx = src_idx; | |||||
| rep(w, OW) { | |||||
| u64_dst[dst_idx] = u64_src[_src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| _src_idx += 2; | |||||
| } | |||||
| src_idx += IW * 2; | |||||
| } | |||||
| src_idx = 2 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); | |||||
| rep(w, end_n_remain) { | |||||
| u64_dst[dst_idx] = u64_src[src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| src_idx += 2; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (same_line) { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + | |||||
| cur_n_remain); | |||||
| for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + | |||||
| cur_n_remain); | |||||
| for (int w = cur_n_remain; w < OW; ++w) { | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (start_h + 1 + fh2) * IW + fw2); | |||||
| for (int h = start_h + 1; h < end_h; ++h) { | |||||
| rep(w, OW) { | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| } | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (end_h + fh2) * IW + fw2); | |||||
| rep(w, end_n_remain) { | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| dst[dst_idx++] = src[src_idx++]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <bool is_xcorr, typename dtype> | |||||
| void img2col_stride_nchw8( | |||||
| const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, | |||||
| const int IH, const int IW, const int FH, const int FW, const int SH, | |||||
| const int SW, const int cur_index, const int block_size) { | |||||
| int start_h = cur_index / OW; | |||||
| int cur_n_remain = cur_index % OW; | |||||
| int end_h = (cur_index + block_size) / OW; | |||||
| int end_n_remain = (cur_index + block_size) % OW; | |||||
| bool same_line = (start_h == end_h); | |||||
| int IC_div_8 = IC / 8; | |||||
| if (sizeof(dtype) == 2) { | |||||
| if (same_line) { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
| cur_n_remain * SW + fw2); | |||||
| for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| src_idx)); | |||||
| dst_idx += 8; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| #else | |||||
| int src_idx = 2 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
| cur_n_remain * SW + fw2); | |||||
| uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
| uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
| for (int w = cur_n_remain; w < end_n_remain; w++) { | |||||
| u64_dst[dst_idx] = u64_src[src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| src_idx += 2 * SW; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| int src_idx = 8 * (ic * IH * IW + (fh2 + start_h * SH) * IW + | |||||
| fw2 + cur_n_remain * SW); | |||||
| for (int w = cur_n_remain; w < OW; ++w) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| src_idx)); | |||||
| dst_idx += 8; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + | |||||
| fw2); | |||||
| for (int h = start_h + 1; h < end_h; ++h) { | |||||
| int _src_idx = src_idx; | |||||
| rep(w, OW) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| _src_idx)); | |||||
| dst_idx += 8; | |||||
| _src_idx += 8 * SW; | |||||
| } | |||||
| src_idx += IW * 8 * SH; | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); | |||||
| rep(w, end_n_remain) { | |||||
| vst1q_f16( | |||||
| reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
| vld1q_f16( | |||||
| reinterpret_cast<const __fp16*>(src) + | |||||
| src_idx)); | |||||
| dst_idx += 8; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| #else | |||||
| uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
| uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
| int src_idx = 2 * (ic * IH * IW + (fh2 + start_h * SH) * IW + | |||||
| fw2 + cur_n_remain * SW); | |||||
| for (int w = cur_n_remain; w < OW; ++w) { | |||||
| u64_dst[dst_idx] = u64_src[src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| src_idx += 2 * SW; | |||||
| } | |||||
| src_idx = 2 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + | |||||
| fw2); | |||||
| for (int h = start_h + 1; h < end_h; ++h) { | |||||
| int _src_idx = src_idx; | |||||
| rep(w, OW) { | |||||
| u64_dst[dst_idx] = u64_src[_src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| _src_idx += 2 * SW; | |||||
| } | |||||
| src_idx += IW * 2 * SH; | |||||
| } | |||||
| src_idx = 2 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); | |||||
| rep(w, end_n_remain) { | |||||
| u64_dst[dst_idx] = u64_src[src_idx]; | |||||
| u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
| dst_idx += 2; | |||||
| src_idx += 2 * SW; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (same_line) { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
| fw2 + cur_n_remain * SW); | |||||
| for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
| dst[dst_idx++] = src[src_idx]; | |||||
| dst[dst_idx++] = src[src_idx + 1]; | |||||
| dst[dst_idx++] = src[src_idx + 2]; | |||||
| dst[dst_idx++] = src[src_idx + 3]; | |||||
| dst[dst_idx++] = src[src_idx + 4]; | |||||
| dst[dst_idx++] = src[src_idx + 5]; | |||||
| dst[dst_idx++] = src[src_idx + 6]; | |||||
| dst[dst_idx++] = src[src_idx + 7]; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| int dst_idx = 0; | |||||
| rep(ic, IC_div_8) { | |||||
| rep(fh, FH) { | |||||
| rep(fw, FW) { | |||||
| int fh2 = fh, fw2 = fw; | |||||
| if (!is_xcorr) { | |||||
| fh2 = FH - fh - 1; | |||||
| fw2 = FW - fw - 1; | |||||
| } | |||||
| int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
| fw2 + cur_n_remain * SW); | |||||
| for (int w = cur_n_remain; w < OW; ++w) { | |||||
| dst[dst_idx++] = src[src_idx]; | |||||
| dst[dst_idx++] = src[src_idx + 1]; | |||||
| dst[dst_idx++] = src[src_idx + 2]; | |||||
| dst[dst_idx++] = src[src_idx + 3]; | |||||
| dst[dst_idx++] = src[src_idx + 4]; | |||||
| dst[dst_idx++] = src[src_idx + 5]; | |||||
| dst[dst_idx++] = src[src_idx + 6]; | |||||
| dst[dst_idx++] = src[src_idx + 7]; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + ((start_h + 1) * SH + fh2) * IW + | |||||
| fw2); | |||||
| for (int h = start_h + 1; h < end_h; ++h) { | |||||
| rep(w, OW) { | |||||
| dst[dst_idx++] = src[src_idx]; | |||||
| dst[dst_idx++] = src[src_idx + 1]; | |||||
| dst[dst_idx++] = src[src_idx + 2]; | |||||
| dst[dst_idx++] = src[src_idx + 3]; | |||||
| dst[dst_idx++] = src[src_idx + 4]; | |||||
| dst[dst_idx++] = src[src_idx + 5]; | |||||
| dst[dst_idx++] = src[src_idx + 6]; | |||||
| dst[dst_idx++] = src[src_idx + 7]; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| } | |||||
| src_idx = 8 * (ic * IH * IW + (end_h * SH + fh2) * IW + fw2); | |||||
| rep(w, end_n_remain) { | |||||
| dst[dst_idx++] = src[src_idx]; | |||||
| dst[dst_idx++] = src[src_idx + 1]; | |||||
| dst[dst_idx++] = src[src_idx + 2]; | |||||
| dst[dst_idx++] = src[src_idx + 3]; | |||||
| dst[dst_idx++] = src[src_idx + 4]; | |||||
| dst[dst_idx++] = src[src_idx + 5]; | |||||
| dst[dst_idx++] = src[src_idx + 6]; | |||||
| dst[dst_idx++] = src[src_idx + 7]; | |||||
| src_idx += 8 * SW; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <bool is_xcorr, typename dtype> | template <bool is_xcorr, typename dtype> | ||||
| void img2col_stride( | void img2col_stride( | ||||
| const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, | const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, | ||||
| @@ -68,6 +68,87 @@ void benchmark_impl( | |||||
| multi_thread_config.nr_thread); | multi_thread_config.nr_thread); | ||||
| } | } | ||||
| } | } | ||||
| void benchmark_with_contrast( | |||||
| const std::vector<conv_bias::TestArg>& args, const std::string algo_name, | |||||
| std::vector<DType>& data_type, | |||||
| const std::vector<conv_bias::TestArg>& args_contrast, | |||||
| const std::string algo_name_contrast, std::vector<DType>& data_type_contrast, | |||||
| size_t RUNS, TaskExecutorConfig&& single_thread_config) { | |||||
| auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); | |||||
| auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | |||||
| auto benchmarker_contrast = Benchmarker<ConvBias>(single_thread_handle.get()); | |||||
| benchmarker.set_times(RUNS) | |||||
| .set_display(false) | |||||
| .set_dtype(0, data_type[0]) | |||||
| .set_dtype(1, data_type[1]) | |||||
| .set_dtype(2, data_type[2]) | |||||
| .set_dtype(4, data_type[3]) | |||||
| .set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||||
| benchmarker_contrast.set_times(RUNS) | |||||
| .set_display(false) | |||||
| .set_dtype(0, data_type_contrast[0]) | |||||
| .set_dtype(1, data_type_contrast[1]) | |||||
| .set_dtype(2, data_type_contrast[2]) | |||||
| .set_dtype(4, data_type_contrast[3]) | |||||
| .set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||||
| algo_name_contrast.c_str())); | |||||
| size_t arg_size = args.size(), arg_contrast_size = args_contrast.size(); | |||||
| megdnn_assert(arg_size == arg_contrast_size); | |||||
| rep(i, arg_size) { | |||||
| TensorLayout dst_layout, dst_layout_contrast; | |||||
| auto opr = single_thread_handle.get()->create_operator<ConvBias>(); | |||||
| auto&& arg = args[i]; | |||||
| opr->param() = arg.param; | |||||
| opr->deduce_layout( | |||||
| {arg.src, data_type[0]}, {arg.filter, data_type[1]}, | |||||
| {arg.bias, data_type[2]}, {}, dst_layout); | |||||
| float computation = (dst_layout.total_nr_elems() * arg.filter[1] * | |||||
| arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) / | |||||
| (1024 * 1024 * 1024) * 1e3; | |||||
| benchmarker.set_param(arg.param); | |||||
| auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS; | |||||
| auto&& arg_contrast = args_contrast[i]; | |||||
| opr->param() = arg_contrast.param; | |||||
| opr->deduce_layout( | |||||
| {arg_contrast.src, data_type_contrast[0]}, | |||||
| {arg_contrast.filter, data_type_contrast[1]}, | |||||
| {arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast); | |||||
| float computation_contrast = | |||||
| (dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] * | |||||
| arg_contrast.filter[2] * arg_contrast.filter[3] * | |||||
| arg_contrast.filter[4] * 2.0) / | |||||
| (1024 * 1024 * 1024) * 1e3; | |||||
| benchmarker_contrast.set_param(arg_contrast.param); | |||||
| auto used_contrast = benchmarker_contrast.exec( | |||||
| {arg_contrast.src, | |||||
| arg_contrast.filter, | |||||
| arg_contrast.bias, | |||||
| {}, | |||||
| {}}) / | |||||
| RUNS; | |||||
| printf("Bench case: \n"); | |||||
| printf("padding: %u, stride: %u, nonline mode: %u\n", arg.param.pad_h, | |||||
| arg.param.stride_h, arg.param.nonlineMode); | |||||
| printf("%s %s %s\n", arg.src.to_string().c_str(), | |||||
| arg.filter.to_string().c_str(), arg.bias.to_string().c_str()); | |||||
| printf("%s %s %s\n", arg_contrast.src.to_string().c_str(), | |||||
| arg_contrast.filter.to_string().c_str(), | |||||
| arg_contrast.bias.to_string().c_str()); | |||||
| printf("%s: %f gflops;\n%s: %f gflops\n" | |||||
| "spead up = %f\n", | |||||
| algo_name.c_str(), computation / used, algo_name_contrast.c_str(), | |||||
| computation_contrast / used_contrast, used_contrast / used); | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| @@ -1591,6 +1672,91 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_FP32) { | |||||
| data_type); | data_type); | ||||
| shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_NCHW88) { | |||||
| constexpr size_t RUNS = 50; | |||||
| using NLMode = param::ConvBias::NonlineMode; | |||||
| std::vector<conv_bias::TestArg> args_nchw88, args_nchw44; | |||||
| auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
| size_t group) { | |||||
| param::ConvBias param_nchw88, param_nchw44; | |||||
| param_nchw88.format = param::ConvBias::Format::NCHW88; | |||||
| param_nchw44.format = param::ConvBias::Format::NCHW44; | |||||
| for (size_t pad : {1, 2, 4}) { | |||||
| for (size_t stride : {1, 2, 3}) { | |||||
| for (auto nlmode : | |||||
| {NLMode::RELU, NLMode::IDENTITY, NLMode::SIGMOID, | |||||
| NLMode::H_SWISH}) { | |||||
| param_nchw88.nonlineMode = nlmode; | |||||
| param_nchw88.pad_h = pad; | |||||
| param_nchw88.pad_w = pad; | |||||
| param_nchw88.stride_h = stride; | |||||
| param_nchw88.stride_w = stride; | |||||
| param_nchw44.nonlineMode = nlmode; | |||||
| param_nchw44.pad_h = pad; | |||||
| param_nchw44.pad_w = pad; | |||||
| param_nchw44.stride_h = stride; | |||||
| param_nchw44.stride_w = stride; | |||||
| args_nchw88.emplace_back( | |||||
| param_nchw88, TensorShape{N, IC / 8, H, W, 8}, | |||||
| TensorShape{OC / 8, IC / group / 8, FS, FS, 8, 8}, | |||||
| TensorShape{1, OC / 8, 1, 1, 8}); | |||||
| args_nchw44.emplace_back( | |||||
| param_nchw44, TensorShape{N, IC / 4, H, W, 4}, | |||||
| TensorShape{OC / 4, IC / group / 4, FS, FS, 4, 4}, | |||||
| TensorShape{1, OC / 4, 1, 1, 4}); | |||||
| } | |||||
| } | |||||
| } | |||||
| }; | |||||
| std::vector<DType> data_type_fp16 = { | |||||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; | |||||
| std::vector<DType> data_type_fp32 = { | |||||
| dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
| bench_case(1, 32, 32, 300, 300, 3, 1); | |||||
| bench_case(1, 32, 32, 400, 400, 3, 1); | |||||
| bench_case(1, 32, 32, 100, 100, 3, 1); | |||||
| bench_case(1, 32, 32, 80, 80, 3, 1); | |||||
| bench_case(1, 32, 64, 200, 200, 3, 1); | |||||
| bench_case(1, 32, 64, 128, 128, 3, 1); | |||||
| bench_case(1, 32, 64, 100, 100, 3, 1); | |||||
| bench_case(1, 32, 64, 80, 80, 3, 1); | |||||
| bench_case(1, 32, 128, 200, 200, 3, 1); | |||||
| bench_case(1, 32, 128, 128, 128, 3, 1); | |||||
| bench_case(1, 32, 128, 100, 100, 3, 1); | |||||
| bench_case(1, 32, 128, 80, 80, 3, 1); | |||||
| bench_case(1, 64, 32, 7, 7, 3, 1); | |||||
| bench_case(1, 64, 64, 7, 7, 3, 1); | |||||
| bench_case(1, 64, 128, 7, 7, 3, 1); | |||||
| bench_case(1, 64, 256, 7, 7, 3, 1); | |||||
| bench_case(1, 64, 512, 7, 7, 3, 1); | |||||
| bench_case(1, 64, 1024, 7, 7, 3, 1); | |||||
| bench_case(1, 64, 32, 14, 14, 3, 1); | |||||
| bench_case(1, 64, 64, 14, 14, 3, 1); | |||||
| bench_case(1, 64, 128, 14, 14, 3, 1); | |||||
| bench_case(1, 64, 256, 14, 14, 3, 1); | |||||
| bench_case(1, 64, 512, 14, 14, 3, 1); | |||||
| bench_case(1, 64, 1024, 14, 14, 3, 1); | |||||
| bench_case(1, 128, 128, 14, 14, 3, 1); | |||||
| bench_case(1, 128, 256, 14, 14, 3, 1); | |||||
| bench_case(1, 512, 512, 14, 14, 3, 1); | |||||
| bench_case(1, 256, 512, 14, 14, 3, 1); | |||||
| bench_case(1, 512, 1024, 14, 14, 3, 1); | |||||
| bench_case(1, 1024, 1024, 14, 14, 3, 1); | |||||
| std::string algo_name_nchw88 = "IM2COLMATMUL:AARCH64_F16_MK8_16X12X1:96"; | |||||
| std::string algo_name_nchw44 = "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1:96"; | |||||
| benchmark_with_contrast( | |||||
| args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44, | |||||
| algo_name_nchw44, data_type_fp32, RUNS, {1, {4}}); | |||||
| } | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
| BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { | BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { | ||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| @@ -362,6 +362,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { | |||||
| #endif | #endif | ||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_MK8_FP16) { | |||||
| using namespace conv_bias; | |||||
| std::vector<conv_bias::TestArg> args = get_nchw88_conv_bias_args( | |||||
| {2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
| auto args1 = get_nchw88_conv_bias_args( | |||||
| {2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 2, 3); | |||||
| args.insert(args.begin(), args1.begin(), args1.begin()); | |||||
| args1 = get_nchw88_conv_bias_args( | |||||
| {2, 3, 4, 5, 6, 7, 9}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 3, 4); | |||||
| args.insert(args.begin(), args1.begin(), args1.begin()); | |||||
| NormalRNG rng(1); | |||||
| #define cb(name) \ | |||||
| checker_conv_bias_common( \ | |||||
| args, handle(), &rng, 0.03, dtype::Float16{}, dtype::Float16{}, \ | |||||
| dtype::Float16{}, dtype::Float16{}, name); | |||||
| #if MEGDNN_AARCH64 | |||||
| cb("IM2COLMATMUL:AARCH64_F16_MK8_16X12X1"); | |||||
| #endif | |||||
| #undef cb | |||||
| } | |||||
| #endif | #endif | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| @@ -161,6 +161,24 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { | |||||
| run(); | run(); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON, ELEMWISE_SIGMOID) { | |||||
| using Mode = ElemwiseForward::Param::Mode; | |||||
| Checker<ElemwiseForward> checker(handle()); | |||||
| checker.set_epsilon(1e-3); | |||||
| checker.set_dtype(0, dtype::Float16()); | |||||
| checker.set_param(Mode::SIGMOID); | |||||
| for (size_t n : {1, 2, 3}) { | |||||
| for (size_t ic : {8, 16, 24, 32}) { | |||||
| for (size_t ih : {5, 10, 15, 20, 21, 37}) { | |||||
| for (size_t iw : {7, 9, 11, 13, 14, 20, 35}) { | |||||
| checker.exec({{n, ic, ih, iw}, {}}); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { | TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { | ||||
| using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
| Checker<ElemwiseForward> checker(handle()); | Checker<ElemwiseForward> checker(handle()); | ||||
| @@ -98,6 +98,9 @@ TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY) { | |||||
| BENCHMARK_CASES_INT(shape, dtype::Int16()); | BENCHMARK_CASES_INT(shape, dtype::Int16()); | ||||
| BENCHMARK_CASES_INT(shape, dtype::Int8()); | BENCHMARK_CASES_INT(shape, dtype::Int8()); | ||||
| BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); | BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| BENCHMARK_CASES_FLOAT(shape, dtype::Float16()); | |||||
| #endif | |||||
| #undef BENCHMARK_CASES_INT | #undef BENCHMARK_CASES_INT | ||||
| #undef BENCHMARK_CASES_FLOAT | #undef BENCHMARK_CASES_FLOAT | ||||
| #undef RUN | #undef RUN | ||||
| @@ -1580,17 +1580,19 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | ||||
| std::vector<size_t> kernel_vec, | std::vector<size_t> kernel_vec, | ||||
| std::vector<param::ConvBias::NonlineMode> nlmode_vec, | std::vector<param::ConvBias::NonlineMode> nlmode_vec, | ||||
| std::vector<megdnn::BiasMode> biasmode_vec, size_t stride) { | |||||
| std::vector<megdnn::BiasMode> biasmode_vec, size_t stride, int pad) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| using NLMode = param::ConvBias::NonlineMode; | using NLMode = param::ConvBias::NonlineMode; | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, | auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, | ||||
| size_t stride, size_t group, NLMode nlmode, | |||||
| size_t stride, int pad, size_t group, NLMode nlmode, | |||||
| megdnn::BiasMode bias_mode) { | megdnn::BiasMode bias_mode) { | ||||
| constexpr int pack_c = 8; | constexpr int pack_c = 8; | ||||
| const size_t pad = kernel / 2; | |||||
| if (pad == -1) { | |||||
| pad = kernel / 2; | |||||
| } | |||||
| auto oc_per_group = oc / group; | auto oc_per_group = oc / group; | ||||
| auto ic_per_group = ic / group; | auto ic_per_group = ic / group; | ||||
| @@ -1651,8 +1653,8 @@ std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | |||||
| if (kernel < h || kernel < w) { | if (kernel < h || kernel < w) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| pack(n, oc, ic, h, w, kernel, stride, group, | |||||
| nlmode, bias); | |||||
| pack(n, oc, ic, h, w, kernel, stride, pad, | |||||
| group, nlmode, bias); | |||||
| } | } | ||||
| } | } | ||||
| return args; | return args; | ||||