| @@ -1771,7 +1771,6 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||||
| bool can_be_fused = true; | bool can_be_fused = true; | ||||
| can_be_fused &= (elem->input().size() == 1); | can_be_fused &= (elem->input().size() == 1); | ||||
| can_be_fused &= (elem->param().mode == Mode::RELU) || | can_be_fused &= (elem->param().mode == Mode::RELU) || | ||||
| (elem->param().mode == Mode::TANH) || | |||||
| (elem->param().mode == Mode::SIGMOID); | (elem->param().mode == Mode::SIGMOID); | ||||
| return can_be_fused; | return can_be_fused; | ||||
| @@ -1911,13 +1910,14 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||||
| } | } | ||||
| } else if (try_fuse_nonlinearity(elem)) { | } else if (try_fuse_nonlinearity(elem)) { | ||||
| auto inp = rewriter.get_var(elem->input(0)); | auto inp = rewriter.get_var(elem->input(0)); | ||||
| auto elem_noline = get_nonlinearity_mode(elem); | |||||
| { | { | ||||
| auto conv = try_cast_as_op<opr::Convolution>(inp->owner_opr()); | auto conv = try_cast_as_op<opr::Convolution>(inp->owner_opr()); | ||||
| if (conv && check_conv(conv) && | if (conv && check_conv(conv) && | ||||
| m_deps[elem->input(0)].size() == 1) { | m_deps[elem->input(0)].size() == 1) { | ||||
| opr::ConvBiasForward::Param param = | opr::ConvBiasForward::Param param = | ||||
| convert_to_conv_bias_param(conv->param()); | convert_to_conv_bias_param(conv->param()); | ||||
| param.nonlineMode = get_nonlinearity_mode(elem); | |||||
| param.nonlineMode = elem_noline; | |||||
| auto new_var = opr::ConvBiasForward::make( | auto new_var = opr::ConvBiasForward::make( | ||||
| conv->input(0), conv->input(1), param, | conv->input(0), conv->input(1), param, | ||||
| conv->execution_policy(), conv->config()) | conv->execution_policy(), conv->config()) | ||||
| @@ -1941,9 +1941,16 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||||
| ; | ; | ||||
| }; | }; | ||||
| if (conv && check_conv_bias(conv) && | if (conv && check_conv_bias(conv) && | ||||
| m_deps[elem->input(0)].size() == 1) { | |||||
| m_deps[elem->input(0)].size() == 1 && | |||||
| conv->input().size() > 2) { | |||||
| auto param = conv->param(); | auto param = conv->param(); | ||||
| param.nonlineMode = get_nonlinearity_mode(elem); | |||||
| bool noline_ok = param.nonlineMode == NonlineMode::IDENTITY || | |||||
| (param.nonlineMode == NonlineMode::RELU && | |||||
| elem_noline == NonlineMode::RELU); | |||||
| if (!noline_ok) { | |||||
| return; | |||||
| } | |||||
| param.nonlineMode = elem_noline; | |||||
| auto new_var = | auto new_var = | ||||
| opr::ConvBiasForward::make( | opr::ConvBiasForward::make( | ||||
| conv->input(0), conv->input(1), conv->input(2), | conv->input(0), conv->input(1), conv->input(2), | ||||
| @@ -1731,6 +1731,52 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4); | ||||
| } | } | ||||
| TEST(TestGoptInference, ConvBiasNonlinearityFusePass2) { | |||||
| // hwcd4 is only supported in naive handle | |||||
| NaiveMegDNNHandleScope naive_megdnn_handle; | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| HostTensorGenerator<> gen; | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||||
| }; | |||||
| opr::Convolution::Param param; | |||||
| auto x = mkvar("x", {5, 8, 16, 24}), w1 = mkcvar("w1", {4, 8, 1, 1}), | |||||
| w2 = mkcvar("w2", {4, 8, 1, 1}); | |||||
| auto b1 = mkcvar("b1", {1, 4, 1, 1}); | |||||
| auto y_cut = opr::Convolution::make(x, w1, param); | |||||
| auto y = opr::Elemwise::make({y_cut + b1}, opr::Elemwise::Param::Mode::SIGMOID); | |||||
| y = opr::Elemwise::make({y}, opr::Elemwise::Param::Mode::RELU); | |||||
| auto y_cut2 = opr::Convolution::make(x, w2, param); | |||||
| y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::SIGMOID); | |||||
| y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::RELU); | |||||
| y = y + y_cut2; | |||||
| SymbolVar y_opt; | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity(); | |||||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
| ASSERT_EQ( | |||||
| opr::ConvBias::Param::NonlineMode::SIGMOID, | |||||
| find_opr<opr::ConvBias>(y_opt).param().nonlineMode); | |||||
| graph->compile({{y_opt, {}}}) | |||||
| ->to_json() | |||||
| ->writeto_fpath( | |||||
| output_file("TestGoptInference.FuseConvBiasNonlinPass2.json")); | |||||
| HostTensorND host_y, host_y_opt; | |||||
| auto func = graph->compile( | |||||
| {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4); | |||||
| } | |||||
| TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) { | TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) { | ||||
| NaiveMegDNNHandleScope naive_megdnn_handle; | NaiveMegDNNHandleScope naive_megdnn_handle; | ||||