GitOrigin-RevId: a058776be3
tags/v1.9.0
| @@ -1038,7 +1038,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
| 'NCHW_NCHW64 = 27', | 'NCHW_NCHW64 = 27', | ||||
| 'NCHW64_NCHW = 28', | 'NCHW64_NCHW = 28', | ||||
| 'NCHW_NHWC = 29', | 'NCHW_NHWC = 29', | ||||
| 'NHWC_NCHW = 30', | |||||
| 'NHWC_NCHW = 30', | |||||
| 'NHWCD4I_NHWC = 31', | |||||
| ) | ) | ||||
| ) | ) | ||||
| @@ -114,6 +114,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds | |||||
| dst[3] = src[2]; | dst[3] = src[2]; | ||||
| dst[4] = 4; | dst[4] = 4; | ||||
| break; | break; | ||||
| case Param::Mode::NHWCD4I_NHWC: | |||||
| case Param::Mode::NHWCD4_NHWC: | case Param::Mode::NHWCD4_NHWC: | ||||
| megdnn_assert(src.ndim == 5); | megdnn_assert(src.ndim == 5); | ||||
| dst.ndim = 4; | dst.ndim = 4; | ||||
| @@ -331,6 +332,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||||
| CHECK_SRC(DefaultTensorFormat::make()); | CHECK_SRC(DefaultTensorFormat::make()); | ||||
| dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); | dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); | ||||
| break; | break; | ||||
| case Param::Mode::NHWCD4I_NHWC: | |||||
| case Param::Mode::NHWCD4I_NCHW: | case Param::Mode::NHWCD4I_NCHW: | ||||
| CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); | CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); | ||||
| dst = DefaultTensorFormat::make(); | dst = DefaultTensorFormat::make(); | ||||
| @@ -594,6 +596,7 @@ void RelayoutFormat::deduce_exec_layout( | |||||
| .dimshuffle({0, 1, 3, 2, 4}); | .dimshuffle({0, 1, 3, 2, 4}); | ||||
| exec_dst = dst; | exec_dst = dst; | ||||
| break; | break; | ||||
| case Param::Mode::NHWCD4I_NHWC: | |||||
| case Param::Mode::NHWCD4_NHWC: | case Param::Mode::NHWCD4_NHWC: | ||||
| // src is {N, H, CB, W, 4} | // src is {N, H, CB, W, 4} | ||||
| // dst is {N, H, W, C}, | // dst is {N, H, W, C}, | ||||
| @@ -1002,7 +1002,9 @@ void ConvertFormatPass::apply(OptState& state) const { | |||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| //! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) + | //! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) + | ||||
| //! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4) | |||||
| //! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4). Merge | |||||
| //! consecutive relayout_format(NHWCD4 -> NCHW) + dimshuffle(NCHW -> NHWC) to one | |||||
| //! relayout_format(NHWCD4 -> NHWC). | |||||
| auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) { | auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) { | ||||
| auto opr_is_relayout = [](OperatorNodeBase* opr) { | auto opr_is_relayout = [](OperatorNodeBase* opr) { | ||||
| return opr->try_cast_final<opr::RelayoutFormat>(); | return opr->try_cast_final<opr::RelayoutFormat>(); | ||||
| @@ -1019,23 +1021,48 @@ void ConvertFormatPass::apply(OptState& state) const { | |||||
| } | } | ||||
| return false; | return false; | ||||
| }; | }; | ||||
| auto this_opr_is_relayout = opr_is_relayout(opr); | |||||
| auto prev_opr_is_dimshuffle = static_cast<opr::Dimshuffle*>(nullptr); | |||||
| if (this_opr_is_relayout) { | |||||
| prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); | |||||
| } | |||||
| if (this_opr_is_relayout && prev_opr_is_dimshuffle) { | |||||
| if (this_opr_is_relayout->param().mode == | |||||
| megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && | |||||
| match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { | |||||
| auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0)); | |||||
| auto new_param = megdnn::param::RelayoutFormat(); | |||||
| new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I; | |||||
| auto new_opr = opr::RelayoutFormat::make(inp, new_param); | |||||
| rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); | |||||
| //! dimshuffle + relayout_format | |||||
| { | |||||
| auto this_opr_is_relayout = opr_is_relayout(opr); | |||||
| auto prev_opr_is_dimshuffle = static_cast<opr::Dimshuffle*>(nullptr); | |||||
| if (this_opr_is_relayout) { | |||||
| prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); | |||||
| } | |||||
| if (this_opr_is_relayout && prev_opr_is_dimshuffle) { | |||||
| //! megengine only accept NCHW input | |||||
| if (this_opr_is_relayout->param().mode == | |||||
| megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && | |||||
| match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { | |||||
| auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0)); | |||||
| auto new_param = megdnn::param::RelayoutFormat(); | |||||
| new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I; | |||||
| auto new_opr = opr::RelayoutFormat::make(inp, new_param); | |||||
| rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); | |||||
| } | |||||
| } else { | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| } | |||||
| } | |||||
| //! relayout_format + dimshuffle | |||||
| { | |||||
| auto this_opr_is_dimshuffle = opr_is_dimshuffle(opr); | |||||
| auto prev_opr_is_relayout = static_cast<opr::RelayoutFormat*>(nullptr); | |||||
| if (this_opr_is_dimshuffle) { | |||||
| prev_opr_is_relayout = opr_is_relayout(opr->input(0)->owner_opr()); | |||||
| } | |||||
| if (this_opr_is_dimshuffle && prev_opr_is_relayout) { | |||||
| if (prev_opr_is_relayout->param().mode == | |||||
| megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW && | |||||
| match_pattern(this_opr_is_dimshuffle->param(), {0, 2, 3, 1})) { | |||||
| auto inp = rewriter.get_var(prev_opr_is_relayout->input(0)); | |||||
| auto new_param = megdnn::param::RelayoutFormat(); | |||||
| new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NHWC; | |||||
| auto new_opr = opr::RelayoutFormat::make(inp, new_param); | |||||
| rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); | |||||
| } | |||||
| } else { | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| } | } | ||||
| } else { | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| } | } | ||||
| }; | }; | ||||
| state.graph().iter(on_opr_merge); | state.graph().iter(on_opr_merge); | ||||
| @@ -1365,6 +1365,71 @@ TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) { | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | ||||
| } | } | ||||
| TEST(TestGoptInference, MergeRelayoutFormatAndDimShuffle) { | |||||
| // hwcd4 is only supported in naive handle | |||||
| NaiveMegDNNHandleScope naive_megdnn_handle; | |||||
| HostTensorGenerator<> gen; | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
| }; | |||||
| auto host_x = gen({2, 8, 16, 32}, cn); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
| auto a = mkvar("a", {1}); | |||||
| auto b = mkvar("b", {1}); | |||||
| auto z = x * a + b; | |||||
| //! to NHWC | |||||
| auto y = opr::Dimshuffle::make(z, {0, 2, 3, 1}); | |||||
| SymbolVar y_opt; | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nhwcd4(); | |||||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
| ASSERT_EQ(0, find_opr_num<opr::Dimshuffle>(y_opt)); | |||||
| auto check = [](SymbolVar endpoint) -> bool { | |||||
| bool valid = true; | |||||
| auto cb = [&](cg::OperatorNodeBase* opr) { | |||||
| if (opr->same_type<opr::RelayoutFormat>()) { | |||||
| auto mode = opr->try_cast_final<opr::RelayoutFormat>()->param().mode; | |||||
| //! The first relayout_format opr's mode is NCHW_NHWCD4I. The second is | |||||
| //! NHWCD4I_NHWC | |||||
| if (mode == megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I || | |||||
| mode == megdnn::param::RelayoutFormat::Mode::NHWCD4I_NHWC) { | |||||
| valid &= true; | |||||
| } else { | |||||
| valid &= false; | |||||
| } | |||||
| } | |||||
| }; | |||||
| cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||||
| return valid; | |||||
| }; | |||||
| ASSERT_EQ(true, check(y_opt)); | |||||
| graph->compile({{y_opt, {}}}) | |||||
| ->to_json() | |||||
| ->writeto_fpath(output_file( | |||||
| "TestGoptInference.MergeRelayoutFormatAndDimShuffle.json")); | |||||
| HostTensorND host_y; | |||||
| HostTensorND host_y_opt; | |||||
| auto func = graph->compile( | |||||
| {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||||
| *host_x = *gen({8, 8, 16, 16}, cn); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||||
| } | |||||
| TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { | TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { | ||||
| // hwcd4 is only supported in naive handle | // hwcd4 is only supported in naive handle | ||||
| NaiveMegDNNHandleScope naive_megdnn_handle; | NaiveMegDNNHandleScope naive_megdnn_handle; | ||||