GitOrigin-RevId: 9faa7ef068
tags/v0.5.0
| @@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| #endif | #endif | ||||
| if (!strcmp(argv[i], "--enable-chwn4")) { | |||||
| mgb_log_warn("enable chwn4 optimization"); | |||||
| graph_opt.graph_opt.enable_chwn4(); | |||||
| #define cb(_layout) \ | |||||
| if (!strcmp(argv[i], "--enable-" #_layout)) { \ | |||||
| mgb_log_warn("enable " #_layout " optimization"); \ | |||||
| graph_opt.graph_opt.enable_##_layout(); \ | |||||
| continue; \ | |||||
| } | |||||
| cb(chwn4); | |||||
| cb(nchw44); | |||||
| cb(nchw88); | |||||
| cb(nchw32); | |||||
| cb(nhwcd4); | |||||
| #undef cb | |||||
| if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | |||||
| mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | |||||
| graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||||
| continue; | |||||
| } | |||||
| if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) { | |||||
| mgb_log_warn("enable fuse_conv_bias_with_z optimization"); | |||||
| graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); | |||||
| continue; | continue; | ||||
| } | } | ||||
| #if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
| @@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
| options().graph_opt.winograd_transform = false; | options().graph_opt.winograd_transform = false; | ||||
| gopt::transform_vars_inplace_with_winograd(dest_vars); | gopt::transform_vars_inplace_with_winograd(dest_vars); | ||||
| } | } | ||||
| if (options().graph_opt.transform_chwn4()) { | |||||
| gopt::GraphOptimizer optimizer; | |||||
| optimizer.apply_optimize_options(options().graph_opt); | |||||
| options().graph_opt.layout_transform = | |||||
| cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT; | |||||
| optimizer.apply_inplace(dest_vars); | |||||
| } | |||||
| #if MGB_JIT | #if MGB_JIT | ||||
| if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | ||||
| @@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
| optimizer.apply_inplace(dest_vars); | optimizer.apply_inplace(dest_vars); | ||||
| } | } | ||||
| #endif | #endif | ||||
| gopt::GraphOptimizer optimizer; | |||||
| optimizer.apply_optimize_options(options().graph_opt); | |||||
| options().graph_opt.reset(); | |||||
| optimizer.apply_inplace(dest_vars); | |||||
| const OprNodeArray* opr_seq = nullptr; | const OprNodeArray* opr_seq = nullptr; | ||||
| CompSeqExtraInfo extra_info; | CompSeqExtraInfo extra_info; | ||||
| @@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions { | |||||
| bool f16_io_comp = false; | bool f16_io_comp = false; | ||||
| //! whether to enable conv bias nonlinearity fusion | //! whether to enable conv bias nonlinearity fusion | ||||
| bool fuse_conv_bias_nonlinearity = false; | bool fuse_conv_bias_nonlinearity = false; | ||||
| //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||||
| //! + z -> conv_bias(x, w, b, z) | |||||
| bool fuse_conv_bias_with_z = false; | |||||
| enum LayoutTransform : uint32_t { | enum LayoutTransform : uint32_t { | ||||
| DEFAULT, | DEFAULT, | ||||
| NHWCD4, ///< compute using NHWCD4 tensor format | NHWCD4, ///< compute using NHWCD4 tensor format | ||||
| @@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions { | |||||
| ///< used for cuda | ///< used for cuda | ||||
| }; | }; | ||||
| LayoutTransform layout_transform = LayoutTransform::DEFAULT; | LayoutTransform layout_transform = LayoutTransform::DEFAULT; | ||||
| //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||||
| //! + z -> conv_bias(x, w, b, z) | |||||
| bool fuse_conv_bias_with_z = false; | |||||
| void reset() { | |||||
| f16_io_f32_comp = false; | |||||
| f16_io_comp = false; | |||||
| fuse_conv_bias_nonlinearity = false; | |||||
| fuse_conv_bias_with_z = false; | |||||
| layout_transform = LayoutTransform::DEFAULT; | |||||
| } | |||||
| #define SET(n) \ | #define SET(n) \ | ||||
| GraphCommonOptimizeOptions& enable_##n() { \ | GraphCommonOptimizeOptions& enable_##n() { \ | ||||
| @@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions { | |||||
| #undef SET | #undef SET | ||||
| #define SET(_trans, _trans_capital) \ | #define SET(_trans, _trans_capital) \ | ||||
| GraphCommonOptimizeOptions& enable_##_trans() { \ | GraphCommonOptimizeOptions& enable_##_trans() { \ | ||||
| mgb_assert(layout_transform == LayoutTransform::DEFAULT); \ | |||||
| layout_transform = LayoutTransform::_trans_capital; \ | layout_transform = LayoutTransform::_trans_capital; \ | ||||
| return *this; \ | return *this; \ | ||||
| } \ | } \ | ||||
| @@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { | |||||
| const GraphOptimizer& GraphOptimizer::apply_optimize_options( | const GraphOptimizer& GraphOptimizer::apply_optimize_options( | ||||
| const cg::GraphCommonOptimizeOptions& options) { | const cg::GraphCommonOptimizeOptions& options) { | ||||
| bool need_param_fuse = false; | |||||
| if (options.f16_io_comp) { | if (options.f16_io_comp) { | ||||
| add_pass(ConvertF32ToF16Pass::make(false)); | add_pass(ConvertF32ToF16Pass::make(false)); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.f16_io_f32_comp) { | if (options.f16_io_f32_comp) { | ||||
| add_pass(ConvertF32ToF16Pass::make(true)); | add_pass(ConvertF32ToF16Pass::make(true)); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.transform_nhwcd4()) { | if (options.transform_nhwcd4()) { | ||||
| add_pass(ConvertFormatPass::make_nhwcd4_converter()); | add_pass(ConvertFormatPass::make_nhwcd4_converter()); | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.transform_nchw88()) { | if (options.transform_nchw88()) { | ||||
| add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.transform_nchw44()) { | if (options.transform_nchw44()) { | ||||
| add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.transform_nchw32()) { | if (options.transform_nchw32()) { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| @@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||||
| add_pass(EnableTensorCorePass::make_tensorcore_converter()); | add_pass(EnableTensorCorePass::make_tensorcore_converter()); | ||||
| add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
| add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.transform_chwn4()) { | if (options.transform_chwn4()) { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| @@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||||
| add_pass(EnableCHWN4Pass::make_chwn4_converter()); | add_pass(EnableCHWN4Pass::make_chwn4_converter()); | ||||
| add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
| add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.fuse_conv_bias_nonlinearity) { | if (options.fuse_conv_bias_nonlinearity) { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| need_param_fuse = true; | |||||
| } | } | ||||
| if (options.fuse_conv_bias_with_z) { | if (options.fuse_conv_bias_with_z) { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
| need_param_fuse = true; | |||||
| } | |||||
| if (need_param_fuse) { | |||||
| add_pass<ParamFusePass>(); | |||||
| } | } | ||||
| add_pass<ParamFusePass>(); | |||||
| return *this; | return *this; | ||||
| } | } | ||||