|
|
@@ -420,7 +420,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
dst[4] = 32; |
|
|
dst[4] = 32; |
|
|
} else if (layout_type() == |
|
|
} else if (layout_type() == |
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) { |
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) { |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", inp_shape.to_string().c_str()); |
|
|
|
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", |
|
|
|
|
|
inp_shape.to_string().c_str()); |
|
|
dst.ndim = 5; |
|
|
dst.ndim = 5; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[1] = inp_shape[1] / 64; |
|
|
dst[1] = inp_shape[1] / 64; |
|
|
@@ -438,8 +439,6 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
dst[4] = 32; |
|
|
dst[4] = 32; |
|
|
} else if (layout_type() == |
|
|
} else if (layout_type() == |
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) { |
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) { |
|
|
mgb_assert(layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0); |
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0); |
|
|
dst.ndim = 5; |
|
|
dst.ndim = 5; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[0] = inp_shape[0]; |
|
|
@@ -499,18 +498,17 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
} else if (layout_type() == |
|
|
} else if (layout_type() == |
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) { |
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) { |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0); |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0); |
|
|
dst.ndim = 4; |
|
|
|
|
|
|
|
|
dst.ndim = 5; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[1] = inp_shape[3] / 32; |
|
|
dst[1] = inp_shape[3] / 32; |
|
|
dst[2] = inp_shape[1]; |
|
|
dst[2] = inp_shape[1]; |
|
|
dst[3] = inp_shape[2]; |
|
|
dst[3] = inp_shape[2]; |
|
|
dst[4] = 32; |
|
|
dst[4] = 32; |
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64) { |
|
|
|
|
|
|
|
|
} else { |
|
|
mgb_assert(layout_type() == |
|
|
mgb_assert(layout_type() == |
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); |
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0); |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0); |
|
|
dst.ndim = 4; |
|
|
|
|
|
|
|
|
dst.ndim = 5; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[0] = inp_shape[0]; |
|
|
dst[1] = inp_shape[3] / 64; |
|
|
dst[1] = inp_shape[3] / 64; |
|
|
dst[2] = inp_shape[1]; |
|
|
dst[2] = inp_shape[1]; |
|
|
@@ -3729,21 +3727,6 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
return y1.node(); |
|
|
return y1.node(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
mgb_assert(inp->shape().ndim == 4); |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); |
|
|
|
|
|
return y1.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, |
|
|
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, |
|
|
&nchw42nchw]( |
|
|
&nchw42nchw]( |
|
|
OperatorNodeBase* opr) { |
|
|
OperatorNodeBase* opr) { |
|
|
@@ -3915,31 +3898,29 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
return true; |
|
|
return true; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64, |
|
|
|
|
|
&readers](OperatorNodeBase* opr) { |
|
|
|
|
|
|
|
|
auto try_conv_reformat_nchw42nhwc = [&rewriter, &nchw42nhwc, |
|
|
|
|
|
&readers](OperatorNodeBase* opr) { |
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
// check reshape |
|
|
// check reshape |
|
|
auto reshape1 = |
|
|
|
|
|
try_cast_as_op<opr::Reshape>(opr); |
|
|
|
|
|
if (reshape1 == nullptr) |
|
|
|
|
|
|
|
|
auto reshape = try_cast_as_op<opr::Reshape>(opr); |
|
|
|
|
|
if (reshape == nullptr) |
|
|
return false; |
|
|
return false; |
|
|
opr_set.insert(opr); |
|
|
opr_set.insert(opr); |
|
|
|
|
|
|
|
|
// check dimshuffle |
|
|
// check dimshuffle |
|
|
auto shuffle = try_cast_as_op<opr::Dimshuffle>( |
|
|
auto shuffle = try_cast_as_op<opr::Dimshuffle>( |
|
|
reshape1->input(0)->owner_opr()); |
|
|
|
|
|
|
|
|
reshape->input(0)->owner_opr()); |
|
|
if (shuffle == nullptr) |
|
|
if (shuffle == nullptr) |
|
|
return false; |
|
|
return false; |
|
|
auto&& param = shuffle->param(); |
|
|
auto&& param = shuffle->param(); |
|
|
if (param.pattern_len != 6) |
|
|
|
|
|
|
|
|
if (param.pattern_len != 5) |
|
|
return false; |
|
|
return false; |
|
|
bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 && |
|
|
|
|
|
param.pattern[2] == 3 && param.pattern[3] == 4 && |
|
|
|
|
|
param.pattern[4] == 2 && param.pattern[5] == 5 && |
|
|
|
|
|
shuffle->output(0)->shape()[5] == 4 && |
|
|
|
|
|
shuffle->output(0)->shape()[4] == 16; |
|
|
|
|
|
if (!is_nchw42nchw64) |
|
|
|
|
|
|
|
|
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && |
|
|
|
|
|
param.pattern[2] == 3 && param.pattern[3] == 1 && |
|
|
|
|
|
param.pattern[4] == 4 && |
|
|
|
|
|
shuffle->output(0)->shape()[4] == 4; |
|
|
|
|
|
if (!is_nchw42nhwc) |
|
|
return false; |
|
|
return false; |
|
|
opr_set.insert(shuffle); |
|
|
opr_set.insert(shuffle); |
|
|
for (auto&& i : readers[shuffle]) { |
|
|
for (auto&& i : readers[shuffle]) { |
|
|
@@ -3948,20 +3929,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// check reshape |
|
|
|
|
|
auto reshape2 = |
|
|
|
|
|
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); |
|
|
|
|
|
if (reshape2 == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(reshape2); |
|
|
|
|
|
for (auto&& i : readers[reshape2]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto typecvt = |
|
|
auto typecvt = |
|
|
try_cast_as_op<opr::TypeCvt>(reshape2->input(0)->owner_opr()); |
|
|
|
|
|
|
|
|
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr()); |
|
|
if (typecvt == nullptr) |
|
|
if (typecvt == nullptr) |
|
|
return false; |
|
|
return false; |
|
|
auto in_dtype = typecvt->input(0)->dtype(), |
|
|
auto in_dtype = typecvt->input(0)->dtype(), |
|
|
@@ -3972,6 +3941,11 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
if (!is_s82s4) |
|
|
if (!is_s82s4) |
|
|
return false; |
|
|
return false; |
|
|
opr_set.insert(typecvt); |
|
|
opr_set.insert(typecvt); |
|
|
|
|
|
for (auto&& i : readers[typecvt]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// check conv bias |
|
|
// check conv bias |
|
|
auto conv_bias = |
|
|
auto conv_bias = |
|
|
@@ -4006,11 +3980,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
auto conv_bias_shuffle = opr::ConvBias::make( |
|
|
auto conv_bias_shuffle = opr::ConvBias::make( |
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
OperatorNodeConfig{out_dtype}); |
|
|
OperatorNodeConfig{out_dtype}); |
|
|
auto new_var = nhwc2nchw64(conv_bias_shuffle.node()); |
|
|
|
|
|
rewriter.replace_var( |
|
|
rewriter.replace_var( |
|
|
opr->output(0), new_var, |
|
|
|
|
|
|
|
|
opr->output(0), conv_bias_shuffle.node(), |
|
|
mgb_cstr_log("replace conv_bias + " |
|
|
mgb_cstr_log("replace conv_bias + " |
|
|
"reformat to conv_bias(NCHW4_NCHW64)")); |
|
|
|
|
|
|
|
|
"reformat to conv_bias(NCHW4_NHWC)")); |
|
|
return true; |
|
|
return true; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
@@ -4098,14 +4071,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
|
|
|
|
|
|
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, |
|
|
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, |
|
|
&try_conv_reformat_nchw42nchw32, |
|
|
&try_conv_reformat_nchw42nchw32, |
|
|
&try_conv_reformat_nchw42nchw64, |
|
|
|
|
|
|
|
|
&try_conv_reformat_nchw42nhwc, |
|
|
#if CUDA_VERSION >= 10020 |
|
|
#if CUDA_VERSION >= 10020 |
|
|
&try_conv_reformat_nchw322nchw4, |
|
|
&try_conv_reformat_nchw322nchw4, |
|
|
#endif |
|
|
#endif |
|
|
&rewriter](OperatorNodeBase* opr) { |
|
|
&rewriter](OperatorNodeBase* opr) { |
|
|
if (!try_conv_dimshuffle_reshape_typecvt(opr) && |
|
|
if (!try_conv_dimshuffle_reshape_typecvt(opr) && |
|
|
!try_conv_reformat_nchw42nchw32(opr) && |
|
|
|
|
|
!try_conv_reformat_nchw42nchw64(opr) |
|
|
|
|
|
|
|
|
!try_conv_reformat_nchw42nchw32(opr) && |
|
|
|
|
|
!try_conv_reformat_nchw42nhwc(opr) |
|
|
#if CUDA_VERSION >= 10020 |
|
|
#if CUDA_VERSION >= 10020 |
|
|
&& !try_conv_reformat_nchw322nchw4(opr) |
|
|
&& !try_conv_reformat_nchw322nchw4(opr) |
|
|
#endif |
|
|
#endif |
|
|
@@ -4546,6 +4519,9 @@ VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, |
|
|
case Format::NCHW64: |
|
|
case Format::NCHW64: |
|
|
type = LayoutType::NCHW64_TO_NCHW; |
|
|
type = LayoutType::NCHW64_TO_NCHW; |
|
|
break; |
|
|
break; |
|
|
|
|
|
case Format::NHWC: |
|
|
|
|
|
type = LayoutType::NHWC_TO_NCHW; |
|
|
|
|
|
break; |
|
|
default: |
|
|
default: |
|
|
mgb_throw(AssertionError, |
|
|
mgb_throw(AssertionError, |
|
|
"format(%d) is not supported, related var " |
|
|
"format(%d) is not supported, related var " |
|
|
@@ -4980,7 +4956,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
case Format::NHWC: |
|
|
case Format::NHWC: |
|
|
inps[1] = RelayoutPlaceholder::make( |
|
|
inps[1] = RelayoutPlaceholder::make( |
|
|
inps[1], RelayoutPlaceholder::LayoutType:: |
|
|
inps[1], RelayoutPlaceholder::LayoutType:: |
|
|
NCHW_TO_NHWC) |
|
|
|
|
|
|
|
|
NHWC_TO_NCHW4) |
|
|
.node(); |
|
|
.node(); |
|
|
break; |
|
|
break; |
|
|
case Format::NCHW32: |
|
|
case Format::NCHW32: |
|
|
|