diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index 51ed9a9ff..ddf784b83 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -21,8 +21,14 @@ namespace pnnx { -static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs) +static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs, bool ignore_aten_size) { + if (ignore_aten_size && v->node()->kind().toDisplayString() == std::string("aten::size")) + { + // any intermediate shape is constant with static input shape + return false; + } + for (auto x : inputs) { if (v == x) @@ -31,7 +37,7 @@ static bool value_link_input(const torch::jit::Value* v, const std::vectornode()->inputs().size(); i++) { - bool link = value_link_input(v->node()->inputs()[i], inputs); + bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size); if (link) return true; } @@ -183,7 +189,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(c10::TensorType::create(t)); // check if value that does not depend on inputs - if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) + if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs)) { output_tensors[v] = t; } @@ -221,7 +227,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(finaltype); // check if value that does not depend on inputs - if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) + if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs)) { output_tensors[v] = t; } diff --git a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp index b570f7de1..9809ac189 100644 --- a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp @@ -53,6 +53,8 @@ void eliminate_dropout(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = dropout_out->name; + dropout_out->producer = 0; dropout_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp index 97a5d3e42..02f9a9342 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp @@ -60,6 +60,8 @@ void eliminate_noop_expression(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = expr_out->name; + expr_out->producer = 0; expr_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp index f3129a933..4cb64f702 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp @@ -68,6 +68,8 @@ void eliminate_noop_pad(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = pad_out->name; + pad_out->producer = 0; pad_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_slice.cpp b/tools/pnnx/src/pass_level5/eliminate_slice.cpp index afe8ae073..7be03ddf6 100644 --- a/tools/pnnx/src/pass_level5/eliminate_slice.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_slice.cpp @@ -62,6 +62,8 @@ void eliminate_slice(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = slice_out->name; + slice_out->producer = 0; slice_out->consumers.clear(); diff --git a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp b/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp index 342b4048c..c3097bdb4 100644 --- a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp @@ -60,6 +60,8 @@ void eliminate_view_reshape(Graph& graph) op->inputs[0]->consumers.push_back(x); } + op->inputs[0]->name = op_out->name; + op_out->producer = 0; op_out->consumers.clear();