fast-run
GitOrigin-RevId: 49ccbdf2d4
tags/v0.5.0
| @@ -154,7 +154,7 @@ public: | |||
| for (auto&& algo : matmul_algos) { | |||
| if (algo->type() == nullptr) | |||
| continue; | |||
| for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) { | |||
| for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
| refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| @@ -725,6 +725,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||
| cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); | |||
| cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); | |||
| cb(nchw4, { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass<FuseConvBiasZPass>(); | |||
| @@ -736,10 +737,21 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass(ConvertFormatPass::make_nhwcd4_converter()); | |||
| }); | |||
| cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); | |||
| cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); | |||
| cb(nchw44_dot, | |||
| { add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); }); | |||
| cb(nchw88, { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | |||
| add_pass<ShuffleShuffleRemovePass>(); | |||
| }); | |||
| cb(nchw44, { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | |||
| add_pass<ShuffleShuffleRemovePass>(); | |||
| }); | |||
| cb(nchw44_dot, { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); | |||
| add_pass<ShuffleShuffleRemovePass>(); | |||
| }); | |||
| cb(nchw32, { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass<FuseConvBiasZPass>(); | |||
| @@ -707,7 +707,9 @@ template <> | |||
| void AlgoChooser<megdnn::ConvBias>::ExeContext:: | |||
| modify_param_with_weights_preprocessed( | |||
| typename TimedProfiler<megdnn::ConvBias>::Param& param) const { | |||
| if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { | |||
| if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW || | |||
| param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44 || | |||
| param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW88) { | |||
| auto winograd_param = | |||
| megdnn::ConvBias::parse_winograd_name(param.algo_name); | |||
| if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) { | |||
| @@ -727,8 +729,18 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext:: | |||
| filter_transform_layout); | |||
| param.shapes[1] = filter_transform_layout; | |||
| param.dtypes[1] = filter_transform_layout.dtype.enumv(); | |||
| param.opr_param.format = megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; | |||
| if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { | |||
| param.opr_param.format = | |||
| megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; | |||
| } else if (param.opr_param.format == | |||
| megdnn::ConvBias::Param::Format::NCHW44) { | |||
| param.opr_param.format = | |||
| megdnn::ConvBias::Param::Format::NCHW44_WINOGRAD; | |||
| } else if (param.opr_param.format == | |||
| megdnn::ConvBias::Param::Format::NCHW) { | |||
| param.opr_param.format = | |||
| megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD; | |||
| } | |||
| param.opr_param.output_block_size = winograd_param.output_block_size; | |||
| } | |||
| } | |||
| @@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||
| spatial_start = 2; | |||
| break; | |||
| case Param::Format::NCHW_WINOGRAD: | |||
| case Param::Format::NCHW44_WINOGRAD: | |||
| case Param::Format::NCHW88_WINOGRAD: | |||
| cpos = 1; | |||
| spatial_start = 0; | |||
| @@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||
| uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]); | |||
| uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]); | |||
| if (param.format == Param::Format::NCHW_WINOGRAD || | |||
| param.format == Param::Format::NCHW44_WINOGRAD || | |||
| param.format == Param::Format::NCHW88_WINOGRAD) { | |||
| mgb_assert(opr->same_type<opr::ConvBias>(), | |||
| "Only conv bias support NCHW_WINOGRAD"); | |||
| "Only conv bias support WINOGRAD"); | |||
| auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>(); | |||
| uint32_t output_block_size = conv_bias_opr.param().output_block_size; | |||
| mgb_assert(fh == fw, | |||
| @@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||
| return dst_shape.total_nr_elems() * fh * fw * | |||
| static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2; | |||
| } | |||
| if (param.format == Param::Format::NCHW44_WINOGRAD) { | |||
| return dst_shape.total_nr_elems() * fh * fw * | |||
| static_cast<uint64_t>(src_shape[cpos] * 4) / group * 2; | |||
| } | |||
| return dst_shape.total_nr_elems() * fh * fw * | |||
| static_cast<uint64_t>(src_shape[cpos]) / group * 2; | |||
| } | |||