GitOrigin-RevId: 9faa7ef068
tags/v0.5.0
| @@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) { | |||
| continue; | |||
| } | |||
| #endif | |||
| if (!strcmp(argv[i], "--enable-chwn4")) { | |||
| mgb_log_warn("enable chwn4 optimization"); | |||
| graph_opt.graph_opt.enable_chwn4(); | |||
| #define cb(_layout) \ | |||
| if (!strcmp(argv[i], "--enable-" #_layout)) { \ | |||
| mgb_log_warn("enable " #_layout " optimization"); \ | |||
| graph_opt.graph_opt.enable_##_layout(); \ | |||
| continue; \ | |||
| } | |||
| cb(chwn4); | |||
| cb(nchw44); | |||
| cb(nchw88); | |||
| cb(nchw32); | |||
| cb(nhwcd4); | |||
| #undef cb | |||
| if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | |||
| mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | |||
| graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||
| continue; | |||
| } | |||
| if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) { | |||
| mgb_log_warn("enable fuse_conv_bias_with_z optimization"); | |||
| graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); | |||
| continue; | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| @@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
| options().graph_opt.winograd_transform = false; | |||
| gopt::transform_vars_inplace_with_winograd(dest_vars); | |||
| } | |||
| if (options().graph_opt.transform_chwn4()) { | |||
| gopt::GraphOptimizer optimizer; | |||
| optimizer.apply_optimize_options(options().graph_opt); | |||
| options().graph_opt.layout_transform = | |||
| cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT; | |||
| optimizer.apply_inplace(dest_vars); | |||
| } | |||
| #if MGB_JIT | |||
| if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | |||
| @@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
| optimizer.apply_inplace(dest_vars); | |||
| } | |||
| #endif | |||
| gopt::GraphOptimizer optimizer; | |||
| optimizer.apply_optimize_options(options().graph_opt); | |||
| options().graph_opt.reset(); | |||
| optimizer.apply_inplace(dest_vars); | |||
| const OprNodeArray* opr_seq = nullptr; | |||
| CompSeqExtraInfo extra_info; | |||
| @@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions { | |||
| bool f16_io_comp = false; | |||
| //! whether to enable conv bias nonlinearity fusion | |||
| bool fuse_conv_bias_nonlinearity = false; | |||
| //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||
| //! + z -> conv_bias(x, w, b, z) | |||
| bool fuse_conv_bias_with_z = false; | |||
| enum LayoutTransform : uint32_t { | |||
| DEFAULT, | |||
| NHWCD4, ///< compute using NHWCD4 tensor format | |||
| @@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions { | |||
| ///< used for cuda | |||
| }; | |||
| LayoutTransform layout_transform = LayoutTransform::DEFAULT; | |||
| //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||
| //! + z -> conv_bias(x, w, b, z) | |||
| bool fuse_conv_bias_with_z = false; | |||
| void reset() { | |||
| f16_io_f32_comp = false; | |||
| f16_io_comp = false; | |||
| fuse_conv_bias_nonlinearity = false; | |||
| fuse_conv_bias_with_z = false; | |||
| layout_transform = LayoutTransform::DEFAULT; | |||
| } | |||
| #define SET(n) \ | |||
| GraphCommonOptimizeOptions& enable_##n() { \ | |||
| @@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions { | |||
| #undef SET | |||
| #define SET(_trans, _trans_capital) \ | |||
| GraphCommonOptimizeOptions& enable_##_trans() { \ | |||
| mgb_assert(layout_transform == LayoutTransform::DEFAULT); \ | |||
| layout_transform = LayoutTransform::_trans_capital; \ | |||
| return *this; \ | |||
| } \ | |||
| @@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { | |||
| const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||
| const cg::GraphCommonOptimizeOptions& options) { | |||
| bool need_param_fuse = false; | |||
| if (options.f16_io_comp) { | |||
| add_pass(ConvertF32ToF16Pass::make(false)); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.f16_io_f32_comp) { | |||
| add_pass(ConvertF32ToF16Pass::make(true)); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.transform_nhwcd4()) { | |||
| add_pass(ConvertFormatPass::make_nhwcd4_converter()); | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.transform_nchw88()) { | |||
| add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.transform_nchw44()) { | |||
| add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.transform_nchw32()) { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| @@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||
| add_pass(EnableTensorCorePass::make_tensorcore_converter()); | |||
| add_pass<ShuffleShuffleRemovePass>(); | |||
| add_pass<RemoveRedundantTypeCvtPass>(); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.transform_chwn4()) { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| @@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||
| add_pass(EnableCHWN4Pass::make_chwn4_converter()); | |||
| add_pass<ShuffleShuffleRemovePass>(); | |||
| add_pass<RemoveRedundantTypeCvtPass>(); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.fuse_conv_bias_nonlinearity) { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| need_param_fuse = true; | |||
| } | |||
| if (options.fuse_conv_bias_with_z) { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass<FuseConvBiasZPass>(); | |||
| need_param_fuse = true; | |||
| } | |||
| if (need_param_fuse) { | |||
| add_pass<ParamFusePass>(); | |||
| } | |||
| add_pass<ParamFusePass>(); | |||
| return *this; | |||
| } | |||