| @@ -194,6 +194,26 @@ R"__usage__( | |||||
| Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used | Execute operators with kernels implemented in MegDNN with CHWN4 tensor format. Can only be used | ||||
| on Nvidia GPUs, whose compute capability is above 6.1. | on Nvidia GPUs, whose compute capability is above 6.1. | ||||
| )__usage__" | )__usage__" | ||||
| R"__usage__( | |||||
| --enable-nchw44 | |||||
| Execute operators with kernels implemented in MegDNN with NCHW44 tensor format. This can only | |||||
| be used on arm of armv7 and arm64, support data tyep of float32, qint8 and int8x8x16. | |||||
| )__usage__" | |||||
| R"__usage__( | |||||
| --enable-nhw88 | |||||
| Execute operators with kernels implemented in MegDNN with NCHW88 tensor format. This can only | |||||
| be used on x86 with data type float. | |||||
| )__usage__" | |||||
| R"__usage__( | |||||
| --enable-nhw44-dot | |||||
| Execute operators with kernels implemented in MegDNN with NCHW44-DOT tensor format. This Can | |||||
| only be used on arm32 and arm64 with dot-product supported, and only support qint8 model | |||||
| )__usage__" | |||||
| R"__usage__( | |||||
| --weight-preprocess | |||||
| Execute operators with weight preprocess, which can optimize the operator execution time with | |||||
| algo of winograd, im2col ,etc., but it may consume more memory. | |||||
| )__usage__" | |||||
| ; | ; | ||||
| @@ -1226,6 +1246,11 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| graph_opt.graph_opt.weight_winograd_transform = true; | graph_opt.graph_opt.weight_winograd_transform = true; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (!strcmp(argv[i], "--weight-preprocess")) { | |||||
| mgb_log_warn("enable weight-preprocess optimization"); | |||||
| graph_opt.graph_opt.enable_weight_preprocess(); | |||||
| continue; | |||||
| } | |||||
| fprintf(stderr, "invalid arg: %s\n", argv[i]); | fprintf(stderr, "invalid arg: %s\n", argv[i]); | ||||
| ret.args_parse_ret = -1; | ret.args_parse_ret = -1; | ||||
| @@ -97,6 +97,9 @@ struct GraphCommonOptimizeOptions { | |||||
| bool fuse_conv_bias_with_z = false; | bool fuse_conv_bias_with_z = false; | ||||
| //! whether to enable fast-run profiled winograd opr replace | //! whether to enable fast-run profiled winograd opr replace | ||||
| bool weight_winograd_transform = false; | bool weight_winograd_transform = false; | ||||
| //! whether to enable weight preprocess, if enabled it may use more | |||||
| //! memory, default disable now | |||||
| bool weight_preprocess = false; | |||||
| enum LayoutTransform : uint32_t { | enum LayoutTransform : uint32_t { | ||||
| DEFAULT, | DEFAULT, | ||||
| NCHW4, ///< compute using NCHW4 tensor format | NCHW4, ///< compute using NCHW4 tensor format | ||||
| @@ -127,6 +130,7 @@ struct GraphCommonOptimizeOptions { | |||||
| SET(fuse_conv_bias_nonlinearity); | SET(fuse_conv_bias_nonlinearity); | ||||
| SET(fuse_conv_bias_with_z); | SET(fuse_conv_bias_with_z); | ||||
| SET(weight_winograd_transform); | SET(weight_winograd_transform); | ||||
| SET(weight_preprocess); | |||||
| #undef SET | #undef SET | ||||
| #define SET(_trans, _trans_capital) \ | #define SET(_trans, _trans_capital) \ | ||||
| GraphCommonOptimizeOptions& enable_##_trans() { \ | GraphCommonOptimizeOptions& enable_##_trans() { \ | ||||
| @@ -963,6 +963,9 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( | |||||
| bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | ||||
| const cg::OperatorNodeBase& opr) const { | const cg::OperatorNodeBase& opr) const { | ||||
| if (!opr.owner_graph()->options().graph_opt.weight_preprocess) { | |||||
| return false; | |||||
| } | |||||
| if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) | if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) | ||||
| return false; | return false; | ||||
| if (cg::is_const_var_value(opr.input(1))) | if (cg::is_const_var_value(opr.input(1))) | ||||
| @@ -2225,6 +2225,7 @@ protected: | |||||
| iw = ih; | iw = ih; | ||||
| comp_node = CompNode::load("cpux"); | comp_node = CompNode::load("cpux"); | ||||
| graph = ComputingGraph::make(); | graph = ComputingGraph::make(); | ||||
| graph->options().graph_opt.weight_preprocess = is_weight_preprocess(); | |||||
| TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh}; | TensorShape x_shape{1, ic, ih, iw}, w_shape{oc, ic, fh, fh}; | ||||
| x_host = std::make_shared<HostTensorND>(comp_node, x_shape); | x_host = std::make_shared<HostTensorND>(comp_node, x_shape); | ||||
| auto x = opr::Host2DeviceCopy::make(*graph, x_host); | auto x = opr::Host2DeviceCopy::make(*graph, x_host); | ||||
| @@ -2247,6 +2248,8 @@ protected: | |||||
| void run() { func->execute().wait(); } | void run() { func->execute().wait(); } | ||||
| virtual bool is_weight_preprocess() { return true; } | |||||
| void TearDown() override { | void TearDown() override { | ||||
| func.reset(); | func.reset(); | ||||
| // Triggers mock check | // Triggers mock check | ||||
| @@ -2346,6 +2349,33 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) { | |||||
| } | } | ||||
| } | } | ||||
| class TestNoWeightPreprocess : public TestWeightPreprocess { | |||||
| bool is_weight_preprocess() override { return false; } | |||||
| }; | |||||
| TEST_F(TestNoWeightPreprocess, NoPreprocess) { | |||||
| using ::testing::_; | |||||
| using ::testing::Return; | |||||
| auto& mock = mock_conv(); | |||||
| MockAlgorithm algo; | |||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) | |||||
| .WillRepeatedly(Return(&algo)); | |||||
| EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _)) | |||||
| .WillRepeatedly(Return(0)); | |||||
| EXPECT_CALL(mock, get_preprocess_workspace_in_bytes(_, _, _)) | |||||
| .WillRepeatedly(Return(0)); | |||||
| { | |||||
| ::testing::InSequence seq; | |||||
| // Return empty preprocess filters, indicating no need to preprocess | |||||
| EXPECT_CALL(mock, deduce_preprocessed_filter_layout(_, _, _)).Times(0); | |||||
| EXPECT_CALL(mock, exec_preprocess(_, _, _, _, _)).Times(0); | |||||
| EXPECT_CALL(mock, exec(_, _, _, nullptr, _)); | |||||
| run(); | |||||
| } | |||||
| } | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| #endif | #endif | ||||