diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index eaff15f1f..5695e4a95 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -309,6 +309,7 @@ set(pnnx_pass_level5_SRCS pass_level5/eliminate_noop_upsample.cpp pass_level5/eliminate_noop_slice.cpp pass_level5/eliminate_noop_view_reshape.cpp + pass_level5/eliminate_reshape_shape_expression.cpp pass_level5/eval_expression.cpp pass_level5/fold_constants.cpp pass_level5/fuse_adjacent_reshape.cpp diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 72030dd47..001765348 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -25,6 +25,7 @@ #include "pass_level5/eliminate_noop_upsample.h" #include "pass_level5/eliminate_noop_slice.h" #include "pass_level5/eliminate_noop_view_reshape.h" +#include "pass_level5/eliminate_reshape_shape_expression.h" #include "pass_level5/eval_expression.h" #include "pass_level5/fuse_adjacent_reshape.h" #include "pass_level5/fuse_channel_shuffle.h" @@ -119,6 +120,8 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons eliminate_noop_view_reshape(g); + eliminate_reshape_shape_expression(g); + fuse_channel_shuffle(g); fuse_index_expression(g); diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp index e6b00e87b..1f79a7a63 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp @@ -40,6 +40,19 @@ void eliminate_noop_view_reshape(Graph& graph) if (input_shape.empty()) continue; + // if only one dynamic dim-size + int dynamic_dim_count = 0; + for (size_t j = 0; j < output_shape.size(); j++) + { + if (output_shape[j] == -1) + { + dynamic_dim_count += 1; + } + } + + if (dynamic_dim_count > 1) + continue; + matched = true; for (auto& x : op->inputs) diff --git a/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp new file mode 100644 index 000000000..744d29016 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp @@ -0,0 +1,161 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_reshape_shape_expression.h" + +#include +#include +#include +#include +#include +#include + +namespace pnnx { + +static bool token_is_interger_literal(const std::string& t) +{ + std::istringstream iss(t); + int f; + iss >> std::noskipws >> f; + return iss.eof() && !iss.fail(); +} + +static std::vector build_shape(const std::string& expr) +{ + std::string listexpr = expr.substr(1, expr.size() - 2); + + std::vector shape; + + std::string t; + int level = 0; + for (size_t i = 0; i < listexpr.size(); i++) + { + char ch = listexpr[i]; + + if (ch == '(' || ch == '[') + { + level += 1; + t = "-1"; + } + else if (ch == ')' || ch == ']') + { + level -= 1; + t = "-1"; + } + else if (level == 0 && ch == ',') + { + int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1; + shape.push_back(dimsize); + t.clear(); + } + else + { + t += ch; + } + } + + if (level == 0 && !t.empty()) + { + int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1; + shape.push_back(dimsize); + } + + return shape; +} + +void eliminate_reshape_shape_expression(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.view" && op->type != "Tensor.reshape") + continue; + + if (op->inputs.size() != 2) + continue; + + Operator* op_expr = op->inputs[1]->producer; + if (op_expr->type != "pnnx.Expression") + continue; + + std::string expr = op_expr->params.at("expr").s; + if (expr.empty() || expr[0] != '[') + continue; + + std::vector shape = build_shape(expr); + + // replace -1 with static dim-size + std::vector outshape = op->outputs[0]->shape; + if (!outshape.empty()) + { + for (size_t j = 0; j < outshape.size(); j++) + { + if (outshape[j] != -1) + { + shape[j] = outshape[j]; + } + } + } + + // if only one dynamic dim-size, drop expression + int dynamic_dim_count = 0; + for (size_t j = 0; j < shape.size(); j++) + { + if (shape[j] == -1) + { + dynamic_dim_count += 1; + } + } + + if (dynamic_dim_count > 1) + continue; + + matched = true; + + op->params["shape"] = shape; + + op->inputs.resize(1); + op_expr->outputs[0]->remove_consumer(op); + + if (op_expr->outputs[0]->consumers.size() == 0) + { + // remove expression operator + op_expr->inputs[0]->remove_consumer(op_expr); + + Operand* op_expr_out = op_expr->outputs[0]; + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_expr_out)); + delete op_expr_out; + + op_expr->inputs.clear(); + op_expr->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_expr)); + delete op_expr; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h new file mode 100644 index 000000000..d4457c3ac --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_reshape_shape_expression(Graph& graph); + +} // namespace pnnx