| @@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
| 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | ||||
| Doc('MK8', 'Split 8 from M and K, better for neon compute:' | Doc('MK8', 'Split 8 from M and K, better for neon compute:' | ||||
| '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | ||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
| Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | ||||
| 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | ||||
| 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | ||||
| @@ -858,7 +858,10 @@ when the ``I`` suffix is present. | |||||
| 'NCHW_NCHW88_CONV_CHAN_WEIGHT', | 'NCHW_NCHW88_CONV_CHAN_WEIGHT', | ||||
| 'NCHW_NCHW88_CONV_GROUP_WEIGHT', | 'NCHW_NCHW88_CONV_GROUP_WEIGHT', | ||||
| 'NCHW_NCHW88', | 'NCHW_NCHW88', | ||||
| 'NCHW88_NCHW') | |||||
| 'NCHW88_NCHW', | |||||
| 'NCHW_NCHW4_IC_SMALL', | |||||
| 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', | |||||
| ) | |||||
| ) | ) | ||||
| @@ -90,12 +90,11 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { | |||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride1_2x2_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride1_2x2_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
| const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, | const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, | ||||
| 2, 3, 16, 16, 3, 4, 16, 16}; | 2, 3, 16, 16, 3, 4, 16, 16}; | ||||
| @@ -326,12 +325,11 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot(const int8_t* src, | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride1_3x3_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride1_3x3_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
| const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, | const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, | ||||
| @@ -562,12 +560,11 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot(const int8_t* src, | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride2_2x2_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride2_2x2_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, | const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, | ||||
| @@ -658,12 +655,11 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot(const int8_t* src, | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride2_3x3_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride2_3x3_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, | const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, | ||||
| @@ -814,12 +810,11 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot(const int8_t* src, | |||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride2_5x5_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride2_5x5_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | ||||
| @@ -1113,12 +1108,11 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot(const int8_t* src, | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride2_7x7_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride2_7x7_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; | ||||
| @@ -1476,12 +1470,11 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot(const int8_t* src, | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride1_5x5_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride1_5x5_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
| const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | ||||
| @@ -1777,12 +1770,11 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot(const int8_t* src, | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| void conv_bias::conv_direct_stride1_7x7_int8_dot(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const size_t IH, | |||||
| const size_t IW, const size_t OH, | |||||
| const size_t OW, const Op& op) { | |||||
| void conv_bias::conv_direct_stride1_7x7_int8_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | |||||
| const size_t OH, const size_t OW, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(IH); | |||||
| const size_t tail_step = IW - OW; | const size_t tail_step = IW - OW; | ||||
| const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; | ||||
| @@ -29,6 +29,7 @@ void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step, | |||||
| const int ih, const int pad_left, | const int ih, const int pad_left, | ||||
| const int pad_right, const int pad_top, | const int pad_right, const int pad_top, | ||||
| const int pad_bottom) { | const int pad_bottom) { | ||||
| MEGDNN_MARK_USED_VAR(pad_right); | |||||
| constexpr int IC_PACK_SIZE = 4; | constexpr int IC_PACK_SIZE = 4; | ||||
| rep_step(ic_idx, ic, IC_PACK_SIZE) { | rep_step(ic_idx, ic, IC_PACK_SIZE) { | ||||
| const int8_t* i_src = src + ic_idx * ic_step; | const int8_t* i_src = src + ic_idx * ic_step; | ||||
| @@ -66,6 +67,7 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, | |||||
| const int ih, const int pad_left, | const int ih, const int pad_left, | ||||
| const int pad_right, const int pad_top, | const int pad_right, const int pad_top, | ||||
| const int pad_bottom) { | const int pad_bottom) { | ||||
| MEGDNN_MARK_USED_VAR(pad_right); | |||||
| constexpr int IC_PACK_SIZE = 4; | constexpr int IC_PACK_SIZE = 4; | ||||
| int odd_start = megdnn::div_ceil(dst_step, 2); | int odd_start = megdnn::div_ceil(dst_step, 2); | ||||
| bool nochange = pad_left % 2 == 0; | bool nochange = pad_left % 2 == 0; | ||||
| @@ -367,4 +369,4 @@ FOR_FILTER(2) | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| //vim: syntax=cpp.doxygen | |||||
| //vim: syntax=cpp.doxygen | |||||
| @@ -163,6 +163,7 @@ static void conv_kern(WorkspaceBundle bundle, | |||||
| bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | ||||
| FallbackConvBiasImpl*, const NCBKernSizeParam& param, | FallbackConvBiasImpl*, const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy algo_selection_strategy) const { | AlgoSelectionStrategy algo_selection_strategy) const { | ||||
| MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
| auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
| auto FW = fm.spatial[1]; | auto FW = fm.spatial[1]; | ||||
| @@ -199,6 +200,7 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | |||||
| bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | ||||
| megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | ||||
| MEGDNN_MARK_USED_VAR(param); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -338,4 +340,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | |||||
| #endif | #endif | ||||
| //vim: syntax=cpp.doxygen | |||||
| //vim: syntax=cpp.doxygen | |||||
| @@ -98,6 +98,7 @@ template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<1, ow_remain, Op, T> { | struct StoreOCxOWx<1, ow_remain, Op, T> { | ||||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | ||||
| const int ld_dst_oc) { | const int ld_dst_oc) { | ||||
| MEGDNN_MARK_USED_VAR(ld_dst_oc); | |||||
| switch (ow_remain) { | switch (ow_remain) { | ||||
| case 8: | case 8: | ||||
| UNROLL_CALL_RAW(4, cb12); | UNROLL_CALL_RAW(4, cb12); | ||||
| @@ -337,14 +337,11 @@ ConvBias::WinogradParam ConvBias::parse_winograd_name( | |||||
| &(ret.channel_block_size), &(ret.output_block_size), | &(ret.channel_block_size), &(ret.output_block_size), | ||||
| &(ret.tile_size)); | &(ret.tile_size)); | ||||
| if (strcmp(name, pre.c_str())) { | if (strcmp(name, pre.c_str())) { | ||||
| megdnn_log_warn("algo %s is not %s algo", name, pre.c_str()); | |||||
| ret = INVALID_WINOGRAD_PARAM; | ret = INVALID_WINOGRAD_PARAM; | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (ret.tile_size == 0 || ret.output_block_size == 0 || | if (ret.tile_size == 0 || ret.output_block_size == 0 || | ||||
| ret.channel_block_size == 0) { | ret.channel_block_size == 0) { | ||||
| megdnn_log_warn("the algo name %s is not suitable for %s", | |||||
| algo_name.c_str(), pre.c_str()); | |||||
| ret = INVALID_WINOGRAD_PARAM; | ret = INVALID_WINOGRAD_PARAM; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||||
| dst[3] = src[3]; | dst[3] = src[3]; | ||||
| dst[4] = 4; | dst[4] = 4; | ||||
| break; | break; | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
| dst.ndim = 5; | |||||
| megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
| dst[0] = src[0]; | |||||
| dst[1] = div_ceil(src[1], 4_z); | |||||
| dst[2] = src[2]; | |||||
| dst[3] = src[3]; | |||||
| dst[4] = 4; | |||||
| break; | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
| megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); | |||||
| megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); | |||||
| dst.ndim = 5; | |||||
| dst[0] = src[0]; | |||||
| dst[1] = div_ceil(src[1], 4_z); | |||||
| dst[2] = src[2]; | |||||
| dst[3] = src[3]; | |||||
| dst[4] = 4; | |||||
| break; | |||||
| case Param::Mode::NCHW_NCHW88: | case Param::Mode::NCHW_NCHW88: | ||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = src[0]; | dst[0] = src[0]; | ||||
| @@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
| case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: | ||||
| case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: | ||||
| case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
| CHECK_SRC(DefaultTensorFormat::make()); | CHECK_SRC(DefaultTensorFormat::make()); | ||||
| dst = src; | dst = src; | ||||
| break; | break; | ||||
| @@ -284,6 +306,15 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
| megdnn_throw("Invalid relayout format mode"); | megdnn_throw("Invalid relayout format mode"); | ||||
| break; | break; | ||||
| } | } | ||||
| if (!dst.is_default() && | |||||
| ( | |||||
| handle()->type() != Handle::HandleType::NAIVE)) { | |||||
| megdnn_throw( | |||||
| "Only naive and opencl handle support " | |||||
| "Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 " | |||||
| "to enable naive handle"); | |||||
| } | |||||
| #undef CHECK_SRC | #undef CHECK_SRC | ||||
| } | } | ||||
| @@ -374,6 +405,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||||
| exec_dst = dst; | exec_dst = dst; | ||||
| } | } | ||||
| break; | break; | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: | |||||
| // nchw to nchw4c or oihw to oihw4i | |||||
| { | |||||
| TensorLayout work_space_layout( | |||||
| {src[0], round_up(src[1], 4_z), src[2], src[3]}, | |||||
| src.dtype, src.format); | |||||
| exec_src = work_space_layout | |||||
| .reshape({src[0], div_ceil(src[1], 4_z), 4, | |||||
| src[2], src[3]}) | |||||
| .dimshuffle({0, 1, 3, 4, 2}); | |||||
| exec_dst = dst; | |||||
| } | |||||
| break; | |||||
| case Param::Mode::NCHW_NHWCD4: | case Param::Mode::NCHW_NHWCD4: | ||||
| case Param::Mode::NCHW_NHWCD4I: | case Param::Mode::NCHW_NHWCD4I: | ||||
| // src is {N, C, H, W} | // src is {N, C, H, W} | ||||
| @@ -10,6 +10,7 @@ | |||||
| */ | */ | ||||
| #include "src/cuda/convolution/opr_impl.h" | #include "src/cuda/convolution/opr_impl.h" | ||||
| #include "megdnn/dtype.h" | |||||
| #include "src/cuda/convolution/helper.h" | #include "src/cuda/convolution/helper.h" | ||||
| #include "src/cuda/convolution/backward_data/algo.h" | #include "src/cuda/convolution/backward_data/algo.h" | ||||
| #include "src/cuda/convolution/backward_filter/algo.h" | #include "src/cuda/convolution/backward_filter/algo.h" | ||||
| @@ -28,10 +29,35 @@ using namespace convolution; | |||||
| /* ============== ConvolutionForwardImpl ============== */ | /* ============== ConvolutionForwardImpl ============== */ | ||||
| ConvolutionForwardImpl::ConvBiasExtraData | ConvolutionForwardImpl::ConvBiasExtraData | ||||
| ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& dst) { | |||||
| ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) { | |||||
| auto conv_param = param(); | auto conv_param = param(); | ||||
| DType bias_type; | |||||
| if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| bias_type = dtype::QuantizedS32( | |||||
| src.dtype.param<dtype::QuantizedS8>().scale * | |||||
| filter.dtype.param<dtype::QuantizedS8>().scale); | |||||
| } else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| bias_type = dtype::QuantizedS32( | |||||
| src.dtype.param<dtype::Quantized8Asymm>().scale * | |||||
| filter.dtype.param<dtype::Quantized8Asymm>().scale); | |||||
| } else if (src.dtype.enumv() == DTypeEnum::Uint8 || | |||||
| src.dtype.enumv() == DTypeEnum::Int8) { | |||||
| bias_type = dtype::Int32{}; | |||||
| } else if (src.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| bias_type = dtype::QuantizedS32( | |||||
| src.dtype.param<dtype::Quantized4Asymm>().scale * | |||||
| filter.dtype.param<dtype::Quantized4Asymm>().scale); | |||||
| } else { | |||||
| megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | |||||
| bias_type = src.dtype; | |||||
| } | |||||
| ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(), | ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(), | ||||
| TensorLayout(dst.dtype), TensorLayout(dst.dtype)}; | |||||
| TensorLayout(bias_type), TensorLayout(dst.dtype)}; | |||||
| ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | ||||
| conv_param.mode, | conv_param.mode, | ||||
| conv_param.sparse, | conv_param.sparse, | ||||
| @@ -54,7 +80,7 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| bool reproducible) { | bool reproducible) { | ||||
| auto extra_data = conv_bias_extra_data(dst); | |||||
| auto extra_data = conv_bias_extra_data(src, filter, dst); | |||||
| return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
| ->get_algorithm_heuristic(src, filter, extra_data.bias_layout, | ->get_algorithm_heuristic(src, filter, extra_data.bias_layout, | ||||
| extra_data.z_layout, dst, | extra_data.z_layout, dst, | ||||
| @@ -65,7 +91,7 @@ std::vector<ConvolutionForwardImpl::Algorithm*> | |||||
| ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
| const TensorLayout& filter, | const TensorLayout& filter, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| auto extra_data = conv_bias_extra_data(dst); | |||||
| auto extra_data = conv_bias_extra_data(src, filter, dst); | |||||
| return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
| ->get_all_algorithms(src, filter, extra_data.bias_layout, | ->get_all_algorithms(src, filter, extra_data.bias_layout, | ||||
| extra_data.z_layout, dst); | extra_data.z_layout, dst); | ||||
| @@ -75,7 +101,7 @@ size_t ConvolutionForwardImpl::get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| const PreprocessedFilter* preprocessed_filter) { | const PreprocessedFilter* preprocessed_filter) { | ||||
| auto extra_data = conv_bias_extra_data(dst); | |||||
| auto extra_data = conv_bias_extra_data(src, filter, dst); | |||||
| return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
| ->get_workspace_in_bytes( | ->get_workspace_in_bytes( | ||||
| src, filter, extra_data.bias_layout, extra_data.z_layout, | src, filter, extra_data.bias_layout, extra_data.z_layout, | ||||
| @@ -90,7 +116,8 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
| const PreprocessedFilter* preprocessed_filter, | const PreprocessedFilter* preprocessed_filter, | ||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| auto extra_data = conv_bias_extra_data(dst.layout); | |||||
| auto extra_data = | |||||
| conv_bias_extra_data(src.layout, filter.layout, dst.layout); | |||||
| TensorND bias(nullptr, extra_data.bias_layout); | TensorND bias(nullptr, extra_data.bias_layout); | ||||
| TensorND z(nullptr, extra_data.z_layout); | TensorND z(nullptr, extra_data.z_layout); | ||||
| return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) | ||||
| @@ -61,7 +61,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
| TensorLayout z_layout; | TensorLayout z_layout; | ||||
| }; | }; | ||||
| private: | private: | ||||
| ConvBiasExtraData conv_bias_extra_data(const TensorLayout&); | |||||
| ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&); | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | ||||
| @@ -32,7 +32,7 @@ void create_param(const DeformablePSROIPoolingBase* opr, | |||||
| p.sample_per_part = param.sample_per_part; | p.sample_per_part = param.sample_per_part; | ||||
| p.trans_std = param.trans_std; | p.trans_std = param.trans_std; | ||||
| p.scale = param.spatial_scale; | p.scale = param.spatial_scale; | ||||
| p.nr_cls = p.no_trans ? 1 : trans[0]; | |||||
| p.nr_cls = p.no_trans ? 1 : trans[1] / 2; | |||||
| p.nr_bbox = rois[0]; | p.nr_bbox = rois[0]; | ||||
| p.IC = data[1]; | p.IC = data[1]; | ||||
| p.IH = data[2]; | p.IH = data[2]; | ||||
| @@ -11,6 +11,7 @@ | |||||
| #include "src/cuda/relayout_format/opr_impl.h" | #include "src/cuda/relayout_format/opr_impl.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| auto src_dtype = src.layout.dtype; | auto src_dtype = src.layout.dtype; | ||||
| megdnn_assert( | megdnn_assert( | ||||
| param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | ||||
| param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, | |||||
| param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || | |||||
| param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
| param().mode == | |||||
| Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, | |||||
| "relayout format of cuda only support NCHW4->CHWN4 or " | "relayout format of cuda only support NCHW4->CHWN4 or " | ||||
| "CHWN4->NCHW4"); | |||||
| if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| "CHWN4->NCHW4 or NCHW->NCHW4"); | |||||
| if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||||
| param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) && | |||||
| src_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| size_t row = 0, col = 0; | size_t row = 0, col = 0; | ||||
| if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { | ||||
| row = src.layout[0], | row = src.layout[0], | ||||
| col = src.layout[1] * src.layout[2] * src.layout[3]; | col = src.layout[1] * src.layout[2] * src.layout[3]; | ||||
| } else { | } else { | ||||
| megdnn_assert(param().mode == | |||||
| param::RelayoutFormat::Mode::CHWN4_NCHW4); | |||||
| row = src.layout[0] * src.layout[1] * src.layout[2], | row = src.layout[0] * src.layout[1] * src.layout[2], | ||||
| col = src.layout[3]; | col = src.layout[3]; | ||||
| } | } | ||||
| @@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| return handle()->create_operator<RelayoutForward>()->exec(trans_in, | return handle()->create_operator<RelayoutForward>()->exec(trans_in, | ||||
| trans_out); | trans_out); | ||||
| } | } | ||||
| if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || | |||||
| param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) && | |||||
| src.layout[1] % 4 != 0) { | |||||
| megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4, | |||||
| "The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT " | |||||
| "of RelayoutFormat opr(cuda backend) does not support " | |||||
| "src.ptr == dst.ptr"); | |||||
| megdnn_assert(src.layout[1] <= 4); | |||||
| cuda_check(cudaMemsetAsync(dst.raw_ptr, 0, | |||||
| dst.layout.span().dist_byte(), | |||||
| cuda_stream(this->handle()))); | |||||
| TensorLayout exec_dst_layout = dst.layout; | |||||
| exec_dst_layout[4] = src.layout[1]; | |||||
| TensorLayout exec_src_layout = | |||||
| src.layout | |||||
| .reshape({src.layout[0], src.layout[1], 1, | |||||
| src.layout[2], src.layout[3]}) | |||||
| .dimshuffle({0, 2, 3, 4, 1}); | |||||
| return handle()->create_operator<RelayoutForward>()->exec( | |||||
| {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||||
| } | |||||
| TensorLayout exec_src, exec_dst; | TensorLayout exec_src, exec_dst; | ||||
| deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); | ||||
| TensorND exec_src_nd{src.raw_ptr, exec_src}; | TensorND exec_src_nd{src.raw_ptr, exec_src}; | ||||
| @@ -293,7 +293,7 @@ void Fwd::exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, | |||||
| float trans_std = param.trans_std, scale = param.spatial_scale; | float trans_std = param.trans_std, scale = param.spatial_scale; | ||||
| size_t nr_bbox = rois.layout[0]; | size_t nr_bbox = rois.layout[0]; | ||||
| size_t nr_cls = no_trans ? 1 : trans.layout[0]; | |||||
| size_t nr_cls = no_trans ? 1 : trans.layout[1] / 2; | |||||
| size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | ||||
| const float* data_ptr = data.ptr<float>(); | const float* data_ptr = data.ptr<float>(); | ||||
| @@ -339,7 +339,7 @@ void Bwd::exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, | |||||
| float trans_std = param.trans_std, scale = param.spatial_scale; | float trans_std = param.trans_std, scale = param.spatial_scale; | ||||
| size_t nr_bbox = rois.layout[0]; | size_t nr_bbox = rois.layout[0]; | ||||
| size_t nr_cls = no_trans ? 1 : trans.layout[0]; | |||||
| size_t nr_cls = no_trans ? 1 : trans.layout[1] / 2; | |||||
| size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; | ||||
| const float* data_ptr = data.ptr<float>(); | const float* data_ptr = data.ptr<float>(); | ||||
| @@ -107,11 +107,7 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | |||||
| m_dispatcher{megcoreGetCPUDispatcher(computing_handle)} {} | m_dispatcher{megcoreGetCPUDispatcher(computing_handle)} {} | ||||
| size_t HandleImpl::image2d_pitch_alignment() const { | size_t HandleImpl::image2d_pitch_alignment() const { | ||||
| if (type() == Handle::HandleType::NAIVE) { | |||||
| // only naive CPU handle supports this format | |||||
| return g_image2d_pitch_alignment; | |||||
| } | |||||
| megdnn_throw("Image2DTensorFormat is not supported on this handle"); | |||||
| return g_image2d_pitch_alignment; | |||||
| } | } | ||||
| size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) { | size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) { | ||||
| @@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, | |||||
| } | } | ||||
| } | } | ||||
| } // anonymous namespace | |||||
| } // namespace | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace naive { | namespace naive { | ||||
| void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| MIDOUT_BEGIN(megdnn_naive_pooling) { | |||||
| check_exec(src.layout, dst.layout, workspace.size); | |||||
| size_t c_pos, spatial_pos, batch_pos = 0; | |||||
| if (param().format == Param::Format::NCHW || | |||||
| param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW88 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW32) { | |||||
| c_pos = 1; | |||||
| spatial_pos = 2; | |||||
| } else if (param().format == Param::Format::NHWC) { | |||||
| c_pos = 3; | |||||
| spatial_pos = 1; | |||||
| } else if (param().format == Param::Format::CHWN4) { | |||||
| c_pos = 0; | |||||
| spatial_pos = 1; | |||||
| batch_pos = 3; | |||||
| } else { | |||||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
| c_pos = 2; | |||||
| spatial_pos = 1; | |||||
| } | |||||
| size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||||
| IH = src.layout.shape[spatial_pos + 0], | |||||
| IW = src.layout.shape[spatial_pos + 1]; | |||||
| size_t OH = dst.layout.shape[spatial_pos + 0], | |||||
| OW = dst.layout.shape[spatial_pos + 1]; | |||||
| if (param().format == Param::Format::NHWCD4) { | |||||
| C *= 4; | |||||
| IW = src.layout.shape[spatial_pos + 2]; | |||||
| OW = dst.layout.shape[spatial_pos + 2]; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::CHWN4) { | |||||
| C *= 4; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW88) { | |||||
| C *= 8; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW32) { | |||||
| C *= 32; | |||||
| } | |||||
| size_t PH = param().pad_h, PW = param().pad_w; | |||||
| size_t FH = param().window_h, FW = param().window_w; | |||||
| size_t SH = param().stride_h, SW = param().stride_w; | |||||
| #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<naive::HandleImpl*>(handle()), \ | |||||
| pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||||
| sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ | |||||
| PW, SH, SW, FH, FW)); | |||||
| check_exec(src.layout, dst.layout, workspace.size); | |||||
| size_t c_pos, spatial_pos, batch_pos = 0; | |||||
| if (param().format == Param::Format::NCHW || | |||||
| param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW88 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW32) { | |||||
| c_pos = 1; | |||||
| spatial_pos = 2; | |||||
| } else if (param().format == Param::Format::NHWC) { | |||||
| c_pos = 3; | |||||
| spatial_pos = 1; | |||||
| } else if (param().format == Param::Format::CHWN4) { | |||||
| c_pos = 0; | |||||
| spatial_pos = 1; | |||||
| batch_pos = 3; | |||||
| } else { | |||||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
| c_pos = 2; | |||||
| spatial_pos = 1; | |||||
| } | |||||
| size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||||
| IH = src.layout.shape[spatial_pos + 0], | |||||
| IW = src.layout.shape[spatial_pos + 1]; | |||||
| size_t OH = dst.layout.shape[spatial_pos + 0], | |||||
| OW = dst.layout.shape[spatial_pos + 1]; | |||||
| if (param().format == Param::Format::NHWCD4) { | |||||
| C *= 4; | |||||
| IW = src.layout.shape[spatial_pos + 2]; | |||||
| OW = dst.layout.shape[spatial_pos + 2]; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::CHWN4) { | |||||
| C *= 4; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW88) { | |||||
| C *= 8; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW32) { | |||||
| C *= 32; | |||||
| } | |||||
| size_t PH = param().pad_h, PW = param().pad_w; | |||||
| size_t FH = param().window_h, FW = param().window_w; | |||||
| size_t SH = param().stride_h, SW = param().stride_w; | |||||
| #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||||
| MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<naive::HandleImpl*>(handle()), \ | |||||
| pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||||
| sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \ | |||||
| PH, PW, SH, SW, FH, FW)); \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| #define DISPATCH_WITH_POOLER(Pooler) \ | #define DISPATCH_WITH_POOLER(Pooler) \ | ||||
| switch (param().format) { \ | switch (param().format) { \ | ||||
| @@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| } \ | } \ | ||||
| } \ | } \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
| #undef cb | #undef cb | ||||
| #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | ||||
| #undef DISPATCH_WITH_POOLER | #undef DISPATCH_WITH_POOLER | ||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| megdnn_assert_internal(0); | |||||
| } | } | ||||
| WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | ||||
| @@ -14,6 +14,10 @@ | |||||
| #include "megdnn/tensor_iter.h" | #include "megdnn/tensor_iter.h" | ||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megdnn_naive_relayout_format) | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace naive; | using namespace naive; | ||||
| @@ -79,6 +83,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, | |||||
| } | } | ||||
| cb(Float32, dt_float32); | cb(Float32, dt_float32); | ||||
| cb(QuantizedS8, dt_qint8); | |||||
| default: | default: | ||||
| megdnn_assert(0); | megdnn_assert(0); | ||||
| #undef cb | #undef cb | ||||
| @@ -138,7 +143,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return n * c * h * w * src.dtype.size(); | return n * c * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { | ||||
| megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); | |||||
| megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
| megdnn_assert(src[0] % 8 == 0, | megdnn_assert(src[0] % 8 == 0, | ||||
| "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); | ||||
| if (src[1] % 8 == 0) | if (src[1] % 8 == 0) | ||||
| @@ -150,7 +155,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return oc * ic * h * w * src.dtype.size(); | return oc * ic * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { | ||||
| megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
| megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
| megdnn_assert(src[1] % 8 == 0, | megdnn_assert(src[1] % 8 == 0, | ||||
| "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " | ||||
| "align to 8"); | "align to 8"); | ||||
| @@ -164,7 +169,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { | ||||
| megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); | |||||
| megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); | |||||
| if (src[0] % 8 == 0) | if (src[0] % 8 == 0) | ||||
| return 0; | return 0; | ||||
| size_t group = round_up(src[0], 8_z); | size_t group = round_up(src[0], 8_z); | ||||
| @@ -174,6 +179,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| size_t w = src[4]; | size_t w = src[4]; | ||||
| return group * ocpg * icpg * h * w * src.dtype.size(); | return group * ocpg * icpg * h * w * src.dtype.size(); | ||||
| } | } | ||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL: { | |||||
| if (src[1] % 4 == 0) | |||||
| return 0; | |||||
| size_t n = src[0]; | |||||
| size_t c = round_up(src[1], 4_z); | |||||
| size_t h = src[2]; | |||||
| size_t w = src[3]; | |||||
| return n * c * h * w * src.dtype.size(); | |||||
| } | |||||
| case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { | |||||
| megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); | |||||
| if (src[1] % 4 == 0) | |||||
| return 0; | |||||
| size_t oc = src[0]; | |||||
| size_t ic = round_up(src[1], 4_z); | |||||
| size_t h = src[2]; | |||||
| size_t w = src[3]; | |||||
| return oc * ic * h * w * src.dtype.size(); | |||||
| } | |||||
| default: | default: | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -200,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| //! ic % 4 != 0 | //! ic % 4 != 0 | ||||
| if ((IC & 0x3)) { | if ((IC & 0x3)) { | ||||
| switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
| #define cb(name, ctype) \ | |||||
| case (DTypeEnum::name): { \ | |||||
| ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
| ctype* dptr = workspace.ptr<ctype>(); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| m_handle, \ | |||||
| padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \ | |||||
| break; \ | |||||
| #define cb(name, ctype) \ | |||||
| case (DTypeEnum::name): { \ | |||||
| MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ | |||||
| midout_iv(Param::Mode::NCHW_NHWCD4I)) { \ | |||||
| ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
| ctype* dptr = workspace.ptr<ctype>(); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \ | |||||
| IC, IH, IW);); \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| break; \ | |||||
| } | } | ||||
| cb(Float32, dt_float32); | cb(Float32, dt_float32); | ||||
| MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); | MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); | ||||
| @@ -226,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| size_t FW = src.layout[3]; | size_t FW = src.layout[3]; | ||||
| if ((IC & 0x3)) { | if ((IC & 0x3)) { | ||||
| switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
| #define cb(name, ctype) \ | |||||
| case (DTypeEnum::name): { \ | |||||
| ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
| ctype* dptr = workspace.ptr<ctype>(); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \ | |||||
| IC, FH, FW);); \ | |||||
| break; \ | |||||
| #define cb(name, ctype) \ | |||||
| case (DTypeEnum::name): { \ | |||||
| MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ | |||||
| midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \ | |||||
| ctype* sptr = src.compatible_ptr<ctype>(); \ | |||||
| ctype* dptr = workspace.ptr<ctype>(); \ | |||||
| MEGDNN_DISPATCH_CPU_KERN(m_handle, \ | |||||
| padding_filter_to_workspace<ctype>( \ | |||||
| dptr, sptr, OC, IC, FH, FW);); \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| break; \ | |||||
| } | } | ||||
| cb(Quantized8Asymm, dt_uint8); | cb(Quantized8Asymm, dt_uint8); | ||||
| cb(QuantizedS8, dt_int8); | cb(QuantizedS8, dt_int8); | ||||
| @@ -244,33 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | exec_src_nd.raw_ptr = workspace.raw_ptr; | ||||
| } | } | ||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88) { | } else if (param().mode == Param::Mode::NCHW_NCHW88) { | ||||
| size_t ic = src.layout[1]; | |||||
| if (ic % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| #define cb(_idx, _pack_size, _mode) \ | |||||
| MIDOUT_BEGIN(megdnn_naive_relayout_format, \ | |||||
| midout_iv(Param::Mode::_mode)) { \ | |||||
| size_t val = src.layout[_idx]; \ | |||||
| if (val % _pack_size != 0) { \ | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ | |||||
| _pack_size); \ | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; \ | |||||
| } \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| cb(1, 8, NCHW_NCHW88); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { | ||||
| megdnn_assert(src.layout[0] % 8 == 0); | megdnn_assert(src.layout[0] % 8 == 0); | ||||
| size_t ic = src.layout[1]; | |||||
| if (ic % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { | ||||
| size_t group = src.layout[0]; | |||||
| if (group % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { | ||||
| megdnn_assert(src.layout[1] % 8 == 0); | megdnn_assert(src.layout[1] % 8 == 0); | ||||
| size_t ic = src.layout[2]; | |||||
| if (ic % 8 != 0) { | |||||
| padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); | |||||
| exec_src_nd.raw_ptr = workspace.raw_ptr; | |||||
| } | |||||
| cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT); | |||||
| } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { | |||||
| cb(1, 4, NCHW_NCHW4_IC_SMALL); | |||||
| } else if (param().mode == | |||||
| Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { | |||||
| cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | |||||
| } | } | ||||
| m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); | ||||
| #undef cb | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -8,6 +8,7 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "megdnn/dtype.h" | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "test/common/checker.h" | #include "test/common/checker.h" | ||||
| #include "test/common/rng.h" | #include "test/common/rng.h" | ||||
| @@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { | |||||
| checker.execs({{22, 23, 24, 25, 4}, {}}); | checker.execs({{22, 23, 24, 25, 4}, {}}); | ||||
| } | } | ||||
| TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { | |||||
| Checker<RelayoutFormat> checker(handle_cuda()); | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| param::RelayoutFormat param; | |||||
| param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; | |||||
| for (DType dtype : | |||||
| std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { | |||||
| checker.set_dtype(0, dtype).set_rng(0, &rng); | |||||
| checker.set_param(param).execs({{2, 4, 35, 36}, {}}); | |||||
| checker.set_param(param).execs({{2, 3, 35, 36}, {}}); | |||||
| checker.set_param(param).execs({{2, 1, 35, 36}, {}}); | |||||
| param.mode = param::RelayoutFormat::Mode:: | |||||
| NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT; | |||||
| checker.set_param(param).execs({{4, 3, 3, 3}, {}}); | |||||
| checker.set_param(param).execs({{4, 4, 3, 3}, {}}); | |||||
| checker.set_param(param).execs({{1, 4, 3, 3}, {}}); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -25,7 +25,7 @@ from . import config, craniotome, dtype | |||||
| from . import global_init as _global_init | from . import global_init as _global_init | ||||
| from . import helper as _helper | from . import helper as _helper | ||||
| from . import mgb as _detail | from . import mgb as _detail | ||||
| from . import opr, opr_param_defs, plugin | |||||
| from . import opr, opr_extra, opr_param_defs, plugin | |||||
| from .exc import MegBrainError | from .exc import MegBrainError | ||||
| from .logconf import get_logger | from .logconf import get_logger | ||||
| from .mgb import ( | from .mgb import ( | ||||
| @@ -0,0 +1,3 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # Copyright (c) 2015-2019 Megvii Inc. All rights reserved. | |||||
| @@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): | |||||
| memo[id(self)] = result | memo[id(self)] = result | ||||
| for k, v in self.__dict__.items(): | for k, v in self.__dict__.items(): | ||||
| setattr(result, k, copy.deepcopy(v, memo)) | setattr(result, k, copy.deepcopy(v, memo)) | ||||
| setattr(result, "saved_tensors", tmp) | |||||
| self.saved_tensors = tmp | self.saved_tensors = tmp | ||||
| return result | return result | ||||
| @@ -235,6 +235,14 @@ class Tensor: | |||||
| return self.__val.dtype | return self.__val.dtype | ||||
| return self._symvar.dtype | return self._symvar.dtype | ||||
| def set_dtype(self, dtype: str = None): | |||||
| r"""Set the data type of the tensor. | |||||
| """ | |||||
| if self.__val is not None: | |||||
| self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||||
| elif self.__sym is not None: | |||||
| self.__sym = self.__sym.astype(dtype) | |||||
| @property | @property | ||||
| def _comp_node(self): | def _comp_node(self): | ||||
| if self.__val is not None: | if self.__val is not None: | ||||
| @@ -26,7 +26,7 @@ def _clear_plasma_store(): | |||||
| # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | ||||
| # so this function should be called explicitly | # so this function should be called explicitly | ||||
| global MGE_PLASMA_STORE_MANAGER | global MGE_PLASMA_STORE_MANAGER | ||||
| if MGE_PLASMA_STORE_MANAGER is not None: | |||||
| if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0: | |||||
| del MGE_PLASMA_STORE_MANAGER | del MGE_PLASMA_STORE_MANAGER | ||||
| MGE_PLASMA_STORE_MANAGER = None | MGE_PLASMA_STORE_MANAGER = None | ||||
| @@ -50,6 +50,7 @@ class _PlasmaStoreManager: | |||||
| stderr=None if debug_flag else subprocess.DEVNULL, | stderr=None if debug_flag else subprocess.DEVNULL, | ||||
| ) | ) | ||||
| self.__initialized = True | self.__initialized = True | ||||
| self.refcount = 1 | |||||
| def __del__(self): | def __del__(self): | ||||
| if self.__initialized and self.plasma_store.returncode is None: | if self.__initialized and self.plasma_store.returncode is None: | ||||
| @@ -83,6 +84,8 @@ class PlasmaShmQueue: | |||||
| "Exception happened in starting plasma_store: {}\n" | "Exception happened in starting plasma_store: {}\n" | ||||
| "Tips: {}".format(str(e), err_info) | "Tips: {}".format(str(e), err_info) | ||||
| ) | ) | ||||
| else: | |||||
| MGE_PLASMA_STORE_MANAGER.refcount += 1 | |||||
| self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | ||||
| @@ -133,6 +136,8 @@ class PlasmaShmQueue: | |||||
| def close(self): | def close(self): | ||||
| self.queue.close() | self.queue.close() | ||||
| self.disconnect_client() | self.disconnect_client() | ||||
| global MGE_PLASMA_STORE_MANAGER | |||||
| MGE_PLASMA_STORE_MANAGER.refcount -= 1 | |||||
| _clear_plasma_store() | _clear_plasma_store() | ||||
| def cancel_join_thread(self): | def cancel_join_thread(self): | ||||
| @@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||||
| ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) | ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) | ||||
| ret = ret.reshape(orig_shape[:-1], weight.shape[0]) | ret = ret.reshape(orig_shape[:-1], weight.shape[0]) | ||||
| if bias is not None: | if bias is not None: | ||||
| ret += bias | |||||
| ret += bias.reshape(1, bias.shape[0]) | |||||
| return ret | return ret | ||||
| @@ -442,17 +442,38 @@ class trace: | |||||
| Serialize trace to file system. | Serialize trace to file system. | ||||
| :param fpath: positional only argument. Path of output file. | :param fpath: positional only argument. Path of output file. | ||||
| :param arg_names: names of the input tensors in the traced function | |||||
| :param append: whether output is appended to ``fpath`` | |||||
| :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use | |||||
| :param arg_names: names of the input tensors in the traced function. | |||||
| :param append: whether output is appended to ``fpath``. | |||||
| :param optimize_for_inference: whether to enable optimize_for_inference | |||||
| pass before dump. | |||||
| :param enable_io16xc32: whether to use float16 for I/O between oprs and use | |||||
| float32 as internal computation precision. Note the output var would be | float32 as internal computation precision. Note the output var would be | ||||
| changed to float16 | |||||
| :param f16_io_comp: whether to use float16 for both I/O and computation | |||||
| precision | |||||
| :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some | |||||
| OpenCL devices | |||||
| :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
| into one opr. This is supported only in NHWCD4 format. | |||||
| changed to float16. | |||||
| :param enable_ioc16: whether to use float16 for both I/O and computation | |||||
| precision. | |||||
| :param enable_hwcd4: whether to use NHWCD4 data layout. This is faster on some | |||||
| OpenCL backend. | |||||
| :param enable_nchw88: whether to use NCHW4 data layout. it currently | |||||
| used in X86 AVX backend. | |||||
| :param enable_nchw44: whether to use NCHW4 data layout. it currently | |||||
| used in arm backend. | |||||
| :param enable_nchw44_dot: whether to use NCHW4 data layout. it currently | |||||
| used in armv8.2+dotprod backend. | |||||
| :param enable_nchw4: whether to use NCHW4 data layout. it currently | |||||
| used in nvidia backend(based on cudnn). | |||||
| :param enable_nchw32 whether to use NCHW32 data layout. it currently | |||||
| used in nvidia backend with tensorcore(based on cudnn). | |||||
| :param enable_chwn4 whether to use CHWN4 data layout. it currently | |||||
| used in nvidia backend with tensorcore. | |||||
| :param enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
| into one opr. | |||||
| :param enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
| input for inference on nvidia backend(this optimization pass will | |||||
| result in mismatch of the precision of output of training and | |||||
| inference) | |||||
| """ | """ | ||||
| if self._status != self._FINISHED: | if self._status != self._FINISHED: | ||||
| raise ValueError("not traced") | raise ValueError("not traced") | ||||
| @@ -475,6 +496,7 @@ class trace: | |||||
| "enable_nchw88": "use_nchw88", | "enable_nchw88": "use_nchw88", | ||||
| "enable_nchw32": "use_nchw32", | "enable_nchw32": "use_nchw32", | ||||
| "enable_nchw44": "use_nchw44", | "enable_nchw44": "use_nchw44", | ||||
| "enable_nchw44_dot": "use_nchw44_dot", | |||||
| "enable_chwn4": "use_chwn4", | "enable_chwn4": "use_chwn4", | ||||
| "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", | "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", | ||||
| "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", | "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", | ||||
| @@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from .._internal.dtype import is_quantize | |||||
| from ..core import Buffer, Parameter, Tensor | from ..core import Buffer, Parameter, Tensor | ||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| @@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): | |||||
| ), "param `{}` shape mismatch, should be {}, get {}".format( | ), "param `{}` shape mismatch, should be {}, get {}".format( | ||||
| k, var.shape, to_be_load.shape | k, var.shape, to_be_load.shape | ||||
| ) | ) | ||||
| # For quantized dtype, the initialized dtype | |||||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||||
| var.set_dtype(to_be_load.dtype) | |||||
| var.set_value(to_be_load) | var.set_value(to_be_load) | ||||
| loaded.append(k) | loaded.append(k) | ||||
| @@ -37,15 +37,14 @@ class QATModule(Module): | |||||
| Set quantization related configs with ``qconfig``, including | Set quantization related configs with ``qconfig``, including | ||||
| observer and fake_quant for weight and activation. | observer and fake_quant for weight and activation. | ||||
| """ | """ | ||||
| self.weight_observer = qconfig.weight_observer() | |||||
| self.act_observer = qconfig.act_observer() | |||||
| if qconfig.fake_quant is None: | |||||
| self.weight_fake_quant = None | |||||
| self.act_fake_quant = None | |||||
| else: | |||||
| self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||||
| self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||||
| def safe_call(func): | |||||
| return func() if func is not None else None | |||||
| self.weight_observer = safe_call(qconfig.weight_observer) | |||||
| self.act_observer = safe_call(qconfig.act_observer) | |||||
| self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||||
| self.act_fake_quant = safe_call(qconfig.act_fake_quant) | |||||
| def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
| self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
| @@ -77,13 +76,19 @@ class QATModule(Module): | |||||
| r""" | r""" | ||||
| Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
| """ | """ | ||||
| return self.weight_observer.get_dtype() | |||||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||||
| return self.weight_fake_quant.get_dtype() | |||||
| else: | |||||
| return self.weight_observer.get_dtype() | |||||
| def get_activation_dtype(self): | def get_activation_dtype(self): | ||||
| r""" | r""" | ||||
| Get activation's quantization dtype as the method from ``qconfig``. | Get activation's quantization dtype as the method from ``qconfig``. | ||||
| """ | """ | ||||
| return self.act_observer.get_dtype() | |||||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||||
| return self.act_fake_quant.get_dtype() | |||||
| else: | |||||
| return self.act_observer.get_dtype() | |||||
| @classmethod | @classmethod | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -12,4 +12,5 @@ from .qconfig import ( | |||||
| calibration_qconfig, | calibration_qconfig, | ||||
| ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
| min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
| tqt_quant_qconfig, | |||||
| ) | ) | ||||
| @@ -5,18 +5,21 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import copy | |||||
| import math | |||||
| import numpy as np | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .._internal.dtype import _metadata_dict | |||||
| from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
| from ..core import Buffer, Function, Parameter | |||||
| from ..jit import sideeffect | |||||
| from ..module import Module | from ..module import Module | ||||
| from .observer import ObserverMode, Round | from .observer import ObserverMode, Round | ||||
| class FakeQuantize(Module): | |||||
| r""" | |||||
| A module to do quant and dequant according to observer's scale and zero_point. | |||||
| """ | |||||
| def __init__(self, dtype: str, enable: bool = True): | |||||
| class _FakeQuantize(Module): | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
| super().__init__() | super().__init__() | ||||
| if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -25,7 +28,10 @@ class FakeQuantize(Module): | |||||
| ) | ) | ||||
| ) | ) | ||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.qmin = _metadata_dict[dtype].qmin | |||||
| self.narrow_range = narrow_range | |||||
| self.qmin = ( | |||||
| -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
| ) | |||||
| self.qmax = _metadata_dict[dtype].qmax | self.qmax = _metadata_dict[dtype].qmax | ||||
| self.enabled = enable | self.enabled = enable | ||||
| @@ -35,25 +41,108 @@ class FakeQuantize(Module): | |||||
| def disable(self): | def disable(self): | ||||
| self.enabled = False | self.enabled = False | ||||
| def fake_quant_forward(self, inp, q_dict): | |||||
| return inp | |||||
| def normal_foward(self, inp, q_dict): | |||||
| return inp | |||||
| def forward(self, inp, q_dict): | def forward(self, inp, q_dict): | ||||
| if self.enabled: | if self.enabled: | ||||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
| scale = q_dict["scale"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup) * scale | |||||
| return oup | |||||
| else: | |||||
| scale = q_dict["scale"] | |||||
| zero_point = q_dict["zero_point"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) + zero_point | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup - zero_point) * scale | |||||
| return oup | |||||
| return self.fake_quant_forward(inp, q_dict) | |||||
| else: | |||||
| return self.normal_foward(inp, q_dict) | |||||
| class TQT_Function(Function): | |||||
| def __init__(self, lowerbound, upperbound): | |||||
| super().__init__() | |||||
| self.lowerbound = lowerbound | |||||
| self.upperbound = upperbound | |||||
| def forward(self, inp, scale): | |||||
| t = 2 ** scale | |||||
| # t = F.maximum(t, 1e-4) | |||||
| inp_scaled = inp / t | |||||
| inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) | |||||
| inp_rounded = F.round(inp_clipped) | |||||
| inp_flq = inp_rounded * t | |||||
| self.save_for_backward(inp_scaled, inp_rounded, t) | |||||
| return inp_flq | |||||
| def backward(self, grad_inp_flq): | |||||
| (inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
| mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
| inp_scaled > self.upperbound + 0.5 | |||||
| ) # mask for accumulating the gradients of |data_scaled|>L | |||||
| mask_quant = F.abs( | |||||
| mask_clip - 1 | |||||
| ) # mask for accumulating the gradients with |data_scaled|<=L | |||||
| grad_quant = ( | |||||
| grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
| ) # gradient within |data_scaled|<=L | |||||
| grad_clip = ( | |||||
| grad_inp_flq * mask_clip * inp_rounded | |||||
| ) # gradient with | data_scaled|>L | |||||
| grad_s = grad_clip.sum() + grad_quant.sum() | |||||
| # dL/ds = dL/dt * t * ln(2) | |||||
| grad_s = grad_s * t * math.log(2) | |||||
| grad_inp = grad_inp_flq * mask_quant | |||||
| return grad_inp, grad_s | |||||
| class TQT(_FakeQuantize): | |||||
| """ | |||||
| TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||||
| """ | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
| super().__init__(dtype, narrow_range, enable) | |||||
| self.scale = Parameter(0.0, dtype=np.float32) | |||||
| def fake_quant_forward(self, inp, q_dict): | |||||
| # when enable, TQT will do fakequant forward, finetune the scale | |||||
| return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||||
| def normal_foward(self, inp, q_dict): | |||||
| # when disable, TQT will do normal forward, initialize scale weight | |||||
| tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
| tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||||
| return inp | return inp | ||||
| def get_dtype(self): | |||||
| return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||||
| class FakeQuantize(_FakeQuantize): | |||||
| r""" | |||||
| A module to do quant and dequant according to observer's scale and zero_point. | |||||
| :param dtype: A string indicating the target quantization type of input. | |||||
| :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
| instead of 1 greater. Usually True for weight and False for activation. | |||||
| :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
| """ | |||||
| def fake_quant_forward(self, inp, q_dict): | |||||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
| scale = q_dict["scale"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup) * scale | |||||
| return oup | |||||
| else: | |||||
| scale = q_dict["scale"] | |||||
| zero_point = q_dict["zero_point"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) + zero_point | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup - zero_point) * scale | |||||
| return oup | |||||
| @@ -31,9 +31,11 @@ class Observer(Module): | |||||
| A base class for Observer Module. | A base class for Observer Module. | ||||
| :param dtype: a string indicating to collect scale and zero_point of which dtype | :param dtype: a string indicating to collect scale and zero_point of which dtype | ||||
| :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
| instead of 1 greater. Usually True for weight and False for activation. | |||||
| """ | """ | ||||
| def __init__(self, dtype="qint8"): | |||||
| def __init__(self, dtype: str, narrow_range: bool = False): | |||||
| super().__init__() | super().__init__() | ||||
| if dtype not in _metadata_dict.keys(): | if dtype not in _metadata_dict.keys(): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -42,7 +44,10 @@ class Observer(Module): | |||||
| ) | ) | ||||
| ) | ) | ||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.qmin = _metadata_dict[dtype].qmin | |||||
| self.narrow_range = narrow_range | |||||
| self.qmin = ( | |||||
| -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
| ) | |||||
| self.qmax = _metadata_dict[dtype].qmax | self.qmax = _metadata_dict[dtype].qmax | ||||
| self.enabled = True | self.enabled = True | ||||
| @@ -96,8 +101,14 @@ def create_observer_dict(mode): | |||||
| class MinMaxObserver(Observer): | class MinMaxObserver(Observer): | ||||
| def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): | |||||
| super().__init__(dtype) | |||||
| def __init__( | |||||
| self, | |||||
| mode=ObserverMode.SYMMERTIC, | |||||
| eps=0.00001, | |||||
| dtype="qint8", | |||||
| narrow_range: bool = False, | |||||
| ): | |||||
| super().__init__(dtype, narrow_range) | |||||
| self.mode = mode | self.mode = mode | ||||
| self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | ||||
| self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | ||||
| @@ -107,6 +118,8 @@ class MinMaxObserver(Observer): | |||||
| min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
| max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
| q_dict = create_observer_dict(self.mode) | q_dict = create_observer_dict(self.mode) | ||||
| q_dict["min_val"] = inp_min_val | |||||
| q_dict["max_val"] = inp_max_val | |||||
| if self.mode == ObserverMode.SYMMERTIC: | if self.mode == ObserverMode.SYMMERTIC: | ||||
| symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
| # use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
| @@ -151,9 +164,14 @@ class MinMaxObserver(Observer): | |||||
| class ExponentialMovingAverageObserver(MinMaxObserver): | class ExponentialMovingAverageObserver(MinMaxObserver): | ||||
| def __init__( | def __init__( | ||||
| self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" | |||||
| self, | |||||
| momentum=0.9, | |||||
| mode=ObserverMode.SYMMERTIC, | |||||
| eps=0.00001, | |||||
| dtype="qint8", | |||||
| narrow_range: bool = False, | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype) | |||||
| super().__init__(mode, eps, dtype, narrow_range) | |||||
| self.momentum = Buffer(momentum) | self.momentum = Buffer(momentum) | ||||
| self.runtime_momentum = Buffer(0.0) | self.runtime_momentum = Buffer(0.0) | ||||
| @@ -186,11 +204,12 @@ class HistogramObserver(MinMaxObserver): | |||||
| self, | self, | ||||
| bins=2048, | bins=2048, | ||||
| upsample_rate=128, | upsample_rate=128, | ||||
| dtype="qint8", | |||||
| mode=ObserverMode.SYMMERTIC, | mode=ObserverMode.SYMMERTIC, | ||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | |||||
| narrow_range: bool = False, | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype) | |||||
| super().__init__(mode, eps, dtype, narrow_range) | |||||
| self.bins = bins | self.bins = bins | ||||
| self.upsample_rate = upsample_rate | self.upsample_rate = upsample_rate | ||||
| self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | ||||
| @@ -1,12 +1,14 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| # | # | ||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| # | |||||
| #' | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from functools import partial | |||||
| from ..module import Module | from ..module import Module | ||||
| from .fake_quant import FakeQuantize | |||||
| from .fake_quant import TQT, FakeQuantize | |||||
| from .observer import ( | from .observer import ( | ||||
| ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
| HistogramObserver, | HistogramObserver, | ||||
| @@ -22,9 +24,9 @@ class QConfig: | |||||
| :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | ||||
| how to collect scales and zero_point of wegiht. | how to collect scales and zero_point of wegiht. | ||||
| :param act_observer: similar to ``weight_observer`` but toward activation. | :param act_observer: similar to ``weight_observer`` but toward activation. | ||||
| :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
| how to do fake_quant calculation. can be invoked multi times to get different | |||||
| instance for each target tensor, for better control on enable and disable. | |||||
| :param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
| how to do fake_quant calculation. | |||||
| :param act_observer: similar to ``weight_fake_quant`` but toward activation. | |||||
| Examples: | Examples: | ||||
| @@ -32,14 +34,24 @@ class QConfig: | |||||
| # Default EMA QConfig for QAT. | # Default EMA QConfig for QAT. | ||||
| ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | |||||
| act_observer=ExponentialMovingAverageObserver, | |||||
| fake_quant=FakeQuantize, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | |||||
| to add initialization parameters of the ``class``, so that don't need to provide parameters in | |||||
| :meth:`~.QATModule.set_qconfig`. | |||||
| Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related | |||||
| parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if | |||||
| four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||||
| Weights are commonly calculated in this way, so needed to narrow the range. | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, act_observer, weight_observer, fake_quant, | |||||
| self, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||||
| ): | ): | ||||
| if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -47,24 +59,42 @@ class QConfig: | |||||
| " class generator using `partial(Observer, ...)` instead. Use" | " class generator using `partial(Observer, ...)` instead. Use" | ||||
| " partial(MyObserver, x=1) to override arguments to constructor if needed" | " partial(MyObserver, x=1) to override arguments to constructor if needed" | ||||
| ) | ) | ||||
| self.act_observer = act_observer | |||||
| self.weight_observer = weight_observer | self.weight_observer = weight_observer | ||||
| self.fake_quant = fake_quant | |||||
| self.act_observer = act_observer | |||||
| self.weight_fake_quant = weight_fake_quant | |||||
| self.act_fake_quant = act_fake_quant | |||||
| # Default QAT QConfigs | |||||
| tqt_quant_qconfig = QConfig( | |||||
| weight_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True | |||||
| ), | |||||
| act_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
| ), | |||||
| weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
| ) | |||||
| min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | |||||
| act_observer=MinMaxObserver, | |||||
| fake_quant=FakeQuantize, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | |||||
| act_observer=ExponentialMovingAverageObserver, | |||||
| fake_quant=FakeQuantize, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
| ), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| calibration_qconfig = QConfig( | calibration_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=None, | |||||
| act_fake_quant=None, | |||||
| ) | ) | ||||
| @@ -6,6 +6,10 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from megengine._internal.plugin import load_tensor_binary | |||||
| def prod(iterable): | def prod(iterable): | ||||
| result = 1 | result = 1 | ||||
| for i in iterable: | for i in iterable: | ||||
| @@ -1,2 +1 @@ | |||||
| __version__ = "0.4.0" | |||||
| __version__ = "0.5.1" | |||||
| @@ -96,7 +96,6 @@ def test_deepcopy(): | |||||
| origin = Sigmoid(0) | origin = Sigmoid(0) | ||||
| new = copy.deepcopy(Sigmoid(0)) | new = copy.deepcopy(Sigmoid(0)) | ||||
| assert new.param == origin.param | assert new.param == origin.param | ||||
| assert new.saved_tensors == None | |||||
| def test_save_context(): | def test_save_context(): | ||||
| @@ -10,6 +10,7 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| import megengine as mge | import megengine as mge | ||||
| import megengine._internal as mgb | |||||
| def test_wrong_dtype(): | def test_wrong_dtype(): | ||||
| @@ -26,3 +27,48 @@ def test_tensor_routine(): | |||||
| mge.tensor([1]) | mge.tensor([1]) | ||||
| mge.tensor(1.5) | mge.tensor(1.5) | ||||
| def test_tensor_set_dtype(): | |||||
| def check_dtype_value(tensor, dtype_scale, value): | |||||
| if mgb.dtype.is_quantize(tensor.dtype): | |||||
| if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5: | |||||
| raise AssertionError( | |||||
| "compare scale failed expect {} got {}".format( | |||||
| dtype_scale, mgb.dtype.get_scale(tensor.dtype) | |||||
| ) | |||||
| ) | |||||
| if np.abs(tensor.numpy()[0][0] - value) > 1e-5: | |||||
| raise AssertionError( | |||||
| "compare value failed expect {} got {}".format( | |||||
| tensor.numpy()[0][0], value | |||||
| ) | |||||
| ) | |||||
| t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||||
| t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
| check_dtype_value(t, 0.1, 10) | |||||
| t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
| check_dtype_value(t, 0.3, 3) | |||||
| t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
| t.set_dtype(mgb.dtype.qint8(0.1)) | |||||
| check_dtype_value(t, 0.1, 10) | |||||
| t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
| check_dtype_value(t, 0.3, 3) | |||||
| t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||||
| s = t + 1 | |||||
| s.set_dtype(mgb.dtype.qint8(0.2)) | |||||
| check_dtype_value(s, 0.2, 10) | |||||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||||
| s = t + 1 | |||||
| s.set_dtype(mgb.dtype.qint8(0.1)) | |||||
| check_dtype_value(s, 0.1, 18) | |||||
| s.set_dtype("float32") | |||||
| check_dtype_value(s, 0, 1.8) | |||||
| @@ -132,3 +132,52 @@ def test_dataloader_parallel_worker_exception(): | |||||
| with pytest.raises(RuntimeError, match=r"worker.*died"): | with pytest.raises(RuntimeError, match=r"worker.*died"): | ||||
| data_iter = iter(dataloader) | data_iter = iter(dataloader) | ||||
| batch_data = next(data_iter) | batch_data = next(data_iter) | ||||
| def _multi_instances_parallel_dataloader_worker(): | |||||
| dataset = init_dataset() | |||||
| for divide_flag in [True, False]: | |||||
| train_dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| num_workers=2, | |||||
| divide=divide_flag, | |||||
| ) | |||||
| val_dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=10, drop_last=False), | |||||
| num_workers=2, | |||||
| divide=divide_flag, | |||||
| ) | |||||
| for idx, (data, label) in enumerate(train_dataloader): | |||||
| assert data.shape == (4, 1, 32, 32) | |||||
| assert label.shape == (4,) | |||||
| if idx % 5 == 0: | |||||
| for val_data, val_label in val_dataloader: | |||||
| assert val_data.shape == (10, 1, 32, 32) | |||||
| assert val_label.shape == (10,) | |||||
| def test_dataloader_parallel_multi_instances(): | |||||
| # set max shared memory to 100M | |||||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
| _multi_instances_parallel_dataloader_worker() | |||||
| def test_dataloader_parallel_multi_instances_multiprocessing(): | |||||
| # set max shared memory to 100M | |||||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
| import multiprocessing as mp | |||||
| # mp.set_start_method("spawn") | |||||
| processes = [] | |||||
| for i in range(4): | |||||
| p = mp.Process(target=_multi_instances_parallel_dataloader_worker) | |||||
| p.start() | |||||
| processes.append(p) | |||||
| for p in processes: | |||||
| p.join() | |||||
| @@ -14,8 +14,10 @@ import pytest | |||||
| from helpers import MLP | from helpers import MLP | ||||
| import megengine as mge | import megengine as mge | ||||
| import megengine._internal as mgb | |||||
| from megengine.core import Buffer, Parameter, Tensor, tensor | from megengine.core import Buffer, Parameter, Tensor, tensor | ||||
| from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | ||||
| from megengine.quantization.quantize import quantize, quantize_qat | |||||
| from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
| @@ -347,3 +349,38 @@ def test_dump_model(): | |||||
| pred = mlp(data) | pred = mlp(data) | ||||
| with tempfile.NamedTemporaryFile() as f: | with tempfile.NamedTemporaryFile() as f: | ||||
| mge.dump(pred, f.name) | mge.dump(pred, f.name) | ||||
| def test_load_quantized(): | |||||
| data_shape = (2, 28) | |||||
| data = tensor(np.random.random(data_shape), dtype="float32") | |||||
| data = data.astype(mgb.dtype.qint8(0.1)) | |||||
| mlp = MLP() | |||||
| quantize_qat(mlp) | |||||
| quantize(mlp) | |||||
| mlp.dense0.weight = Parameter( | |||||
| mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy() | |||||
| ) | |||||
| mlp.dense1.weight = Parameter( | |||||
| mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy() | |||||
| ) | |||||
| mlp.eval() | |||||
| pred0 = mlp(data) | |||||
| with BytesIO() as fout: | |||||
| mge.save(mlp.state_dict(), fout) | |||||
| fout.seek(0) | |||||
| checkpoint = mge.load(fout) | |||||
| # change mlp weight. | |||||
| mlp.dense0.weight = Parameter( | |||||
| mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy() | |||||
| ) | |||||
| mlp.dense1.weight = Parameter( | |||||
| mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy() | |||||
| ) | |||||
| mlp.load_state_dict(checkpoint) | |||||
| pred1 = mlp(data) | |||||
| assertTensorClose( | |||||
| pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||||
| ) | |||||
| @@ -0,0 +1,77 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine._internal as mgb | |||||
| from megengine.core import tensor | |||||
| from megengine.quantization.fake_quant import TQT_Function | |||||
| from megengine.test import assertTensorClose | |||||
| class numpy_TQT_Function: | |||||
| def __init__(self, lowerbound, upperbound): | |||||
| super().__init__() | |||||
| self.lowerbound = lowerbound | |||||
| self.upperbound = upperbound | |||||
| def forward(self, inp, scale): | |||||
| t = 2 ** scale | |||||
| # t = F.maximum(t, 1e-4) | |||||
| inp_scaled = inp / t | |||||
| inp_clipped = np.maximum( | |||||
| np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||||
| ) | |||||
| inp_rounded = np.round(inp_clipped) | |||||
| inp_flq = inp_rounded * t | |||||
| self.saved_tensors = (inp_scaled, inp_rounded, t) | |||||
| return inp_flq | |||||
| def backward(self, grad_inp_flq): | |||||
| (inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
| mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
| inp_scaled > self.upperbound + 0.5 | |||||
| ) # mask for accumulating the gradients of |data_scaled|>L | |||||
| mask_quant = np.abs( | |||||
| mask_clip - 1 | |||||
| ) # mask for accumulating the gradients with |data_scaled|<=L | |||||
| grad_quant = ( | |||||
| grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
| ) # gradient within |data_scaled|<=L | |||||
| grad_clip = ( | |||||
| grad_inp_flq * mask_clip * inp_rounded | |||||
| ) # gradient with | data_scaled|>L | |||||
| grad_s = grad_clip.sum() + grad_quant.sum() | |||||
| # dL/ds = dL/dt * t * ln(2) | |||||
| grad_s = grad_s * t * np.log(2) | |||||
| grad_inp = grad_inp_flq * mask_quant | |||||
| return grad_inp, grad_s | |||||
| def test_TQT(): | |||||
| f = TQT_Function(-127, 127) | |||||
| nf = numpy_TQT_Function(-127, 127) | |||||
| def check_inp(a, b, c, a_np, b_np, c_np): | |||||
| assertTensorClose( | |||||
| f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") | |||||
| ) | |||||
| c1, c2 = f.backward(c) | |||||
| c1_np, c2_np = nf.backward(c_np) | |||||
| assertTensorClose(c1.numpy(), c1_np.astype("float32")) | |||||
| assertTensorClose(c2.numpy(), c2_np.astype("float32")) | |||||
| a = tensor() | |||||
| b = tensor() | |||||
| a_np = np.random.random((4, 3)).astype("float32") | |||||
| b_np = np.random.random((1)).astype("float32") | |||||
| a.set_value(a_np) | |||||
| b.set_value(b_np) | |||||
| check_inp(a, b, b, a_np, b_np, b_np) | |||||
| @@ -14,7 +14,7 @@ import struct | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import megbrain as mgb | |||||
| import megengine._internal as mgb | |||||
| import megengine as mge | import megengine as mge | ||||
| logger = mge.get_logger(__name__) | logger = mge.get_logger(__name__) | ||||
| @@ -709,6 +709,41 @@ void run_test_st(Args &env) { | |||||
| } | } | ||||
| }; | }; | ||||
| auto run_iters = [&](uint32_t case_idx) -> float { | |||||
| double time_sqrsum = 0, time_sum = 0, | |||||
| min_time = std::numeric_limits<double>::max(), max_time = 0; | |||||
| for (int run = 0; run < env.nr_run; ++run) { | |||||
| mgb_log_debug("load_and_run: before running iter %d", run); | |||||
| timer.reset(); | |||||
| func->execute(); | |||||
| mgb_log_debug("load_and_run: before waiting iter %d", run); | |||||
| auto exec_time = timer.get_msecs(); | |||||
| func->wait(); | |||||
| output_dumper.write_to_file(); | |||||
| auto cur = timer.get_msecs(); | |||||
| printf("iter %d/%d: %.3fms (exec=%.3f,device=%.3f)\n", run, | |||||
| env.nr_run, cur, exec_time, | |||||
| func->get_prev_exec_time() * 1e3); | |||||
| time_sum += cur; | |||||
| time_sqrsum += cur * cur; | |||||
| fflush(stdout); | |||||
| if (cur < min_time) { | |||||
| min_time = cur; | |||||
| } | |||||
| if (cur > max_time) { | |||||
| max_time = cur; | |||||
| } | |||||
| } | |||||
| printf("=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||||
| "sd=%.3fms minmax=%.3f,%.3f\n\n", | |||||
| case_idx, time_sum, time_sum / env.nr_run, | |||||
| std::sqrt((time_sqrsum * env.nr_run - time_sum * time_sum) / | |||||
| (env.nr_run * (env.nr_run - 1))), | |||||
| min_time, max_time); | |||||
| return time_sum; | |||||
| }; | |||||
| if (nr_test) { | if (nr_test) { | ||||
| // run testcase, generated by dump_with_testcase.py | // run testcase, generated by dump_with_testcase.py | ||||
| @@ -742,37 +777,7 @@ void run_test_st(Args &env) { | |||||
| if (!env.nr_run) { | if (!env.nr_run) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| double time_sqrsum = 0, time_sum = 0, | |||||
| min_time = std::numeric_limits<double>::max(), max_time = 0; | |||||
| for (int run = 0; run < env.nr_run; ++ run) { | |||||
| mgb_log_debug("load_and_run: before running iter %d", run); | |||||
| timer.reset(); | |||||
| func->execute(); | |||||
| mgb_log_debug("load_and_run: before waiting iter %d", run); | |||||
| auto exec_time = timer.get_msecs(); | |||||
| func->wait(); | |||||
| output_dumper.write_to_file(); | |||||
| auto cur = timer.get_msecs(); | |||||
| printf("iter %d/%d: %.3fms (exec=%.3f,device=%.3f)\n", run, | |||||
| env.nr_run, cur, exec_time, | |||||
| func->get_prev_exec_time() * 1e3); | |||||
| time_sum += cur; | |||||
| time_sqrsum += cur * cur; | |||||
| fflush(stdout); | |||||
| if (cur < min_time) { | |||||
| min_time = cur; | |||||
| } | |||||
| if (cur > max_time) { | |||||
| max_time = cur; | |||||
| } | |||||
| } | |||||
| tot_time += time_sum; | |||||
| printf("=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||||
| "sd=%.3fms minmax=%.3f,%.3f\n\n", | |||||
| i, time_sum, time_sum / env.nr_run, | |||||
| std::sqrt((time_sqrsum * env.nr_run - time_sum * time_sum) / | |||||
| (env.nr_run * (env.nr_run - 1))), | |||||
| min_time, max_time); | |||||
| tot_time += run_iters(i); | |||||
| } | } | ||||
| printf("=== total time: %.3fms\n", tot_time); | printf("=== total time: %.3fms\n", tot_time); | ||||
| @@ -793,15 +798,10 @@ void run_test_st(Args &env) { | |||||
| in->copy_from(i.second); | in->copy_from(i.second); | ||||
| } | } | ||||
| warmup(); | |||||
| timer.reset(); | timer.reset(); | ||||
| func->execute(); | |||||
| auto exec_time = timer.get_msecs(); | |||||
| func->wait(); | |||||
| output_dumper.write_to_file(); | |||||
| auto cur = timer.get_msecs(); | |||||
| printf("%.3fms %.3fms (device=%.3f)\n", cur, exec_time, | |||||
| func->get_prev_exec_time() * 1e3); | |||||
| printf("=== going to run input for %d times\n", env.nr_run); | |||||
| run_iters(0); | |||||
| } else { | } else { | ||||
| // run speed test for a raw mgb graph | // run speed test for a raw mgb graph | ||||
| mgb_assert(env.load_ret.tensor_map.empty(), | mgb_assert(env.load_ret.tensor_map.empty(), | ||||
| @@ -34,6 +34,11 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT) | |||||
| endif() | endif() | ||||
| if(MGE_WITH_CUDA) | |||||
| file(GLOB_RECURSE SOURCES_ opr/impl/standalone/*.cu) | |||||
| list(APPEND SOURCES ${SOURCES_}) | |||||
| endif() | |||||
| add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES}) | add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES}) | ||||
| target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) | target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) | ||||
| target_include_directories(megbrain | target_include_directories(megbrain | ||||
| @@ -795,7 +795,7 @@ bool CpuCompNode::CompNodeImpl::check_global_finalized(const char* reason) { | |||||
| /* ======================== CompNode methods ======================== */ | /* ======================== CompNode methods ======================== */ | ||||
| CompNode CompNode::default_cpu() { | CompNode CompNode::default_cpu() { | ||||
| static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, -1}; | |||||
| static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}}; | |||||
| static auto empty_queue = | static auto empty_queue = | ||||
| std::make_shared<CpuCompNode::WorkerQueue>(locator); | std::make_shared<CpuCompNode::WorkerQueue>(locator); | ||||
| static CpuCompNodeImpl impl{locator, locator, empty_queue}; | static CpuCompNodeImpl impl{locator, locator, empty_queue}; | ||||
| @@ -464,7 +464,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
| #if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
| if (options().graph_opt.tensorrt) { | if (options().graph_opt.tensorrt) { | ||||
| options().graph_opt.tensorrt = false; | options().graph_opt.tensorrt = false; | ||||
| tensorrt::transform_dest_vars_inplace(dest_vars); | |||||
| tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -12,8 +12,8 @@ | |||||
| #pragma once | #pragma once | ||||
| #define MGB_MAJOR 8 | #define MGB_MAJOR 8 | ||||
| #define MGB_MINOR 4 | |||||
| #define MGB_PATCH 1 | |||||
| #define MGB_MINOR 5 | |||||
| #define MGB_PATCH 0 | |||||
| //! whether it is development version | //! whether it is development version | ||||
| #ifndef MGB_IS_DEV | #ifndef MGB_IS_DEV | ||||
| #define MGB_IS_DEV 0 | #define MGB_IS_DEV 0 | ||||
| @@ -756,6 +756,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
| cb(nchw32, { | cb(nchw32, { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
| add_pass(EnableNCHW4Pass::make_nchw4_converter()); | |||||
| add_pass(EnableTensorCorePass::make_tensorcore_converter()); | add_pass(EnableTensorCorePass::make_tensorcore_converter()); | ||||
| add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
| add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
| @@ -763,6 +764,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
| cb(chwn4, { | cb(chwn4, { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
| add_pass(EnableNCHW4Pass::make_nchw4_converter()); | |||||
| add_pass(EnableCHWN4Pass::make_chwn4_converter()); | add_pass(EnableCHWN4Pass::make_chwn4_converter()); | ||||
| add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
| add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
| @@ -60,19 +60,24 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, | |||||
| public: | public: | ||||
| //! relayout type of this opr | //! relayout type of this opr | ||||
| enum class LayoutType { | enum class LayoutType { | ||||
| NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout | |||||
| NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout | |||||
| NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout | |||||
| CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout | |||||
| NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout | |||||
| NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout | |||||
| NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout | |||||
| NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout | |||||
| NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout | |||||
| NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout | |||||
| NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout | |||||
| CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout | |||||
| NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout | |||||
| NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose | |||||
| ///< channel size less than 4 | |||||
| NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout | |||||
| NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout | |||||
| NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 | WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 | ||||
| //!< layout | //!< layout | ||||
| WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to | WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to | ||||
| //!< nchw4 layout | //!< nchw4 layout | ||||
| WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout | |||||
| //!< to nchw4 layout whose | |||||
| //! channel size less than 4 | |||||
| WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 | WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 | ||||
| //!< layout | //!< layout | ||||
| @@ -177,11 +182,21 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { | |||||
| dst[3] = inp_shape[2]; | dst[3] = inp_shape[2]; | ||||
| dst[4] = inp_shape[4]; | dst[4] = inp_shape[4]; | ||||
| } else if (layout_type() == | } else if (layout_type() == | ||||
| RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){ | |||||
| mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
| RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 || | |||||
| layout_type() == RelayoutPlaceholder::LayoutType:: | |||||
| NCHW_TO_NCHW4_IC_SMALL_CONV) { | |||||
| if (layout_type() == | |||||
| RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) { | |||||
| mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
| } else { | |||||
| mgb_assert(layout_type() == | |||||
| RelayoutPlaceholder::LayoutType:: | |||||
| NCHW_TO_NCHW4_IC_SMALL_CONV); | |||||
| mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); | |||||
| } | |||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = inp_shape[0]; | dst[0] = inp_shape[0]; | ||||
| dst[1] = inp_shape[1] / 4; | |||||
| dst[1] = (inp_shape[1] + 4 - 1) / 4; | |||||
| dst[2] = inp_shape[2]; | dst[2] = inp_shape[2]; | ||||
| dst[3] = inp_shape[3]; | dst[3] = inp_shape[3]; | ||||
| dst[4] = 4; | dst[4] = 4; | ||||
| @@ -194,11 +209,23 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { | |||||
| dst[2] = inp_shape[2]; | dst[2] = inp_shape[2]; | ||||
| dst[3] = inp_shape[3]; | dst[3] = inp_shape[3]; | ||||
| } else if (layout_type() == RelayoutPlaceholder::LayoutType:: | } else if (layout_type() == RelayoutPlaceholder::LayoutType:: | ||||
| WEIGHT_NCHW_TO_NCHW4_DENSE) { | |||||
| mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE || | |||||
| layout_type() == | |||||
| RelayoutPlaceholder::LayoutType:: | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) { | |||||
| if (layout_type() == | |||||
| RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) { | |||||
| mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); | |||||
| } else { | |||||
| mgb_assert(layout_type() == | |||||
| RelayoutPlaceholder::LayoutType:: | |||||
| WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV); | |||||
| mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); | |||||
| } | |||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = inp_shape[0]; | dst[0] = inp_shape[0]; | ||||
| dst[1] = inp_shape[1] / 4; | |||||
| dst[1] = (inp_shape[1] + 4 - 1) / 4; | |||||
| dst[2] = inp_shape[2]; | dst[2] = inp_shape[2]; | ||||
| dst[3] = inp_shape[3]; | dst[3] = inp_shape[3]; | ||||
| dst[4] = 4; | dst[4] = 4; | ||||
| @@ -427,6 +454,23 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| auto y2 = opr::Reshape::make(y1, tshp1); | auto y2 = opr::Reshape::make(y1, tshp1); | ||||
| return y2.node(); | return y2.node(); | ||||
| }; | }; | ||||
| reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] = | |||||
| [](VarNode* inp) -> VarNode* { | |||||
| auto x = SymbolVar(inp); | |||||
| auto y = opr::RelayoutFormat::make( | |||||
| x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); | |||||
| return y.node(); | |||||
| }; | |||||
| reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] = | |||||
| [](VarNode* inp) -> VarNode* { | |||||
| auto x = SymbolVar(inp); | |||||
| auto y = opr::RelayoutFormat::make( | |||||
| x, megdnn::param::RelayoutFormat::Mode:: | |||||
| NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); | |||||
| return y.node(); | |||||
| }; | |||||
| reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { | reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { | ||||
| auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
| auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
| @@ -435,13 +479,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | ||||
| }; | }; | ||||
| auto tshp0 = opr::Concat::make( | auto tshp0 = opr::Concat::make( | ||||
| {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), | |||||
| tshp1 = opr::Concat::make( | |||||
| {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); | |||||
| {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); | |||||
| auto y0 = opr::Reshape::make(x, tshp0); | auto y0 = opr::Reshape::make(x, tshp0); | ||||
| auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); | auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); | ||||
| auto y2 = opr::Reshape::make(y1, tshp1); | |||||
| return y2.node(); | |||||
| return y1.node(); | |||||
| }; | }; | ||||
| reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { | reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { | ||||
| auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
| @@ -455,7 +496,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| auto y1 = opr::Reshape::make(y0, tshp0); | auto y1 = opr::Reshape::make(y0, tshp0); | ||||
| return y1.node(); | return y1.node(); | ||||
| }; | }; | ||||
| reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* { | |||||
| reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = | |||||
| [](VarNode* inp) -> VarNode* { | |||||
| auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
| auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | auto cv = [&x](int v) { return x.make_scalar(v); }; | ||||
| @@ -471,7 +513,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| auto y2 = opr::Reshape::make(y1, tshp1); | auto y2 = opr::Reshape::make(y1, tshp1); | ||||
| return y2.node(); | return y2.node(); | ||||
| }; | }; | ||||
| reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* { | |||||
| reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = | |||||
| [](VarNode* inp) -> VarNode* { | |||||
| auto x = SymbolVar(inp); | auto x = SymbolVar(inp); | ||||
| auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | auto cv = [&x](int v) { return x.make_scalar(v); }; | ||||
| @@ -1357,56 +1400,71 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| using RelayoutMode = RelayoutPlaceholder::LayoutType; | using RelayoutMode = RelayoutPlaceholder::LayoutType; | ||||
| megdnn::param::Convolution::Format conv_format = | megdnn::param::Convolution::Format conv_format = | ||||
| megdnn::param::Convolution::Format::NCHW4; | megdnn::param::Convolution::Format::NCHW4; | ||||
| megdnn::param::ConvBias::Format conv_bias_format = | |||||
| megdnn::param::ConvBias::Format conv_bias_format = | |||||
| megdnn::param::ConvBias::Format::NCHW4; | megdnn::param::ConvBias::Format::NCHW4; | ||||
| megdnn::param::BatchConvBias::Format batch_conv_bias_format = | megdnn::param::BatchConvBias::Format batch_conv_bias_format = | ||||
| megdnn::param::BatchConvBias::Format::NCHW4; | megdnn::param::BatchConvBias::Format::NCHW4; | ||||
| RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; | RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; | ||||
| RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; | RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; | ||||
| RelayoutMode weight_to_nchw4_mode_dense = | |||||
| RelayoutMode weight_to_nchw4_mode_dense = | |||||
| RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; | RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; | ||||
| RelayoutMode weight_to_nchw4_mode_group = | |||||
| RelayoutMode weight_to_nchw4_mode_group = | |||||
| RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; | RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; | ||||
| auto trans_nchw4 = [weight_to_nchw4_mode_dense, | |||||
| weight_to_nchw4_mode_group]( | |||||
| struct ConvMode { | |||||
| RelayoutMode weight; | |||||
| RelayoutMode src; | |||||
| }; | |||||
| auto trans_nchw4 = | |||||
| [weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group, | |||||
| src_to_nchw4_mode]( | |||||
| const megdnn::param::Convolution::Sparse conv_mode, | const megdnn::param::Convolution::Sparse conv_mode, | ||||
| const VarNode* filter) -> RelayoutMode { | |||||
| const VarNode* filter) -> ConvMode { | |||||
| if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | ||||
| mgb_assert(filter->shape().ndim == 4, | mgb_assert(filter->shape().ndim == 4, | ||||
| "The origin filter is not NCHW mode"); | "The origin filter is not NCHW mode"); | ||||
| size_t IC = filter->shape()[1]; | size_t IC = filter->shape()[1]; | ||||
| mgb_assert(IC % 4 == 0, | |||||
| "The input channel should be divisible by 4"); | |||||
| return weight_to_nchw4_mode_dense; | |||||
| if (IC < 4) { | |||||
| return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, | |||||
| RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV}; | |||||
| } else { | |||||
| return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; | |||||
| } | |||||
| } else { | } else { | ||||
| mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | ||||
| mgb_assert(filter->shape().ndim == 5, | mgb_assert(filter->shape().ndim == 5, | ||||
| "The origin filter if not NCHW mode"); | "The origin filter if not NCHW mode"); | ||||
| size_t IC = filter->shape()[2]; | size_t IC = filter->shape()[2]; | ||||
| mgb_assert(IC % 4 == 0, | mgb_assert(IC % 4 == 0, | ||||
| "The input channel should be divisible by 4"); | |||||
| return weight_to_nchw4_mode_group; | |||||
| "The input channel should be divisible by 4 for group " | |||||
| "conv"); | |||||
| return {weight_to_nchw4_mode_group, src_to_nchw4_mode}; | |||||
| } | } | ||||
| }; | }; | ||||
| auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode]( | |||||
| OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||||
| auto replace_conv_opr = [trans_nchw4, conv_format]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | ||||
| mgb_assert(conv_opr.param().format == | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHW4"); | |||||
| if (conv_opr.param().format != | |||||
| megdnn::param::Convolution::Format::NCHW) { | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| auto conv_mode = | |||||
| trans_nchw4(conv_opr.param().sparse, new_inp[1]); | |||||
| VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; | VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; | ||||
| // src: NCHW --> NCWH4 | // src: NCHW --> NCWH4 | ||||
| if (new_inp[0]->shape().ndim != 5) { | if (new_inp[0]->shape().ndim != 5) { | ||||
| mgb_assert(new_inp[0]->shape().ndim == 4); | mgb_assert(new_inp[0]->shape().ndim == 4); | ||||
| auto new_src = RelayoutPlaceholder::make(new_inp[0], | |||||
| src_to_nchw4_mode); | |||||
| auto new_src = | |||||
| RelayoutPlaceholder::make(new_inp[0], conv_mode.src); | |||||
| conv_src = new_src.node(); | conv_src = new_src.node(); | ||||
| } | } | ||||
| // weight: NCHW --> NCHW4 | // weight: NCHW --> NCHW4 | ||||
| auto weight_mode = | |||||
| trans_nchw4(conv_opr.param().sparse, new_inp[1]); | |||||
| auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); | |||||
| auto new_filter = | |||||
| RelayoutPlaceholder::make(new_inp[1], conv_mode.weight); | |||||
| conv_filter = new_filter.node(); | conv_filter = new_filter.node(); | ||||
| // format: NCHW --> NCHW4 | // format: NCHW --> NCHW4 | ||||
| auto new_param = conv_opr.param(); | auto new_param = conv_opr.param(); | ||||
| @@ -1428,7 +1486,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& batch_conv_bias_opr = | auto& batch_conv_bias_opr = | ||||
| opr->cast_final_safe<opr::BatchConvBiasForward>(); | opr->cast_final_safe<opr::BatchConvBiasForward>(); | ||||
| mgb_assert(batch_conv_bias_opr.param().format == | |||||
| if (batch_conv_bias_opr.param().format != | |||||
| megdnn::param::BatchConvBias::Format::NCHW) { | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| mgb_assert(batch_conv_bias_opr.param().format == | |||||
| megdnn::param::BatchConvBias::Format::NCHW, | megdnn::param::BatchConvBias::Format::NCHW, | ||||
| "ConvertFormat Pass only support converting NCHW to NCHW4"); | "ConvertFormat Pass only support converting NCHW to NCHW4"); | ||||
| // what should be converted: src, weight | // what should be converted: src, weight | ||||
| @@ -1491,26 +1555,30 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| }; | }; | ||||
| auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, | auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, | ||||
| src_to_nchw4_mode]( | src_to_nchw4_mode]( | ||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
| mgb_assert(conv_bias_opr.param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHW4"); | |||||
| if (conv_bias_opr.param().format != | |||||
| megdnn::param::Convolution::Format::NCHW) { | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| // what should be converted: src, weight | // what should be converted: src, weight | ||||
| VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; | VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; | ||||
| auto conv_mode = | |||||
| trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); | |||||
| // src: NCHW --> NCHW4 | // src: NCHW --> NCHW4 | ||||
| if (new_inp[0]->shape().ndim !=5) { | |||||
| if (new_inp[0]->shape().ndim != 5) { | |||||
| mgb_assert(new_inp[0]->shape().ndim == 4); | mgb_assert(new_inp[0]->shape().ndim == 4); | ||||
| auto new_src = RelayoutPlaceholder::make(new_inp[0], | |||||
| src_to_nchw4_mode); | |||||
| auto new_src = | |||||
| RelayoutPlaceholder::make(new_inp[0], conv_mode.src); | |||||
| conv_bias_src = new_src.node(); | conv_bias_src = new_src.node(); | ||||
| } | } | ||||
| // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 | // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 | ||||
| auto weight_mode = | |||||
| trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); | |||||
| auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); | |||||
| auto new_filter = | |||||
| RelayoutPlaceholder::make(new_inp[1], conv_mode.weight); | |||||
| conv_bias_filter = new_filter.node(); | conv_bias_filter = new_filter.node(); | ||||
| // format: NCHW --> NCHW4 | // format: NCHW --> NCHW4 | ||||
| auto new_param = conv_bias_opr.param(); | auto new_param = conv_bias_opr.param(); | ||||
| @@ -1527,8 +1595,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| // bias: NCHW --> NCHW4 | // bias: NCHW --> NCHW4 | ||||
| VarNode* conv_bias_bias = new_inp[2]; | VarNode* conv_bias_bias = new_inp[2]; | ||||
| if (new_inp[2]->shape().ndim == 4) { | if (new_inp[2]->shape().ndim == 4) { | ||||
| auto new_bias = RelayoutPlaceholder::make(new_inp[2], | |||||
| src_to_nchw4_mode); | |||||
| auto new_bias = | |||||
| RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode); | |||||
| conv_bias_bias = new_bias.node(); | conv_bias_bias = new_bias.node(); | ||||
| } | } | ||||
| if (new_inp.size() == 3) { | if (new_inp.size() == 3) { | ||||
| @@ -1543,8 +1611,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| // z_inp: NCHW --> NCHW4 | // z_inp: NCHW --> NCHW4 | ||||
| VarNode* z_inp = new_inp[3]; | VarNode* z_inp = new_inp[3]; | ||||
| if (new_inp[3]->shape().ndim == 4) { | if (new_inp[3]->shape().ndim == 4) { | ||||
| auto new_z = RelayoutPlaceholder::make(new_inp[3], | |||||
| src_to_nchw4_mode); | |||||
| auto new_z = | |||||
| RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode); | |||||
| z_inp = new_z.node(); | z_inp = new_z.node(); | ||||
| } | } | ||||
| auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, | auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, | ||||
| @@ -1599,18 +1667,100 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| } | } | ||||
| return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); | return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); | ||||
| }; | }; | ||||
| auto replace_pooling_opr = [](OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| using Param = opr::PoolingForward::Param; | |||||
| using Format = Param::Format; | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); | |||||
| if (pooling.param().format != Format::NCHW) { | |||||
| return opr; | |||||
| } | |||||
| if (new_inp[0]->shape().ndim == 5) { | |||||
| mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
| auto new_param = pooling.param(); | |||||
| new_param.format = Format::NCHW4; | |||||
| auto new_pooling = | |||||
| opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
| mgb_assert(new_pooling.shape().ndim == 5, | |||||
| "out var of Pooling opr after transform must be 5 (got: " | |||||
| "%zu).", | |||||
| new_pooling.shape().ndim); | |||||
| return new_pooling.node()->owner_opr(); | |||||
| } | |||||
| auto new_opr = | |||||
| serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
| return new_opr; | |||||
| }; | |||||
| auto replace_resize_opr = [](OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| using Param = opr::ResizeForward::Param; | |||||
| using Format = Param::Format; | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| auto& resize = opr->cast_final_safe<opr::ResizeForward>(); | |||||
| if (new_inp[0]->shape().ndim == 5) { | |||||
| mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
| auto new_param = resize.param(); | |||||
| new_param.format = Format::NCHW4; | |||||
| auto new_resize = opr::ResizeForward::make( | |||||
| new_inp[0], new_inp[1], new_param, opr->config()); | |||||
| mgb_assert(new_resize.shape().ndim == 5, | |||||
| "out var of Resize opr after transform must be 5 (got: " | |||||
| "%zu).", | |||||
| new_resize.shape().ndim); | |||||
| return new_resize.node()->owner_opr(); | |||||
| } | |||||
| auto new_opr = | |||||
| serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
| return new_opr; | |||||
| }; | |||||
| auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| using Param = opr::WarpPerspective::Param; | |||||
| using Format = Param::Format; | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); | |||||
| if (new_inp[0]->shape().ndim == 5) { | |||||
| mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
| auto new_param = warp.param(); | |||||
| new_param.format = Format::NCHW4; | |||||
| SymbolVar new_warp; | |||||
| if (new_inp.size() == 3) { | |||||
| new_warp = opr::WarpPerspectiveForward::make( | |||||
| new_inp[0], new_inp[1], nullptr, new_inp[2], new_param, | |||||
| opr->config()); | |||||
| } else { | |||||
| mgb_assert(new_inp.size() == 4); | |||||
| new_warp = opr::WarpPerspectiveForward::make( | |||||
| new_inp[0], new_inp[1], new_inp[2], new_inp[3], | |||||
| new_param, opr->config()); | |||||
| } | |||||
| mgb_assert(new_warp.shape().ndim == 5, | |||||
| "out var of WarpPerspective opr after transform must be " | |||||
| "5 (got: " | |||||
| "%zu).", | |||||
| new_warp.shape().ndim); | |||||
| return new_warp.node()->owner_opr(); | |||||
| } | |||||
| auto new_opr = | |||||
| serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
| return new_opr; | |||||
| }; | |||||
| auto&& replace_func = ret->m_opr_replace_func; | auto&& replace_func = ret->m_opr_replace_func; | ||||
| //! supportted nchw4 | //! supportted nchw4 | ||||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
| replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
| replace_func[opr::BatchConvBias::typeinfo()] = | replace_func[opr::BatchConvBias::typeinfo()] = | ||||
| replace_batch_conv_bias_opr; | replace_batch_conv_bias_opr; | ||||
| replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||||
| replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | |||||
| replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||||
| replace_warp_perspective_opr; | |||||
| replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | ||||
| replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | ||||
| replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | ||||
| replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | ||||
| //! not supported nchw4 | //! not supported nchw4 | ||||
| replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw; | |||||
| replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::ConvolutionBackwardData::typeinfo()] = | replace_func[opr::ConvolutionBackwardData::typeinfo()] = | ||||
| relayout_inp_to_nchw; | relayout_inp_to_nchw; | ||||
| @@ -1620,9 +1770,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; | |||||
| replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||||
| relayout_inp_to_nchw; | |||||
| replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -1512,6 +1512,7 @@ TEST_PASS(FuseConvBiasNonlinPass, Basic) { | |||||
| #if MGB_CUDA | #if MGB_CUDA | ||||
| TEST(TestEnableTensorCore, SmallInputShape) { | TEST(TestEnableTensorCore, SmallInputShape) { | ||||
| REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
| auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
| @@ -1579,6 +1580,104 @@ TEST(TestEnableTensorCore, SmallInputShape) { | |||||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | ||||
| } | } | ||||
| TEST(TestEnableTensorCore, Nchw4Nchw) { | |||||
| REQUIRE_GPU(1); | |||||
| auto cn = CompNode::load("gpu0"); | |||||
| cn.activate(); | |||||
| auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||||
| auto sm_ver = prop.major * 10 + prop.minor; | |||||
| if (sm_ver < 75) { | |||||
| printf("This testcast ignored due to insufficient cuda cap(got: %d, " | |||||
| "expected: %d)\n", | |||||
| sm_ver, 75); | |||||
| return; | |||||
| } | |||||
| HostTensorGenerator<dtype::Int8> gen; | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
| .rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto mkshape = [](opr::ConvBias::Param::Format format, size_t N, size_t C, | |||||
| size_t H, size_t W) -> TensorShape { | |||||
| mgb_assert(C % 4 == 0); | |||||
| if (format == opr::ConvBias::Param::Format::NCHW4) { | |||||
| return {N, C / 4, H, W, 4}; | |||||
| } else { | |||||
| mgb_assert(format == opr::ConvBias::Param::Format::NCHW); | |||||
| return {N, C, H, W}; | |||||
| } | |||||
| }; | |||||
| for (auto format : {opr::ConvBias::Param::Format::NCHW, | |||||
| opr::ConvBias::Param::Format::NCHW4}) { | |||||
| auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), | |||||
| dtype::QuantizedS8(2.5f)), | |||||
| w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), | |||||
| dtype::QuantizedS8(2.5f)), | |||||
| b = mkcvar("b", mkshape(format, 1, 64, 1, 1), | |||||
| dtype::QuantizedS32(6.25f)), | |||||
| z = mkcvar("b1", mkshape(format, 32, 64, 8, 8), | |||||
| dtype::QuantizedS8(2.5f)); | |||||
| opr::ConvBias::Param param; | |||||
| param.format = format; | |||||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
| param.stride_h = param.stride_w = 2; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| auto y = opr::ConvBias::make( | |||||
| x, w, b, z, param, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); | |||||
| y = opr::ConvBias::make(y, w, b, param, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); | |||||
| y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
| SymbolVar y_opt; | |||||
| SymbolVar y_no_tc; | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nchw32().enable_fuse_conv_bias_nonlinearity(); | |||||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
| } | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_fuse_conv_bias_nonlinearity(); | |||||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_no_tc); | |||||
| } | |||||
| auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt); | |||||
| std::string json_name; | |||||
| ASSERT_EQ(2u, nr_dimshuffle); | |||||
| if (format == opr::ConvBias::Param::Format::NCHW4) { | |||||
| json_name = "TestGoptInference.Nchw4Nchw.NCHW4.json"; | |||||
| } else { | |||||
| mgb_assert(format == opr::ConvBias::Param::Format::NCHW); | |||||
| json_name = "TestGoptInference.Nchw4Nchw.NCHW.json"; | |||||
| } | |||||
| graph->compile({{y_opt, {}}}) | |||||
| ->to_json() | |||||
| ->writeto_fpath(output_file(json_name.c_str())); | |||||
| HostTensorND host_y, host_y_opt; | |||||
| auto func = graph->compile({make_callback_copy(y_no_tc, host_y), | |||||
| make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
| } | |||||
| } | |||||
| TEST(TestEnableTensorCore, ConvBiasWithZ) { | TEST(TestEnableTensorCore, ConvBiasWithZ) { | ||||
| REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
| auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
| @@ -2043,53 +2142,74 @@ TEST(TestGoptInference, EnableCHWN4) { | |||||
| .rename(name), | .rename(name), | ||||
| dtype); | dtype); | ||||
| }; | }; | ||||
| auto mkshape = [](opr::ConvBias::Param::Format format, size_t N, size_t C, | |||||
| size_t H, size_t W) -> TensorShape { | |||||
| mgb_assert(C % 4 == 0); | |||||
| if (format == opr::ConvBias::Param::Format::NCHW4) { | |||||
| return {N, C / 4, H, W, 4}; | |||||
| } else { | |||||
| mgb_assert(format == opr::ConvBias::Param::Format::NCHW); | |||||
| return {N, C, H, W}; | |||||
| } | |||||
| }; | |||||
| auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)), | |||||
| w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), | |||||
| b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), | |||||
| b1 = mkvar("b1", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)); | |||||
| opr::ConvBias::Param param; | |||||
| param.format = opr::ConvBias::Param::Format::NCHW4; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
| for (auto format : {opr::ConvBias::Param::Format::NCHW, | |||||
| opr::ConvBias::Param::Format::NCHW4}) { | |||||
| auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), | |||||
| dtype::QuantizedS8(2.5f)), | |||||
| w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), | |||||
| dtype::QuantizedS8(2.5f)), | |||||
| b = mkcvar("b", mkshape(format, 1, 64, 1, 1), | |||||
| dtype::QuantizedS32(6.25f)), | |||||
| b1 = mkvar("b1", mkshape(format, 32, 64, 16, 16), | |||||
| dtype::QuantizedS8(2.5f)); | |||||
| opr::ConvBias::Param param; | |||||
| param.format = format; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
| auto y = opr::ConvBiasForward::make( | |||||
| x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y1 = opr::ElemwiseMultiType::make( | |||||
| {y, b1}, opr::ElemwiseMultiType::Mode::QFUSE_ADD_RELU, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y2 = opr::ConvBiasForward::make( | |||||
| y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y3 = opr::ElemwiseMultiType::make( | |||||
| {y, b1}, opr::ElemwiseMultiType::Param::Mode::QSUB, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y4 = opr::ElemwiseMultiType::make( | |||||
| {y1, y2}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| y4 = opr::ElemwiseMultiType::make( | |||||
| {y3, y4}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| y4 = opr::TypeCvt::make(y4, dtype::Float32()); | |||||
| SymbolVar y_opt; | |||||
| SymbolVar y_cudnn; | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_chwn4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); | |||||
| auto y = opr::ConvBiasForward::make( | |||||
| x, w, b, param, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y1 = opr::ElemwiseMultiType::make( | |||||
| {y, b1}, opr::ElemwiseMultiType::Mode::QFUSE_ADD_RELU, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y2 = opr::ConvBiasForward::make( | |||||
| y, w, b, param, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y3 = opr::ElemwiseMultiType::make( | |||||
| {y, b1}, opr::ElemwiseMultiType::Param::Mode::QSUB, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y4 = opr::ElemwiseMultiType::make( | |||||
| {y1, y2}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| y4 = opr::ElemwiseMultiType::make( | |||||
| {y3, y4}, opr::ElemwiseMultiType::Param::Mode::QADD, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| y4 = opr::TypeCvt::make(y4, dtype::Float32()); | |||||
| SymbolVar y_opt; | |||||
| SymbolVar y_cudnn; | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_chwn4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); | |||||
| } | |||||
| unpack_vector(gopt::GraphOptimizer{} | |||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
| .add_pass<gopt::FuseConvBiasZPass>() | |||||
| .apply({{y4}}) | |||||
| .endpoint_vars(), | |||||
| y_cudnn); | |||||
| ASSERT_EQ(opr::ConvBias::Param::Format::CHWN4, | |||||
| find_opr<opr::ConvBias>(y_opt).param().format); | |||||
| HostTensorND host_y, host_y_opt; | |||||
| auto func = graph->compile({make_callback_copy(y_cudnn, host_y), | |||||
| make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
| } | } | ||||
| unpack_vector(gopt::GraphOptimizer{} | |||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
| .add_pass<gopt::FuseConvBiasZPass>() | |||||
| .apply({{y4}}) | |||||
| .endpoint_vars(), | |||||
| y_cudnn); | |||||
| HostTensorND host_y, host_y_opt; | |||||
| auto func = graph->compile({make_callback_copy(y_cudnn, host_y), | |||||
| make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||||
| } | } | ||||
| TEST(TestGoptInference, EnableCHWN4WarpPespective) { | TEST(TestGoptInference, EnableCHWN4WarpPespective) { | ||||
| @@ -2430,14 +2550,16 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
| auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | ||||
| b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | ||||
| auto conv1 = opr::ConvBiasForward::make( | auto conv1 = opr::ConvBiasForward::make( | ||||
| x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| x, w1, b1, param_conv_bias, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| // group | // group | ||||
| // icpg != 1 && ocpg != 1 | // icpg != 1 && ocpg != 1 | ||||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | ||||
| auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | ||||
| b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | ||||
| auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, | |||||
| param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto conv2 = opr::ConvBiasForward::make( | |||||
| conv1, w2, b2, param_conv_bias, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | ||||
| @@ -2450,11 +2572,13 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
| ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | ||||
| find_opr<opr::ConvBias>(y_opt).param().format); | find_opr<opr::ConvBias>(y_opt).param().format); | ||||
| auto nr_reshape = find_opr_num<mgb::opr::Reshape>(y_opt); | |||||
| ASSERT_EQ(2u, nr_reshape); | |||||
| graph->compile({{y_opt, {}}}) | graph->compile({{y_opt, {}}}) | ||||
| ->to_json() | ->to_json() | ||||
| ->writeto_fpath( | |||||
| output_file("TestGoptInference.ConvertFormatNCHW4GPU.json")); | |||||
| ->writeto_fpath(output_file( | |||||
| "TestGoptInference.ConvertFormatNCHW4GPU.json")); | |||||
| HostTensorND host_y, host_y_opt; | HostTensorND host_y, host_y_opt; | ||||
| auto func = graph->compile({make_callback_copy(y, host_y), | auto func = graph->compile({make_callback_copy(y, host_y), | ||||
| @@ -2465,6 +2589,90 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
| #endif | #endif | ||||
| TEST(TestGoptInference, ConvertFormatNCHW4NonConvOpr) { | |||||
| auto cn = CompNode::load("xpu0"); | |||||
| HostTensorGenerator<dtype::Int8> gen; | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
| .rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto mkcvarf32 = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
| .rename(name); | |||||
| }; | |||||
| auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f)); | |||||
| opr::ConvBias::Param param_conv_bias; | |||||
| param_conv_bias.format = opr::ConvBias::Param::Format::NCHW; | |||||
| param_conv_bias.stride_h = param_conv_bias.stride_w = 1; | |||||
| param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||||
| param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
| // dense | |||||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
| auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
| b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
| auto conv1 = opr::ConvBiasForward::make( | |||||
| x, w1, b1, param_conv_bias, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| // test Resize | |||||
| auto shape_of = opr::GetVarShape::make(x); | |||||
| auto subtensor = opr::Subtensor::make( | |||||
| shape_of, {opr::Subtensor::AxisIndexer::make_interval( | |||||
| 0, x.make_scalar(2), None, x.make_scalar(1))}); | |||||
| opr::Resize::Param param_resize; | |||||
| param_resize.format = opr::Resize::Param::Format::NCHW; | |||||
| auto resize = opr::ResizeForward::make(conv1, subtensor * 2, param_resize); | |||||
| // test WarpPerspective | |||||
| auto mat = mkcvarf32("mat", {2, 3, 3}), | |||||
| warp = opr::WarpPerspectiveForward::make( | |||||
| resize, mat, nullptr, cg::var_from_tensor_shape(x, {32, 32})); | |||||
| opr::Pooling::Param pool_param; | |||||
| pool_param.format = opr::Pooling::Param::Format::NCHW; | |||||
| // test Pooling | |||||
| auto pool = opr::Pooling::make(warp, pool_param); | |||||
| // group | |||||
| // icpg != 1 && ocpg != 1 | |||||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
| auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
| b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
| auto conv2 = opr::ConvBiasForward::make( | |||||
| pool, w2, b2, param_conv_bias, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto add = opr::ElemwiseMultiType::make( | |||||
| {conv1, conv2}, {opr::ElemwiseMultiType::Param::Mode::QADD}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{1.2f}}); | |||||
| auto y = opr::TypeCvt::make(add, dtype::Float32()); | |||||
| SymbolVar y_opt; | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nchw4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
| } | |||||
| auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt); | |||||
| ASSERT_EQ(2u, nr_dimshuffle); | |||||
| ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | |||||
| find_opr<opr::ConvBias>(y_opt).param().format); | |||||
| ASSERT_EQ(opr::ResizeForward::Param::Format::NCHW4, | |||||
| find_opr<opr::ResizeForward>(y_opt).param().format); | |||||
| ASSERT_EQ(opr::WarpPerspectiveForward::Param::Format::NCHW4, | |||||
| find_opr<opr::WarpPerspectiveForward>(y_opt).param().format); | |||||
| ASSERT_EQ(opr::PoolingForward::Param::Format::NCHW4, | |||||
| find_opr<opr::PoolingForward>(y_opt).param().format); | |||||
| } | |||||
| TEST(TestGoptInference, ConvertFormatNCHW4) { | TEST(TestGoptInference, ConvertFormatNCHW4) { | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
| @@ -2479,7 +2687,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { | |||||
| }; | }; | ||||
| auto x = mkvar("x", {2, 4, 16, 16}); | auto x = mkvar("x", {2, 4, 16, 16}); | ||||
| // ConvBias | |||||
| // ConvBias test dense | |||||
| opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
| param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | ||||
| @@ -2517,6 +2725,67 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | ||||
| } | } | ||||
| TEST(TestGoptInference, ConvertFormatNCHW4Ic3) { | |||||
| REQUIRE_GPU(1); | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{ | |||||
| 1.2f, 127 * 127}; | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::SharedDeviceTensor::make(*graph, *gen(shp)) | |||||
| .rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto x = mkvar("x", {2, 3, 16, 16}, dtype::QuantizedS8(2.5f)); | |||||
| // ConvBias test dense | |||||
| opr::ConvBias::Param param_conv_bias; | |||||
| param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
| auto w1 = mkcvar("w1", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
| b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
| auto conv1 = | |||||
| opr::ConvBias::make(x, w1, b1, param_conv_bias, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
| auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
| b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
| auto conv2 = | |||||
| opr::ConvBias::make(conv1, w2, b2, param_conv_bias, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | |||||
| SymbolVar y_opt; | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nchw4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
| } | |||||
| ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | |||||
| find_opr<opr::ConvBias>(y_opt).param().format); | |||||
| graph->compile({{y_opt, {}}}) | |||||
| ->to_json() | |||||
| ->writeto_fpath(output_file( | |||||
| "TestGoptInference.ConvertFormatNCHW4Ic3.json")); | |||||
| HostTensorND host_y_opt, host_y; | |||||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||||
| make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||||
| } | |||||
| TEST(TestGoptInference, ConvertFormatNCHW88) { | TEST(TestGoptInference, ConvertFormatNCHW88) { | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
| @@ -55,3 +55,8 @@ struct IndexDescMaskItem { | |||||
| table IndexDescMaskDump { | table IndexDescMaskDump { | ||||
| items:[IndexDescMaskItem]; | items:[IndexDescMaskItem]; | ||||
| } | } | ||||
| table NMSKeep { | |||||
| iou_thresh:float; | |||||
| max_output:uint; | |||||
| } | |||||
| @@ -30,74 +30,75 @@ table Blob { | |||||
| table Reserved0 {} | table Reserved0 {} | ||||
| union OperatorParam { | union OperatorParam { | ||||
| param.Empty, | |||||
| param.Axis, | |||||
| param.Convolution, | |||||
| param.MaskPropagate, | |||||
| param.ConvPooling, | |||||
| param.ConvBias, | |||||
| param.SeparableConv, | |||||
| param.Images2Neibs, | |||||
| param.Pooling, | |||||
| param.LRN, | |||||
| param.BN, | |||||
| param.ROIPooling, | |||||
| param.WarpPerspective, | |||||
| param.SpatialTfGridGenerator, | |||||
| param.SpatialTfSampler, | |||||
| param.MGBAddUpdate, | |||||
| param.Elemwise, | |||||
| param.ElemwiseMultiType, | |||||
| param.PowC, | |||||
| param.MatrixMul, | |||||
| param.Winograd, | |||||
| param.SVD, | |||||
| param.Reduce, | |||||
| param.Cumsum, | |||||
| param.CondTake, | |||||
| param.Argsort, | |||||
| param.IndexingRemap, | |||||
| param.MGBSleep, | |||||
| param.Linspace, | |||||
| param.LinspaceFull, | |||||
| param.Eye, | |||||
| param.UniformRNG, | |||||
| param.GaussianRNG, | |||||
| param.Flip, | |||||
| param.Rotate, | |||||
| param.ROICopy, | |||||
| param.CvtColor, | |||||
| param.WarpAffine, | |||||
| param.GaussianBlur, | |||||
| param.Resize, | |||||
| param.Remap, | |||||
| param.Convolution3D, | |||||
| param.Conv3DBias, | |||||
| param.SeparableConv3D, | |||||
| param.TopK, | |||||
| param.RelayoutFormat, | |||||
| param.SeparableFilter, | |||||
| param.LocalShare, | |||||
| param.ROIAlign, | |||||
| param.DeformablePSROIPooling, | |||||
| param.BatchConvBias, | |||||
| param.DType, | |||||
| param.PersistentOutputStorage, | |||||
| param.OptionalAxis, | |||||
| param.OptionalAxisV1, | |||||
| param.ExecutionPolicy, | |||||
| param.AssertEqual, | |||||
| Reserved0, | |||||
| param.CollectiveComm, | |||||
| param.CondExecPred, | |||||
| param.CondExecPredLogical, | |||||
| param.CondExecMark, | |||||
| param.CondExecMerge, | |||||
| param.Host2DeviceCopy, | |||||
| param.Dimshuffle, | |||||
| param.AxisAddRemove, | |||||
| param.IndexDescMaskDump, | |||||
| DType, | |||||
| param.Empty = 1, | |||||
| param.Axis = 2, | |||||
| param.Convolution = 3, | |||||
| param.MaskPropagate = 4, | |||||
| param.ConvPooling = 5, | |||||
| param.ConvBias = 6, | |||||
| param.SeparableConv = 7, | |||||
| param.Images2Neibs = 8, | |||||
| param.Pooling = 9, | |||||
| param.LRN = 10, | |||||
| param.BN = 11, | |||||
| param.ROIPooling = 12, | |||||
| param.WarpPerspective = 13, | |||||
| param.SpatialTfGridGenerator = 14, | |||||
| param.SpatialTfSampler = 15, | |||||
| param.MGBAddUpdate = 16, | |||||
| param.Elemwise = 17, | |||||
| param.ElemwiseMultiType = 18, | |||||
| param.PowC = 19, | |||||
| param.MatrixMul = 20, | |||||
| param.Winograd = 21, | |||||
| param.SVD = 22, | |||||
| param.Reduce = 23, | |||||
| param.Cumsum = 24, | |||||
| param.CondTake = 25, | |||||
| param.Argsort = 26, | |||||
| param.IndexingRemap = 27, | |||||
| param.MGBSleep = 28, | |||||
| param.Linspace = 29, | |||||
| param.LinspaceFull = 30, | |||||
| param.Eye = 31, | |||||
| param.UniformRNG = 32, | |||||
| param.GaussianRNG = 33, | |||||
| param.Flip = 34, | |||||
| param.Rotate = 35, | |||||
| param.ROICopy = 36, | |||||
| param.CvtColor = 37, | |||||
| param.WarpAffine = 38, | |||||
| param.GaussianBlur = 39, | |||||
| param.Resize = 40, | |||||
| param.Convolution3D = 41, | |||||
| param.Conv3DBias = 42, | |||||
| param.SeparableConv3D = 43, | |||||
| param.TopK = 44, | |||||
| param.RelayoutFormat = 45, | |||||
| param.SeparableFilter = 46, | |||||
| param.LocalShare = 47, | |||||
| param.ROIAlign = 48, | |||||
| param.DeformablePSROIPooling = 49, | |||||
| param.BatchConvBias = 50, | |||||
| param.DType = 51, | |||||
| param.PersistentOutputStorage = 52, | |||||
| param.OptionalAxis = 53, | |||||
| param.OptionalAxisV1 = 54, | |||||
| param.ExecutionPolicy = 55, | |||||
| param.AssertEqual = 56, | |||||
| Reserved0 = 57, | |||||
| param.CollectiveComm = 58, | |||||
| param.CondExecPred = 59, | |||||
| param.CondExecPredLogical = 60, | |||||
| param.CondExecMark = 61, | |||||
| param.CondExecMerge = 62, | |||||
| param.Host2DeviceCopy = 63, | |||||
| param.Dimshuffle = 64, | |||||
| param.AxisAddRemove = 65, | |||||
| param.IndexDescMaskDump = 66, | |||||
| DType = 67, | |||||
| param.Remap = 68, | |||||
| param.NMSKeep = 69, | |||||
| } | } | ||||
| table Operator { | table Operator { | ||||
| @@ -846,7 +846,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, | |||||
| OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | ||||
| auto result = ctx.load_oprs(); | auto result = ctx.load_oprs(); | ||||
| auto fbs_end = tensor_begin + offset_to_fbs + size; | |||||
| auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; | |||||
| auto cur = m_file->tell(); | auto cur = m_file->tell(); | ||||
| mgb_assert(fbs_end > cur); | mgb_assert(fbs_end > cur); | ||||
| // Skip to Graph end | // Skip to Graph end | ||||
| @@ -872,4 +872,4 @@ bool is_fbs_file(InputFile& file) { | |||||
| } // namespace serialization | } // namespace serialization | ||||
| } // namespace mgb | } // namespace mgb | ||||
| #endif | |||||
| #endif | |||||
| @@ -64,6 +64,34 @@ TEST(TestSerializer2, GraphDumpLoad) { | |||||
| load(); | load(); | ||||
| } | } | ||||
| TEST(TestSerializer2, MultiGraphDumpLoad) { | |||||
| auto fname = GET_OUTPUT_FILE(); | |||||
| auto dump = [&]() { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::ImmutableTensor::make(*graph, 1926.0817f, {cn}); | |||||
| x.rename("varz"); | |||||
| auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||||
| GraphDumpFormat::FLATBUFFERS); | |||||
| // dump twice | |||||
| dumper->dump({x}); | |||||
| dumper->dump({x}); | |||||
| }; | |||||
| auto load = [&]() { | |||||
| GraphLoader::LoadConfig load_config = {}; | |||||
| auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||||
| GraphDumpFormat::FLATBUFFERS); | |||||
| // load twice | |||||
| loader->load(load_config, false); | |||||
| loader = GraphLoader::make(loader->reset_file(), loader->format()); | |||||
| loader->load(load_config, false); | |||||
| }; | |||||
| dump(); | |||||
| load(); | |||||
| } | |||||
| TEST(TestSerializer2, APlusB) { | TEST(TestSerializer2, APlusB) { | ||||
| auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||
| TensorShape shape{2, 3}; | TensorShape shape{2, 3}; | ||||
| @@ -733,4 +761,4 @@ TEST(TestSerializer2, HasOutputDtype) { | |||||
| load(); | load(); | ||||
| } | } | ||||
| #endif | |||||
| #endif | |||||
| @@ -1727,8 +1727,17 @@ void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() { | |||||
| } | } | ||||
| } | } | ||||
| void mgb::tensorrt::transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars) { | |||||
| void mgb::tensorrt::transform_dest_vars_inplace( | |||||
| mgb::cg::VarNodeArray& dest_vars, | |||||
| cg::GraphCommonOptimizeOptions& options) { | |||||
| gopt::GraphOptimizer optimizer; | gopt::GraphOptimizer optimizer; | ||||
| //! As in megengine, the layout is NCHW, while tensorrt pass currently | |||||
| //! only support NCHW4(int8), so we transform layout to nchw4 firstly. | |||||
| if (options.has_set_nchw4()) { | |||||
| options.disable_nchw4(); | |||||
| optimizer.add_pass<FuseConvBiasNonlinPass>(); | |||||
| optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter()); | |||||
| } | |||||
| optimizer.add_pass<ExpandFusedArithPass>(); | optimizer.add_pass<ExpandFusedArithPass>(); | ||||
| optimizer.add_pass<gopt::TensorRTReplacePass>(); | optimizer.add_pass<gopt::TensorRTReplacePass>(); | ||||
| optimizer.add_pass<ArithFusePass>(); | optimizer.add_pass<ArithFusePass>(); | ||||
| @@ -32,7 +32,8 @@ public: | |||||
| namespace tensorrt { | namespace tensorrt { | ||||
| void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars); | |||||
| void transform_dest_vars_inplace(mgb::cg::VarNodeArray& dest_vars, | |||||
| cg::GraphCommonOptimizeOptions& options); | |||||
| } | } | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -1930,7 +1930,7 @@ TEST(TestTensorRTReplace, FuseConvAdd) { | |||||
| param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
| param.pad_h = param.pad_w = 1; | param.pad_h = param.pad_w = 1; | ||||
| auto y = opr::Convolution::make(x, w, param); | auto y = opr::Convolution::make(x, w, param); | ||||
| auto nchw2nchw4 = [](SymbolVar x) { | auto nchw2nchw4 = [](SymbolVar x) { | ||||
| auto xshp = opr::GetVarShape::make(x); | auto xshp = opr::GetVarShape::make(x); | ||||
| @@ -1978,6 +1978,68 @@ TEST(TestTensorRTReplace, FuseConvAdd) { | |||||
| MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); | MGB_ASSERT_TENSOR_NEAR(outputs[1], outputs[3], 1e-3); | ||||
| } | } | ||||
| TEST(TestTensorRTReplace, FuseConvAddNchw2nchw4) { | |||||
| REQUIRE_GPU(1); | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{ | |||||
| 1.2f, 127 * 127}; | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::SharedDeviceTensor::make(*graph, *gen(shp)) | |||||
| .rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto x = mkvar("x", {32, 4, 28, 28}, dtype::QuantizedS8(2.5f)), | |||||
| w = mkcvar("w", {16, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
| b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
| opr::ConvBias::Param param; | |||||
| param.format = opr::ConvBias::Param::Format::NCHW; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| auto y = opr::ConvBias::make(x, w, b, param, {}, | |||||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
| auto z = opr::TypeCvt::make(y, dtype::Float32()); | |||||
| SymbolVar trt_z; | |||||
| SymbolVar mgb_z; | |||||
| ComputingGraph::Options opt; | |||||
| opt.graph_opt_level = 0; | |||||
| unpack_vector( | |||||
| gopt::GraphOptimizer{} | |||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
| .add_pass(gopt::EnableNCHW4Pass::make_nchw4_converter()) | |||||
| .add_pass<gopt::ExpandFusedArithPass>() | |||||
| .add_pass<gopt::TensorRTReplacePass>() | |||||
| .add_pass<gopt::ArithFusePass>() | |||||
| .apply({{z}}) | |||||
| .endpoint_vars(), | |||||
| trt_z); | |||||
| opt.graph_opt_level = 0; | |||||
| unpack_vector(gopt::GraphOptimizer{}.apply({{z}}).endpoint_vars(), | |||||
| mgb_z); | |||||
| ComputingGraph::OutputSpec outspec(2); | |||||
| SmallVector<HostTensorND> outputs(2); | |||||
| outspec[0] = make_callback_copy(trt_z, outputs[0], false); | |||||
| outspec[1] = make_callback_copy(mgb_z, outputs[1], false); | |||||
| graph->options().graph_opt.tensorrt = false; | |||||
| auto func = graph->compile(outspec); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(outputs[0], outputs[1], 1e-3); | |||||
| } | |||||
| #endif // MGB_ENABLE_TENSOR_RT | #endif // MGB_ENABLE_TENSOR_RT | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||