diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index aebedd24..7b35d7fd 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -4618,6 +4618,11 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[i], RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW); return ovar.node(); + } else if (fmt == Format::NHWC) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW); + return ovar.node(); } else { mgb_assert(fmt == Format::NCHW64); auto ovar = RelayoutPlaceholder::make( @@ -4679,6 +4684,11 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[i], RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); return ovar.node(); + } else if (fmt == Format::NHWC) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4); + return ovar.node(); } else { mgb_assert(fmt == Format::NCHW64); auto ovar = RelayoutPlaceholder::make( @@ -4741,6 +4751,11 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[i], RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); return ovar.node(); + } else if (fmt == Format::NHWC) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32); + return ovar.node(); } else { mgb_assert(fmt == Format::NCHW64); auto ovar = RelayoutPlaceholder::make( @@ -4800,6 +4815,11 @@ EnableNCHW64Pass::make_nchw64_converter() { inps[i], RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); return ovar.node(); + } else if (fmt == Format::NHWC) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); + return ovar.node(); } else { mgb_assert(fmt == Format::NCHW32); auto ovar = RelayoutPlaceholder::make( @@ -4818,10 +4838,75 @@ EnableNCHW64Pass::make_nchw64_converter() { return ret; }; + auto try_transform_to_nhwc = + [make_new_conv, &format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + // fint4XWint4 and fuint4XWint4 + mgb_assert(opr->input().size()==new_inp.size()); + bool check_dtype = + (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || + new_inp[0]->dtype().enumv() == + DTypeEnum::Quantized4Asymm) && + new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; + if (opr->input().size() >= 3) + check_dtype &= + new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; + if (opr->input().size() >= 4) + check_dtype &= new_inp[3]->dtype().enumv() == + new_inp[0]->dtype().enumv(); + if (!check_dtype) + return nullptr; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + bool check_channels = out_channels % 8 == 0 && in_channels % 8 == 0; + if (!check_channels) + return nullptr; + auto inps = new_inp; + auto process = [&](size_t i) -> VarNode* { + auto iter = format_map.find(new_inp[i]->owner_opr()); + if (iter == format_map.end()) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); + return ovar.node(); + } else { + const auto& fmt = iter->second; + if (fmt == Format::NHWC) { + return inps[i]; + } else if (fmt == Format::NCHW4) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC); + return ovar.node(); + } else if (fmt == Format::NCHW32) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC); + return ovar.node(); + } else { + mgb_assert(fmt == Format::NCHW64); + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC); + return ovar.node(); + } + } + }; + for (size_t i = 0; i < inps.size(); ++i) { + inps[i] = process(i); + } + auto& conv_bias = opr->cast_final_safe(); + auto ret = make_new_conv(inps, &conv_bias, Format::NHWC); + format_map.insert(std::make_pair(ret->owner_opr(), Format::NHWC)); + return ret; + }; + // replace rule for conv bias opr auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4, try_transform_to_nchw32, - try_transform_to_nchw64, try_transform_to_nchw]( + try_transform_to_nchw64, + try_transform_to_nhwc, try_transform_to_nchw]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { using Param = megdnn::param::ConvBias; @@ -4833,7 +4918,8 @@ EnableNCHW64Pass::make_nchw64_converter() { VarNode* new_var = nullptr; if ((new_var = try_transform_to_nchw32(opr, new_inp)) || (new_var = try_transform_to_nchw4(opr, new_inp)) || - (new_var = try_transform_to_nchw64(opr, new_inp))|| + (new_var = try_transform_to_nchw64(opr, new_inp)) || + (new_var = try_transform_to_nhwc(opr, new_inp)) || (new_var = try_transform_to_nchw(opr, new_inp))) { return new_var->owner_opr(); } else { @@ -4891,6 +4977,12 @@ EnableNCHW64Pass::make_nchw64_converter() { NCHW_TO_NCHW4) .node(); break; + case Format::NHWC: + inps[1] = RelayoutPlaceholder::make( + inps[1], RelayoutPlaceholder::LayoutType:: + NCHW_TO_NHWC) + .node(); + break; case Format::NCHW32: inps[1] = RelayoutPlaceholder::make( inps[1], RelayoutPlaceholder::LayoutType:: @@ -4991,6 +5083,9 @@ EnableNCHW64Pass::make_nchw64_converter() { cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64), cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), + cb(NCHW, NHWC), cb(NCHW4, NHWC), cb(NCHW32, NHWC), + cb(NCHW64, NHWC), cb(NHWC, NCHW), cb(NHWC, NCHW4), + cb(NHWC, NCHW32), cb(NHWC, NCHW64), #undef cb }; auto inps = new_inp; @@ -5037,26 +5132,27 @@ EnableNCHW64Pass::make_nchw64_converter() { case Format::NCHW: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NCHW64) + NCHW_TO_NHWC) .node(); break; case Format::NCHW4: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NCHW64) + NCHW4_TO_NHWC) .node(); break; case Format::NCHW32: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW64) + NCHW32_TO_NHWC) .node(); break; default: - mgb_assert(cur == Format::NCHW64); + mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC); } + auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; auto param = warp.param(); - param.format = Format::NCHW64; + param.format = target_format; SymbolVar new_warp; if (inps.size() == 3) { new_warp = opr::WarpPerspectiveForward::make( @@ -5069,7 +5165,7 @@ EnableNCHW64Pass::make_nchw64_converter() { warp.config()); } auto ret = new_warp.node()->owner_opr(); - format_map.insert(std::make_pair(ret, Format::NCHW64)); + format_map.insert(std::make_pair(ret, target_format)); return ret; } else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { Format cur; @@ -5087,6 +5183,12 @@ EnableNCHW64Pass::make_nchw64_converter() { NCHW_TO_NCHW4) .node(); break; + case Format::NHWC: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NHWC_TO_NCHW4) + .node(); + break; case Format::NCHW32: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: @@ -5154,31 +5256,31 @@ EnableNCHW64Pass::make_nchw64_converter() { case Format::NCHW: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NCHW64) + NCHW_TO_NHWC) .node(); break; case Format::NCHW4: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NCHW64) + NCHW4_TO_NHWC) .node(); break; case Format::NCHW32: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW64) + NCHW32_TO_NHWC) .node(); break; default: - mgb_assert(cur == Format::NCHW64); + mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC); } - + auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; auto param = pooling.param(); - param.format = Format::NCHW64; + param.format = target_format; auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); auto ret = new_pool.node()->owner_opr(); - format_map.insert(std::make_pair(ret, Format::NCHW64)); + format_map.insert(std::make_pair(ret, target_format)); return ret; } else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { Format cur; @@ -5188,12 +5290,12 @@ EnableNCHW64Pass::make_nchw64_converter() { } else { cur = iter->second; } - size_t in_channels = new_inp[0]->shape()[1]; bool use_nchw32 = false; auto inps = new_inp; using LayoutType = RelayoutPlaceholder::LayoutType; switch (cur) { case Format::NCHW: { + size_t in_channels = new_inp[0]->shape()[1]; use_nchw32 = in_channels % 32 == 0; auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32 : LayoutType::NCHW_TO_NCHW4; @@ -5201,6 +5303,15 @@ EnableNCHW64Pass::make_nchw64_converter() { .node(); break; } + case Format::NHWC: { + size_t in_channels = new_inp[0]->shape()[3]; + use_nchw32 = in_channels % 32 == 0; + auto layout_type = use_nchw32 ? LayoutType::NHWC_TO_NCHW32 + : LayoutType::NHWC_TO_NCHW4; + inps[0] = RelayoutPlaceholder::make(inps[0], layout_type) + .node(); + break; + } case Format::NCHW64: inps[0] = RelayoutPlaceholder::make( inps[0], RelayoutPlaceholder::LayoutType:: @@ -5253,6 +5364,13 @@ EnableNCHW64Pass::make_nchw64_converter() { auto fmt = iter != format_map.end()?iter->second:Format::NCHW; if (iter != format_map.end()) { switch (fmt) { + case Format::NHWC: + inps[i] = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType:: + NHWC_TO_NCHW) + .node(); + break; case Format::NCHW4: inps[i] = RelayoutPlaceholder::make( inps[i],