Browse Source

match inplace slice copy pattern, rewrite copy uses (#4338)

tags/20221128
nihui GitHub 3 years ago
parent
commit
a2af6369d9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 545 additions and 4 deletions
  1. +2
    -0
      tools/pnnx/src/CMakeLists.txt
  2. +7
    -0
      tools/pnnx/src/ir.cpp
  3. +0
    -4
      tools/pnnx/src/pass_level1.cpp
  4. +104
    -0
      tools/pnnx/src/pass_level2.cpp
  5. +64
    -0
      tools/pnnx/src/pass_level2/Tensor_copy.cpp
  6. +3
    -0
      tools/pnnx/src/pass_level5.cpp
  7. +3
    -0
      tools/pnnx/src/pass_level5/fold_constants.cpp
  8. +279
    -0
      tools/pnnx/src/pass_level5/fuse_slice_copy.cpp
  9. +21
    -0
      tools/pnnx/src/pass_level5/fuse_slice_copy.h
  10. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  11. +61
    -0
      tools/pnnx/tests/test_Tensor_slice_copy.py

+ 2
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -174,6 +174,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/F_upsample_nearest.cpp
pass_level2/F_upsample.cpp
pass_level2/Tensor_contiguous.cpp
pass_level2/Tensor_copy.cpp
pass_level2/Tensor_expand.cpp
pass_level2/Tensor_expand_as.cpp
pass_level2/Tensor_index.cpp
@@ -314,6 +315,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_contiguous_view.cpp
pass_level5/fuse_linear_batchnorm1d.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_slice_copy.cpp
pass_level5/fuse_slice_indices.cpp
pass_level5/fuse_slice_to_tensor_split.cpp
pass_level5/fuse_static_conv.cpp


+ 7
- 0
tools/pnnx/src/ir.cpp View File

@@ -1658,6 +1658,13 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
std::string slice_expr = make_slice_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str());
}
else if (op->type == "Tensor.slice_copy")
{
// slice copy expr
std::string slice_expr = make_slice_expression(op);
fprintf(pyfp, "v_%s = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str());
fprintf(pyfp, " v_%s[%s] = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), slice_expr.c_str(), sanitize_identifier(op->inputs[1]->name).c_str());
}
else if (op->type == "Tensor.index")
{
// index expr


+ 0
- 4
tools/pnnx/src/pass_level1.cpp View File

@@ -376,10 +376,6 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit

Operator* op = pg.new_operator(n->kind().toDisplayString(), name);

// always treat inplace op type as non-inplace version
if (op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_')
op->type = op->type.substr(0, op->type.size() - 1);

for (int i = 0; i < (int)n->inputs().size(); i++)
{
const auto& in = n->input(i);


+ 104
- 0
tools/pnnx/src/pass_level2.cpp View File

@@ -502,8 +502,112 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
}
}

static void fix_inplace_copy_output(Graph& graph)
{
while (1)
{
bool matched = false;
for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_';
if (!is_inplace_op)
continue;

// replace inplace op with non-inplace version
op->type = op->type.substr(0, op->type.size() - 1);

if (op->type == "aten::copy")
continue;

if (op->outputs[0]->consumers.size() != 0)
continue;

matched = true;

// find in0 from slice / select chain
Operand* in0 = op->inputs[0];
while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select")
{
in0 = in0->producer->inputs[0];
}

// append copy for inplace op
Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op);
Operand* copy_out = graph.new_operand(op->name + "_copy_out");

copy_out->shape = in0->shape;

op_copy->inputs.push_back(op->inputs[0]);
op_copy->inputs.push_back(op->outputs[0]);
op->inputs[0]->consumers.push_back(op_copy);
op->outputs[0]->consumers.push_back(op_copy);

op_copy->outputs.push_back(copy_out);
copy_out->producer = op_copy;

break;
}

if (!matched)
break;
}

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "aten::copy")
continue;

if (op->outputs[0]->consumers.size() != 0)
continue;

// aten::slice 5 1 in0 .... a
// aten::slice 5 1 a .... b
// aten::copy 2 1 b in1 out

// aten::select 3 1 in0 .... a
// aten::copy 2 1 a in1 out

// find in0 from slice / select chain
Operand* in0 = op->inputs[0];
while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select")
{
in0 = in0->producer->inputs[0];
}

// replace all the following uses of in0 with out
Operand* out0 = op->outputs[0];
out0->shape = in0->shape;
for (size_t j = i; j < graph.ops.size(); j++)
{
Operator* op2 = graph.ops[j];

bool use_in0 = false;
for (size_t k = 0; k < op2->inputs.size(); k++)
{
if (op2->inputs[k] == in0)
{
op2->inputs[k] = out0;
use_in0 = true;
}
}

if (use_in0)
{
in0->remove_consumer(op2);
out0->consumers.push_back(op2);
}
}
}
}

void pass_level2(Graph& g)
{
fix_inplace_copy_output(g);

int opindex = 0;
for (auto x : g_global_pnnx_graph_rewriter_passes)
{


+ 64
- 0
tools/pnnx/src/pass_level2/Tensor_copy.cpp View File

@@ -0,0 +1,64 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_copy : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 self
pnnx.Input input_1 0 1 src
prim::Constant op_0 0 1 non_blocking value=*
aten::copy op_1 3 1 self src non_blocking out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.copy";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_copy, 20)

class Tensor_copy_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 self
pnnx.Input input_1 0 1 src
aten::copy op_1 2 1 self src out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.copy";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_copy_1, 20)

} // namespace pnnx

+ 3
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -34,6 +34,7 @@
#include "pass_level5/fuse_contiguous_view.h"
#include "pass_level5/fuse_linear_batchnorm1d.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_slice_copy.h"
#include "pass_level5/fuse_slice_indices.h"
#include "pass_level5/fuse_slice_to_tensor_split.h"
#include "pass_level5/fuse_static_conv.h"
@@ -66,6 +67,8 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_slice_to_tensor_split(g);

fuse_slice_copy(g);

fuse_static_conv(g);

fuse_conv1d_batchnorm1d(g);


+ 3
- 0
tools/pnnx/src/pass_level5/fold_constants.cpp View File

@@ -22,6 +22,9 @@ namespace pnnx {

void fold_constants(Graph& graph, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
{
if (foldable_constants.empty())
return;

StoreZipReader zip;
zip.open(foldable_constants_zippath);



+ 279
- 0
tools/pnnx/src/pass_level5/fuse_slice_copy.cpp View File

@@ -0,0 +1,279 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_slice_copy.h"

#include <limits.h>
#include <algorithm>
#include <stack>
#include "pass_level2.h"

namespace pnnx {

void fuse_slice_copy(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "Tensor.copy")
continue;

// collect slice / select op chain
std::stack<const Operator*> slice_select_ops;
int descent_dim_current = INT_MAX;
const Operand* in0 = op->inputs[0];
while (in0->producer->type == "Tensor.slice" || in0->producer->type == "Tensor.select")
{
const Operator* sop = in0->producer;
if (sop->type == "Tensor.slice")
{
if (sop->params.find("dims") == sop->params.end()
|| sop->params.find("starts") == sop->params.end()
|| sop->params.find("ends") == sop->params.end()
|| sop->params.find("steps") == sop->params.end())
{
fprintf(stderr, "dynamic index in slice copy chain is not supported\n");
break;
}

int dims0 = sop->params.at("dims").ai[0];
if (descent_dim_current < dims0)
{
break;
}

descent_dim_current = dims0;
}

if (sop->type == "Tensor.select")
{
if (sop->params.find("dim") == sop->params.end()
|| sop->params.find("index") == sop->params.end())
{
fprintf(stderr, "dynamic index in select copy chain is not supported\n");
break;
}

int dim = sop->params.at("dim").i;
if (descent_dim_current < dim)
{
break;
}

descent_dim_current = dim;
}

slice_select_ops.push(sop);
in0 = sop->inputs[0];
}

matched = true;

if (slice_select_ops.empty())
{
// eliminate noop copy
Operand* out = op->outputs[0];

for (auto& x : out->consumers)
{
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == out)
x->inputs[j] = op->inputs[1];
}

op->inputs[1]->consumers.push_back(x);
}

op->inputs[0]->remove_consumer(op);
op->inputs[1]->remove_consumer(op);

op->inputs[1]->name = out->name;

out->producer = 0;
out->consumers.clear();

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), out));
delete out;

op->inputs.clear();
op->outputs.clear();

graph.ops.erase(graph.ops.begin() + i);
delete op;

break;
}

const Operator* top_sop = slice_select_ops.top();

// construct one-step slice
std::vector<int> new_dims;
std::vector<int> new_starts;
std::vector<int> new_ends;
std::vector<int> new_steps;

int select_dims_offset = 0;
while (!slice_select_ops.empty())
{
const Operator* sop = slice_select_ops.top();
slice_select_ops.pop();

if (sop->type == "Tensor.slice")
{
std::vector<int> dims = sop->params.at("dims").ai;
std::vector<int> starts = sop->params.at("starts").ai;
std::vector<int> ends = sop->params.at("ends").ai;
std::vector<int> steps = sop->params.at("steps").ai;

for (size_t j = 0; j < dims.size(); j++)
{
dims[j] += select_dims_offset;
}

new_dims.insert(new_dims.end(), dims.begin(), dims.end());
new_starts.insert(new_starts.end(), starts.begin(), starts.end());
new_ends.insert(new_ends.end(), ends.begin(), ends.end());
new_steps.insert(new_steps.end(), steps.begin(), steps.end());
}
else if (sop->type == "Tensor.select")
{
int dim = sop->params.at("dim").i;
int index = sop->params.at("index").i;

dim += select_dims_offset;
int end = index + 1;
if (index == -1)
end = INT_MAX;

new_dims.push_back(dim);
new_starts.push_back(index);
new_ends.push_back(end);
new_steps.push_back(1);

select_dims_offset += 1;
}
}

op->type = "Tensor.slice_copy";

// insert clone before any slices
Operator* op_clone = graph.new_operator_before("Tensor.clone", op->name + "_ncnnclone", top_sop);
Operand* clone_out = graph.new_operand(op->name + "_ncnnclone_out");

clone_out->shape = top_sop->inputs[0]->shape;

op_clone->inputs.push_back(top_sop->inputs[0]);
top_sop->inputs[0]->consumers.push_back(op_clone);

op_clone->outputs.push_back(clone_out);
clone_out->producer = op_clone;

op->inputs[0]->remove_consumer(op);
op->inputs[0] = clone_out;
clone_out->consumers.push_back(op);

op->params["dims"] = new_dims;
op->params["starts"] = new_starts;
op->params["ends"] = new_ends;
op->params["steps"] = new_steps;

int input_rank = (int)op->inputs[0]->shape.size();
if (input_rank == 0)
{
// insert view_as(sliced) for different or unknown rank
Operator* op_slice = graph.new_operator_before("Tensor.slice", op->name + "_ncnnslice", op);
Operator* op_view_as = graph.new_operator_before("Tensor.view_as", op->name + "_ncnnview_as", op);

Operand* slice_out = graph.new_operand(op->name + "_ncnnslice_out");
Operand* view_as_out = graph.new_operand(op->name + "_ncnnview_as_out");

op_slice->params["dims"] = new_dims;
op_slice->params["starts"] = new_starts;
op_slice->params["ends"] = new_ends;
op_slice->params["steps"] = new_steps;

op_slice->inputs.push_back(op->inputs[0]);
op->inputs[0]->consumers.push_back(op_slice);

op_slice->outputs.push_back(slice_out);
slice_out->producer = op_slice;

op_view_as->inputs.push_back(op->inputs[1]);
op->inputs[1]->consumers.push_back(op_view_as);
op->inputs[1]->remove_consumer(op);
op_view_as->inputs.push_back(slice_out);
slice_out->consumers.push_back(op_view_as);

op_view_as->outputs.push_back(view_as_out);
view_as_out->producer = op_view_as;

op->inputs[1] = view_as_out;
view_as_out->consumers.push_back(op);
}
else if (input_rank != (int)op->inputs[1]->shape.size())
{
// solve the target shape
std::vector<int> target_shape = op->inputs[0]->shape;
for (size_t j = 0; j < new_dims.size(); j++)
{
int dim = new_dims[j];
int start = new_starts[j];
int end = new_ends[j];
int step = new_steps[j];

if (dim < 0)
dim = input_rank + dim;
if (start < 0)
start = target_shape[dim] + start;
if (end < 0)
end = target_shape[dim] + end;
if (end == INT_MAX)
end = target_shape[dim];

target_shape[dim] = (end - start + (step - 1)) / step;
}

Operator* op_view = graph.new_operator_before("Tensor.view", op->name + "_ncnnview", op);
Operand* view_out = graph.new_operand(op->name + "_ncnnview_out");

op_view->params["shape"] = target_shape;

view_out->shape = target_shape;

op_view->inputs.push_back(op->inputs[1]);
op->inputs[1]->consumers.push_back(op_view);
op->inputs[1]->remove_consumer(op);

op_view->outputs.push_back(view_out);
view_out->producer = op_view;

op->inputs[1] = view_out;
view_out->consumers.push_back(op);
}

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_slice_copy.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_slice_copy(Graph& graph);

} // namespace pnnx

+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -171,6 +171,7 @@ pnnx_add_test(Tensor_repeat)
pnnx_add_test(Tensor_reshape)
pnnx_add_test(Tensor_select)
pnnx_add_test(Tensor_slice)
pnnx_add_test(Tensor_slice_copy)
pnnx_add_test(Tensor_view)

pnnx_add_test(torch_addmm)


+ 61
- 0
tools/pnnx/tests/test_Tensor_slice_copy.py View File

@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
x = x.clone()
x[2:10,...] += 1
x[...,1] = x[...,-1] * 3
y = x.clone()
x[:,:,3,::2].clamp_(0, 0.5)
x[:,:,3,::2] = x[:,:,3,::2].exp_()
x[:,:,::2,:] = y[:,:,::2,:].pow(2)
x[:,:,:,:] = x[:,:,:,:] / 2
return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(18, 15, 19, 20)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_Tensor_slice_copy.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_slice_copy.pt inputshape=[18,15,19,20]")

# pnnx inference
import test_Tensor_slice_copy_pnnx
b = test_Tensor_slice_copy_pnnx.test_inference()

return torch.equal(a, b)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save