|
|
|
@@ -1169,18 +1169,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); |
|
|
|
conv_bias_weights = relayout_weight.node(); |
|
|
|
|
|
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; |
|
|
|
auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param); |
|
|
|
conv_bias_bias = relayout_bias.node(); |
|
|
|
mgb_assert(new_inp.size() < 4, |
|
|
|
"ConvertFormat pass does not support fuse Z"); |
|
|
|
bool has_bias = new_inp.size() > 2; |
|
|
|
if (has_bias && |
|
|
|
new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) { |
|
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; |
|
|
|
auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param); |
|
|
|
conv_bias_bias = relayout_bias.node(); |
|
|
|
} else if (has_bias) { |
|
|
|
conv_bias_bias = new_inp[2]; |
|
|
|
} |
|
|
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NHWCD4; |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 5 && |
|
|
|
conv_bias_src->format().type() == |
|
|
|
TensorFormat::Type::IMAGE2D_PACK4); |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
SymbolVar new_conv_bias_opr; |
|
|
|
if (has_bias) { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
} else { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_weights, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
} |
|
|
|
OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5 && |
|
|
|
new_conv_bias_opr.format().type() == |
|
|
|
|