GitOrigin-RevId: 4d1a9c6c84
tags/v0.5.0
| @@ -542,7 +542,8 @@ def optimize_for_inference( | |||||
| use_nchw32=False, | use_nchw32=False, | ||||
| fuse_conv_bias_with_z=False, | fuse_conv_bias_with_z=False, | ||||
| use_nchw88=False, | use_nchw88=False, | ||||
| use_nchw44=False | |||||
| use_nchw44=False, | |||||
| use_chwn4=False | |||||
| ): | ): | ||||
| """optimize computing graph for inference | """optimize computing graph for inference | ||||
| @@ -566,6 +567,8 @@ def optimize_for_inference( | |||||
| times. | times. | ||||
| :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for | :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for | ||||
| nvidia tensorcore. | nvidia tensorcore. | ||||
| :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for | |||||
| nvidia tensorcore. | |||||
| :return: list of transformed vars corresponding to given output vars | :return: list of transformed vars corresponding to given output vars | ||||
| @@ -589,6 +592,7 @@ def optimize_for_inference( | |||||
| "use_nchw32": "nchw2nchw32", | "use_nchw32": "nchw2nchw32", | ||||
| "use_nchw88": "nchw2nchw88", | "use_nchw88": "nchw2nchw88", | ||||
| "use_nchw44": "nchw2nchw44", | "use_nchw44": "nchw2nchw44", | ||||
| "use_chwn4": "nchw42chwn4", | |||||
| }.items(): | }.items(): | ||||
| if settings[k]: | if settings[k]: | ||||
| assert ( | assert ( | ||||
| @@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions { | |||||
| SET(nchw2nchw88, NCHW2NCHW88); | SET(nchw2nchw88, NCHW2NCHW88); | ||||
| SET(nchw2nchw44, NCHW2NCHW44); | SET(nchw2nchw44, NCHW2NCHW44); | ||||
| SET(nchw2nchw32, NCHW2NCHW32); | SET(nchw2nchw32, NCHW2NCHW32); | ||||
| SET(nchw42chwn4, NCHW42CHWN4); | |||||
| #undef SET | #undef SET | ||||
| }; | }; | ||||
| @@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs): | |||||
| 'enable_hwcd4': 'use_nhwcd4', | 'enable_hwcd4': 'use_nhwcd4', | ||||
| 'enable_nchw88': 'use_nchw88', | 'enable_nchw88': 'use_nchw88', | ||||
| 'enable_nchw44': 'use_nchw44', | 'enable_nchw44': 'use_nchw44', | ||||
| 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', | |||||
| 'enable_nchw32': 'use_nchw32', | 'enable_nchw32': 'use_nchw32', | ||||
| 'enable_chwn4': 'use_chwn4', | |||||
| 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', | |||||
| 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', | 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', | ||||
| } | } | ||||
| kwargs = {} | kwargs = {} | ||||
| @@ -398,6 +399,12 @@ def main(): | |||||
| help='transform the model format from NCHW4 to NCHW32 ' | help='transform the model format from NCHW4 to NCHW32 ' | ||||
| 'for inference on nvidia TensoCore' | 'for inference on nvidia TensoCore' | ||||
| ) | ) | ||||
| parser.add_argument( | |||||
| '--enable-chwn4', | |||||
| action='store_true', | |||||
| help='transform the model format to CHWN4 ' | |||||
| 'for inference, mainly used for nvidia tensorcore' | |||||
| ) | |||||
| parser.add_argument( | parser.add_argument( | ||||
| '--enable-fuse-conv-bias-with-z', | '--enable-fuse-conv-bias-with-z', | ||||
| action='store_true', | action='store_true', | ||||
| @@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options( | |||||
| add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
| add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
| } | } | ||||
| if (options->transform_nchw42chwn4()) { | |||||
| add_pass<FuseConvBiasNonlinPass>(); | |||||
| add_pass<FuseConvBiasZPass>(); | |||||
| add_pass(EnableCHWN4Pass::make_chwn4_converter()); | |||||
| add_pass<ShuffleShuffleRemovePass>(); | |||||
| add_pass<RemoveRedundantTypeCvtPass>(); | |||||
| } | |||||
| if (options->fuse_conv_bias_nonlinearity) { | if (options->fuse_conv_bias_nonlinearity) { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| @@ -395,6 +395,8 @@ namespace gopt { | |||||
| NCHW2NCHW44, ///< compute using NCHW44 tensor format | NCHW2NCHW44, ///< compute using NCHW44 tensor format | ||||
| NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for | NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for | ||||
| ///< tensorcore | ///< tensorcore | ||||
| NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed | |||||
| ///< from NCHW4, mainly used for cuda | |||||
| }; | }; | ||||
| LayoutTransform layout_transform = LayoutTransform::DEFAULT; | LayoutTransform layout_transform = LayoutTransform::DEFAULT; | ||||
| //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | ||||
| @@ -424,6 +426,7 @@ namespace gopt { | |||||
| SET(nchw2nchw88, NCHW2NCHW88); | SET(nchw2nchw88, NCHW2NCHW88); | ||||
| SET(nchw2nchw44, NCHW2NCHW44); | SET(nchw2nchw44, NCHW2NCHW44); | ||||
| SET(nchw2nchw32, NCHW2NCHW32); | SET(nchw2nchw32, NCHW2NCHW32); | ||||
| SET(nchw42chwn4, NCHW42CHWN4); | |||||
| #undef SET | #undef SET | ||||
| }; | }; | ||||
| @@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) { | |||||
| y4 = opr::TypeCvt::make(y4, dtype::Float32()); | y4 = opr::TypeCvt::make(y4, dtype::Float32()); | ||||
| SymbolVar y_opt; | SymbolVar y_opt; | ||||
| SymbolVar y_cudnn; | SymbolVar y_cudnn; | ||||
| unpack_vector( | |||||
| gopt::GraphOptimizer{} | |||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
| .add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) | |||||
| .add_pass<gopt::FuseConvBiasZPass>() | |||||
| .apply({{y4}}) | |||||
| .endpoint_vars(), | |||||
| y_opt); | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nchw42chwn4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); | |||||
| } | |||||
| unpack_vector(gopt::GraphOptimizer{} | unpack_vector(gopt::GraphOptimizer{} | ||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | .add_pass<gopt::FuseConvBiasNonlinPass>() | ||||
| .add_pass<gopt::FuseConvBiasZPass>() | .add_pass<gopt::FuseConvBiasZPass>() | ||||
| @@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { | |||||
| auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); | auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); | ||||
| SymbolVar y_opt; | SymbolVar y_opt; | ||||
| SymbolVar y_cudnn; | SymbolVar y_cudnn; | ||||
| unpack_vector(gopt::GraphOptimizer{} | |||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | |||||
| .add_pass<gopt::FuseConvBiasZPass>() | |||||
| .add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) | |||||
| .apply({{y2}}) | |||||
| .endpoint_vars(), | |||||
| y_opt); | |||||
| { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nchw42chwn4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); | |||||
| } | |||||
| unpack_vector(gopt::GraphOptimizer{} | unpack_vector(gopt::GraphOptimizer{} | ||||
| .add_pass<gopt::FuseConvBiasNonlinPass>() | .add_pass<gopt::FuseConvBiasNonlinPass>() | ||||
| .add_pass<gopt::FuseConvBiasZPass>() | .add_pass<gopt::FuseConvBiasZPass>() | ||||