|
|
|
@@ -4618,6 +4618,11 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW); |
|
|
|
return ovar.node(); |
|
|
|
} else if (fmt == Format::NHWC) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(fmt == Format::NCHW64); |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
@@ -4679,6 +4684,11 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); |
|
|
|
return ovar.node(); |
|
|
|
} else if (fmt == Format::NHWC) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(fmt == Format::NCHW64); |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
@@ -4741,6 +4751,11 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); |
|
|
|
return ovar.node(); |
|
|
|
} else if (fmt == Format::NHWC) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(fmt == Format::NCHW64); |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
@@ -4800,6 +4815,11 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); |
|
|
|
return ovar.node(); |
|
|
|
} else if (fmt == Format::NHWC) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(fmt == Format::NCHW32); |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
@@ -4818,10 +4838,75 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
auto try_transform_to_nhwc = |
|
|
|
[make_new_conv, &format_map]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
// fint4XWint4 and fuint4XWint4 |
|
|
|
mgb_assert(opr->input().size()==new_inp.size()); |
|
|
|
bool check_dtype = |
|
|
|
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
new_inp[0]->dtype().enumv() == |
|
|
|
DTypeEnum::Quantized4Asymm) && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= new_inp[3]->dtype().enumv() == |
|
|
|
new_inp[0]->dtype().enumv(); |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
size_t out_channels = opr->input(1)->shape()[0]; |
|
|
|
size_t in_channels = opr->input(1)->shape()[1]; |
|
|
|
bool check_channels = out_channels % 8 == 0 && in_channels % 8 == 0; |
|
|
|
if (!check_channels) |
|
|
|
return nullptr; |
|
|
|
auto inps = new_inp; |
|
|
|
auto process = [&](size_t i) -> VarNode* { |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
const auto& fmt = iter->second; |
|
|
|
if (fmt == Format::NHWC) { |
|
|
|
return inps[i]; |
|
|
|
} else if (fmt == Format::NCHW4) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC); |
|
|
|
return ovar.node(); |
|
|
|
} else if (fmt == Format::NCHW32) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(fmt == Format::NCHW64); |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC); |
|
|
|
return ovar.node(); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
for (size_t i = 0; i < inps.size(); ++i) { |
|
|
|
inps[i] = process(i); |
|
|
|
} |
|
|
|
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
auto ret = make_new_conv(inps, &conv_bias, Format::NHWC); |
|
|
|
format_map.insert(std::make_pair(ret->owner_opr(), Format::NHWC)); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
// replace rule for conv bias opr |
|
|
|
auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4, |
|
|
|
try_transform_to_nchw32, |
|
|
|
try_transform_to_nchw64, try_transform_to_nchw]( |
|
|
|
try_transform_to_nchw64, |
|
|
|
try_transform_to_nhwc, try_transform_to_nchw]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
using Param = megdnn::param::ConvBias; |
|
|
|
@@ -4833,7 +4918,8 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
VarNode* new_var = nullptr; |
|
|
|
if ((new_var = try_transform_to_nchw32(opr, new_inp)) || |
|
|
|
(new_var = try_transform_to_nchw4(opr, new_inp)) || |
|
|
|
(new_var = try_transform_to_nchw64(opr, new_inp))|| |
|
|
|
(new_var = try_transform_to_nchw64(opr, new_inp)) || |
|
|
|
(new_var = try_transform_to_nhwc(opr, new_inp)) || |
|
|
|
(new_var = try_transform_to_nchw(opr, new_inp))) { |
|
|
|
return new_var->owner_opr(); |
|
|
|
} else { |
|
|
|
@@ -4891,6 +4977,12 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
NCHW_TO_NCHW4) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NHWC: |
|
|
|
inps[1] = RelayoutPlaceholder::make( |
|
|
|
inps[1], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW32: |
|
|
|
inps[1] = RelayoutPlaceholder::make( |
|
|
|
inps[1], RelayoutPlaceholder::LayoutType:: |
|
|
|
@@ -4991,6 +5083,9 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64), |
|
|
|
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), |
|
|
|
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), |
|
|
|
cb(NCHW, NHWC), cb(NCHW4, NHWC), cb(NCHW32, NHWC), |
|
|
|
cb(NCHW64, NHWC), cb(NHWC, NCHW), cb(NHWC, NCHW4), |
|
|
|
cb(NHWC, NCHW32), cb(NHWC, NCHW64), |
|
|
|
#undef cb |
|
|
|
}; |
|
|
|
auto inps = new_inp; |
|
|
|
@@ -5037,26 +5132,27 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
case Format::NCHW: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW_TO_NCHW64) |
|
|
|
NCHW_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW4: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW4_TO_NCHW64) |
|
|
|
NCHW4_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW32: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW32_TO_NCHW64) |
|
|
|
NCHW32_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
default: |
|
|
|
mgb_assert(cur == Format::NCHW64); |
|
|
|
mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC); |
|
|
|
} |
|
|
|
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; |
|
|
|
auto param = warp.param(); |
|
|
|
param.format = Format::NCHW64; |
|
|
|
param.format = target_format; |
|
|
|
SymbolVar new_warp; |
|
|
|
if (inps.size() == 3) { |
|
|
|
new_warp = opr::WarpPerspectiveForward::make( |
|
|
|
@@ -5069,7 +5165,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
warp.config()); |
|
|
|
} |
|
|
|
auto ret = new_warp.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, Format::NCHW64)); |
|
|
|
format_map.insert(std::make_pair(ret, target_format)); |
|
|
|
return ret; |
|
|
|
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
Format cur; |
|
|
|
@@ -5087,6 +5183,12 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
NCHW_TO_NCHW4) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NHWC: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NHWC_TO_NCHW4) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW32: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
@@ -5154,31 +5256,31 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
case Format::NCHW: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW_TO_NCHW64) |
|
|
|
NCHW_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW4: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW4_TO_NCHW64) |
|
|
|
NCHW4_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW32: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW32_TO_NCHW64) |
|
|
|
NCHW32_TO_NHWC) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
default: |
|
|
|
mgb_assert(cur == Format::NCHW64); |
|
|
|
mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC); |
|
|
|
} |
|
|
|
|
|
|
|
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; |
|
|
|
auto param = pooling.param(); |
|
|
|
param.format = Format::NCHW64; |
|
|
|
param.format = target_format; |
|
|
|
auto new_pool = |
|
|
|
opr::PoolingForward::make(inps[0], param, pooling.config()); |
|
|
|
auto ret = new_pool.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, Format::NCHW64)); |
|
|
|
format_map.insert(std::make_pair(ret, target_format)); |
|
|
|
return ret; |
|
|
|
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
Format cur; |
|
|
|
@@ -5188,12 +5290,12 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
} else { |
|
|
|
cur = iter->second; |
|
|
|
} |
|
|
|
size_t in_channels = new_inp[0]->shape()[1]; |
|
|
|
bool use_nchw32 = false; |
|
|
|
auto inps = new_inp; |
|
|
|
using LayoutType = RelayoutPlaceholder::LayoutType; |
|
|
|
switch (cur) { |
|
|
|
case Format::NCHW: { |
|
|
|
size_t in_channels = new_inp[0]->shape()[1]; |
|
|
|
use_nchw32 = in_channels % 32 == 0; |
|
|
|
auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32 |
|
|
|
: LayoutType::NCHW_TO_NCHW4; |
|
|
|
@@ -5201,6 +5303,15 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
} |
|
|
|
case Format::NHWC: { |
|
|
|
size_t in_channels = new_inp[0]->shape()[3]; |
|
|
|
use_nchw32 = in_channels % 32 == 0; |
|
|
|
auto layout_type = use_nchw32 ? LayoutType::NHWC_TO_NCHW32 |
|
|
|
: LayoutType::NHWC_TO_NCHW4; |
|
|
|
inps[0] = RelayoutPlaceholder::make(inps[0], layout_type) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
} |
|
|
|
case Format::NCHW64: |
|
|
|
inps[0] = RelayoutPlaceholder::make( |
|
|
|
inps[0], RelayoutPlaceholder::LayoutType:: |
|
|
|
@@ -5253,6 +5364,13 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
auto fmt = iter != format_map.end()?iter->second:Format::NCHW; |
|
|
|
if (iter != format_map.end()) { |
|
|
|
switch (fmt) { |
|
|
|
case Format::NHWC: |
|
|
|
inps[i] = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType:: |
|
|
|
NHWC_TO_NCHW) |
|
|
|
.node(); |
|
|
|
break; |
|
|
|
case Format::NCHW4: |
|
|
|
inps[i] = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
|