Browse Source

fix(gopt): fix convbias replace of cd4 pass

GitOrigin-RevId: b0715e2b77
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
6585514902
2 changed files with 22 additions and 7 deletions
  1. +1
    -1
      src/gopt/impl/framework.cpp
  2. +21
    -6
      src/gopt/impl/inference.cpp

+ 1
- 1
src/gopt/impl/framework.cpp View File

@@ -716,8 +716,8 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options(
need_param_fuse = true;
}
if (options.transform_nhwcd4()) {
add_pass(ConvertFormatPass::make_nhwcd4_converter());
add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter());
need_param_fuse = true;
}
if (options.transform_nchw88()) {


+ 21
- 6
src/gopt/impl/inference.cpp View File

@@ -1169,18 +1169,33 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
conv_bias_weights = relayout_weight.node();

param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
conv_bias_bias = relayout_bias.node();
mgb_assert(new_inp.size() < 4,
"ConvertFormat pass does not support fuse Z");
bool has_bias = new_inp.size() > 2;
if (has_bias &&
new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) {
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
conv_bias_bias = relayout_bias.node();
} else if (has_bias) {
conv_bias_bias = new_inp[2];
}

auto new_param = conv_bias_opr.param();
new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_src->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
SymbolVar new_conv_bias_opr;
if (has_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
}
OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5 &&
new_conv_bias_opr.format().type() ==


Loading…
Cancel
Save