From f9c4e90962c91a9eb43cede55df16dbbe89a6332 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 27 Apr 2022 15:30:38 +0800 Subject: [PATCH] catch more foldable constant in static shape inference mode (#3738) --- tools/pnnx/src/pass_level0/shape_inference.cpp | 14 ++++++++++---- tools/pnnx/src/pass_level5/eliminate_dropout.cpp | 2 ++ .../src/pass_level5/eliminate_noop_expression.cpp | 2 ++ tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp | 2 ++ tools/pnnx/src/pass_level5/eliminate_slice.cpp | 2 ++ .../src/pass_level5/eliminate_view_reshape.cpp | 2 ++ 6 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index 51ed9a9ff..ddf784b83 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -21,8 +21,14 @@ namespace pnnx { -static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs) +static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs, bool ignore_aten_size) { + if (ignore_aten_size && v->node()->kind().toDisplayString() == std::string("aten::size")) + { + // any intermediate shape is constant with static input shape + return false; + } + for (auto x : inputs) { if (v == x) @@ -31,7 +37,7 @@ static bool value_link_input(const torch::jit::Value* v, const std::vectornode()->inputs().size(); i++) { - bool link = value_link_input(v->node()->inputs()[i], inputs); + bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size); if (link) return true; } @@ -183,7 +189,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(c10::TensorType::create(t)); // check if value that does not depend on inputs - if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) + if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs)) { output_tensors[v] = t; } @@ -221,7 +227,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(finaltype); // check if value that does not depend on inputs - if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) + if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs)) { output_tensors[v] = t; } diff --git a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp index b570f7de1..9809ac189 100644 --- a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp @@ -53,6 +53,8 @@ void eliminate_dropout(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = dropout_out->name; + dropout_out->producer = 0; dropout_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp index 97a5d3e42..02f9a9342 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp @@ -60,6 +60,8 @@ void eliminate_noop_expression(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = expr_out->name; + expr_out->producer = 0; expr_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp index f3129a933..4cb64f702 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp @@ -68,6 +68,8 @@ void eliminate_noop_pad(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = pad_out->name; + pad_out->producer = 0; pad_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_slice.cpp b/tools/pnnx/src/pass_level5/eliminate_slice.cpp index afe8ae073..7be03ddf6 100644 --- a/tools/pnnx/src/pass_level5/eliminate_slice.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_slice.cpp @@ -62,6 +62,8 @@ void eliminate_slice(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = slice_out->name; + slice_out->producer = 0; slice_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp b/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp index 342b4048c..c3097bdb4 100644 --- a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp @@ -60,6 +60,8 @@ void eliminate_view_reshape(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = op_out->name; + op_out->producer = 0; op_out->consumers.clear();