Browse Source

pnnx eliminate reshape shape expression for only one dynamic dimsize (#4548)

tags/20230517
nihui GitHub 3 years ago
parent
commit
c68266efd0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 199 additions and 0 deletions
  1. +1
    -0
      tools/pnnx/src/CMakeLists.txt
  2. +3
    -0
      tools/pnnx/src/pass_level5.cpp
  3. +13
    -0
      tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp
  4. +161
    -0
      tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp
  5. +21
    -0
      tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h

+ 1
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -309,6 +309,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_noop_upsample.cpp
pass_level5/eliminate_noop_slice.cpp
pass_level5/eliminate_noop_view_reshape.cpp
pass_level5/eliminate_reshape_shape_expression.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_adjacent_reshape.cpp


+ 3
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -25,6 +25,7 @@
#include "pass_level5/eliminate_noop_upsample.h"
#include "pass_level5/eliminate_noop_slice.h"
#include "pass_level5/eliminate_noop_view_reshape.h"
#include "pass_level5/eliminate_reshape_shape_expression.h"
#include "pass_level5/eval_expression.h"
#include "pass_level5/fuse_adjacent_reshape.h"
#include "pass_level5/fuse_channel_shuffle.h"
@@ -119,6 +120,8 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

eliminate_noop_view_reshape(g);

eliminate_reshape_shape_expression(g);

fuse_channel_shuffle(g);

fuse_index_expression(g);


+ 13
- 0
tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp View File

@@ -40,6 +40,19 @@ void eliminate_noop_view_reshape(Graph& graph)
if (input_shape.empty())
continue;

// if only one dynamic dim-size
int dynamic_dim_count = 0;
for (size_t j = 0; j < output_shape.size(); j++)
{
if (output_shape[j] == -1)
{
dynamic_dim_count += 1;
}
}

if (dynamic_dim_count > 1)
continue;

matched = true;

for (auto& x : op->inputs)


+ 161
- 0
tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp View File

@@ -0,0 +1,161 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_reshape_shape_expression.h"

#include <iostream>
#include <sstream>
#include <algorithm>
#include <stack>
#include <vector>
#include <string>

namespace pnnx {

static bool token_is_interger_literal(const std::string& t)
{
std::istringstream iss(t);
int f;
iss >> std::noskipws >> f;
return iss.eof() && !iss.fail();
}

static std::vector<int> build_shape(const std::string& expr)
{
std::string listexpr = expr.substr(1, expr.size() - 2);

std::vector<int> shape;

std::string t;
int level = 0;
for (size_t i = 0; i < listexpr.size(); i++)
{
char ch = listexpr[i];

if (ch == '(' || ch == '[')
{
level += 1;
t = "-1";
}
else if (ch == ')' || ch == ']')
{
level -= 1;
t = "-1";
}
else if (level == 0 && ch == ',')
{
int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1;
shape.push_back(dimsize);
t.clear();
}
else
{
t += ch;
}
}

if (level == 0 && !t.empty())
{
int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1;
shape.push_back(dimsize);
}

return shape;
}

void eliminate_reshape_shape_expression(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "Tensor.view" && op->type != "Tensor.reshape")
continue;

if (op->inputs.size() != 2)
continue;

Operator* op_expr = op->inputs[1]->producer;
if (op_expr->type != "pnnx.Expression")
continue;

std::string expr = op_expr->params.at("expr").s;
if (expr.empty() || expr[0] != '[')
continue;

std::vector<int> shape = build_shape(expr);

// replace -1 with static dim-size
std::vector<int> outshape = op->outputs[0]->shape;
if (!outshape.empty())
{
for (size_t j = 0; j < outshape.size(); j++)
{
if (outshape[j] != -1)
{
shape[j] = outshape[j];
}
}
}

// if only one dynamic dim-size, drop expression
int dynamic_dim_count = 0;
for (size_t j = 0; j < shape.size(); j++)
{
if (shape[j] == -1)
{
dynamic_dim_count += 1;
}
}

if (dynamic_dim_count > 1)
continue;

matched = true;

op->params["shape"] = shape;

op->inputs.resize(1);
op_expr->outputs[0]->remove_consumer(op);

if (op_expr->outputs[0]->consumers.size() == 0)
{
// remove expression operator
op_expr->inputs[0]->remove_consumer(op_expr);

Operand* op_expr_out = op_expr->outputs[0];

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_expr_out));
delete op_expr_out;

op_expr->inputs.clear();
op_expr->outputs.clear();

graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_expr));
delete op_expr;
}

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void eliminate_reshape_shape_expression(Graph& graph);

} // namespace pnnx

Loading…
Cancel
Save