|
|
|
@@ -3550,6 +3550,35 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
auto nchw42nhwc = [](VarNode* inp) -> VarNode* { |
|
|
|
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); |
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); |
|
|
|
auto y1 = opr::Reshape::make(y0, tshp); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* { |
|
|
|
mgb_assert(inp->shape().ndim == 4); |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp = opr::Concat::make( |
|
|
|
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, |
|
|
|
&nchw42nchw]( |
|
|
|
OperatorNodeBase* opr) { |
|
|
|
@@ -3721,6 +3750,106 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
|
return true; |
|
|
|
}; |
|
|
|
|
|
|
|
auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64, |
|
|
|
&readers](OperatorNodeBase* opr) { |
|
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
|
// check reshape |
|
|
|
auto reshape1 = |
|
|
|
try_cast_as_op<opr::Reshape>(opr); |
|
|
|
if (reshape1 == nullptr) |
|
|
|
return false; |
|
|
|
opr_set.insert(opr); |
|
|
|
|
|
|
|
// check dimshuffle |
|
|
|
auto shuffle = try_cast_as_op<opr::Dimshuffle>( |
|
|
|
reshape1->input(0)->owner_opr()); |
|
|
|
if (shuffle == nullptr) |
|
|
|
return false; |
|
|
|
auto&& param = shuffle->param(); |
|
|
|
if (param.pattern_len != 6) |
|
|
|
return false; |
|
|
|
bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 && |
|
|
|
param.pattern[2] == 3 && param.pattern[3] == 4 && |
|
|
|
param.pattern[4] == 2 && param.pattern[5] == 5 && |
|
|
|
shuffle->output(0)->shape()[5] == 4 && |
|
|
|
shuffle->output(0)->shape()[4] == 16; |
|
|
|
if (!is_nchw42nchw64) |
|
|
|
return false; |
|
|
|
opr_set.insert(shuffle); |
|
|
|
for (auto&& i : readers[shuffle]) { |
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
reader_set.insert(i.first); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// check reshape |
|
|
|
auto reshape2 = |
|
|
|
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); |
|
|
|
if (reshape2 == nullptr) |
|
|
|
return false; |
|
|
|
opr_set.insert(reshape2); |
|
|
|
for (auto&& i : readers[reshape2]) { |
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
reader_set.insert(i.first); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto typecvt = |
|
|
|
try_cast_as_op<opr::TypeCvt>(reshape2->input(0)->owner_opr()); |
|
|
|
if (typecvt == nullptr) |
|
|
|
return false; |
|
|
|
auto in_dtype = typecvt->input(0)->dtype(), |
|
|
|
out_dtype = typecvt->output(0)->dtype(); |
|
|
|
printf("%s, %s\n", in_dtype.name(), out_dtype.name()); |
|
|
|
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); |
|
|
|
if (!is_s82s4) |
|
|
|
return false; |
|
|
|
opr_set.insert(typecvt); |
|
|
|
|
|
|
|
// check conv bias |
|
|
|
auto conv_bias = |
|
|
|
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); |
|
|
|
if (conv_bias == nullptr) |
|
|
|
return false; |
|
|
|
auto inp_dtype = conv_bias->input(0)->dtype(); |
|
|
|
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
conv_bias->param().format == |
|
|
|
megdnn::param::ConvBias::Format::NCHW4; |
|
|
|
if (!is_s8nchw4) |
|
|
|
return false; |
|
|
|
if (conv_bias->input().size() != 3) |
|
|
|
return false; |
|
|
|
opr_set.insert(conv_bias); |
|
|
|
for (auto&& i : readers[conv_bias]) { |
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
reader_set.insert(i.first); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto reader : reader_set) { |
|
|
|
if (opr_set.count(reader) <= 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
auto src = rewriter.get_var(conv_bias->input(0)), |
|
|
|
filter = rewriter.get_var(conv_bias->input(1)), |
|
|
|
bias = rewriter.get_var(conv_bias->input(2)); |
|
|
|
auto new_bias = nchw42nhwc(bias); |
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; |
|
|
|
auto conv_bias_shuffle = opr::ConvBias::make( |
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
OperatorNodeConfig{out_dtype}); |
|
|
|
auto new_var = nhwc2nchw64(conv_bias_shuffle.node()); |
|
|
|
rewriter.replace_var( |
|
|
|
opr->output(0), new_var, |
|
|
|
mgb_cstr_log("replace conv_bias + " |
|
|
|
"reformat to conv_bias(NCHW4_NCHW64)")); |
|
|
|
return true; |
|
|
|
}; |
|
|
|
|
|
|
|
auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( |
|
|
|
OperatorNodeBase* opr) { |
|
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
|
@@ -3805,12 +3934,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
|
|
|
|
|
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, |
|
|
|
&try_conv_reformat_nchw42nchw32, |
|
|
|
&try_conv_reformat_nchw42nchw64, |
|
|
|
#if CUDA_VERSION >= 10020 |
|
|
|
&try_conv_reformat_nchw322nchw4, |
|
|
|
#endif |
|
|
|
&rewriter](OperatorNodeBase* opr) { |
|
|
|
if (!try_conv_dimshuffle_reshape_typecvt(opr) && |
|
|
|
!try_conv_reformat_nchw42nchw32(opr) |
|
|
|
!try_conv_reformat_nchw42nchw32(opr) && |
|
|
|
!try_conv_reformat_nchw42nchw64(opr) |
|
|
|
#if CUDA_VERSION >= 10020 |
|
|
|
&& !try_conv_reformat_nchw322nchw4(opr) |
|
|
|
#endif |
|
|
|
|