Browse Source

pnnx convert torch.tensor_split, fuse full dim size slice to tensor_split (#3988)

tags/20220721
nihui GitHub 3 years ago
parent
commit
b4bae2c9e4
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 584 additions and 13 deletions
  1. +5
    -2
      tools/pnnx/src/CMakeLists.txt
  2. +65
    -0
      tools/pnnx/src/pass_level2/torch_tensor_split.cpp
  3. +4
    -4
      tools/pnnx/src/pass_level3.cpp
  4. +3
    -3
      tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp
  5. +1
    -1
      tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h
  6. +2
    -2
      tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp
  7. +1
    -1
      tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h
  8. +3
    -0
      tools/pnnx/src/pass_level5.cpp
  9. +151
    -0
      tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp
  10. +21
    -0
      tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h
  11. +2
    -0
      tools/pnnx/src/pass_ncnn.cpp
  12. +102
    -0
      tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp
  13. +25
    -0
      tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.h
  14. +2
    -0
      tools/pnnx/tests/CMakeLists.txt
  15. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  16. +63
    -0
      tools/pnnx/tests/ncnn/test_torch_tensor_split.py
  17. +70
    -0
      tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py
  18. +63
    -0
      tools/pnnx/tests/test_torch_tensor_split.py

+ 5
- 2
tools/pnnx/src/CMakeLists.txt View File

@@ -232,6 +232,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_stack.cpp
pass_level2/torch_sum.cpp
pass_level2/torch_permute.cpp
pass_level2/torch_tensor_split.cpp
pass_level2/torch_transpose.cpp
pass_level2/torch_unbind.cpp
pass_level2/torch_unsqueeze.cpp
@@ -265,8 +266,8 @@ set(pnnx_pass_level3_SRCS
pass_level3/eliminate_noop_math.cpp
pass_level3/eliminate_tuple_pair.cpp
pass_level3/expand_quantization_modules.cpp
pass_level3/fuse_cat_stack_tensors.cpp
pass_level3/fuse_chunk_split_unbind_unpack.cpp
pass_level3/fuse_opnto1_tensors.cpp
pass_level3/fuse_op1ton_unpack.cpp
pass_level3/fuse_einsum_operands.cpp
pass_level3/fuse_expression.cpp
pass_level3/fuse_index_expression.cpp
@@ -305,6 +306,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_linear_batchnorm1d.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_slice_indices.cpp
pass_level5/fuse_slice_to_tensor_split.cpp
pass_level5/fuse_static_conv.cpp
pass_level5/normalize_einsum_equation.cpp
pass_level5/unroll_rnn_op.cpp
@@ -319,6 +321,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/convert_torch_chunk.cpp
pass_ncnn/convert_torch_einsum.cpp
pass_ncnn/convert_torch_split.cpp
pass_ncnn/convert_torch_tensor_split.cpp
pass_ncnn/convert_torch_unbind.cpp
pass_ncnn/convert_Tensor_select.cpp
pass_ncnn/eliminate_output.cpp


+ 65
- 0
tools/pnnx/src/pass_level2/torch_tensor_split.cpp View File

@@ -0,0 +1,65 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_tensor_split : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dim
prim::Constant op_0 0 1 sections value=%sections
aten::tensor_split op_1 3 1 input sections dim out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.tensor_split";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tensor_split, 19)

class torch_tensor_split_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 indices
pnnx.Input input_2 0 1 dim
aten::tensor_split op_0 3 1 input indices dim out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.tensor_split";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tensor_split_1, 20)

} // namespace pnnx

+ 4
- 4
tools/pnnx/src/pass_level3.cpp View File

@@ -18,8 +18,8 @@
#include "pass_level3/eliminate_noop_math.h"
#include "pass_level3/eliminate_tuple_pair.h"
#include "pass_level3/expand_quantization_modules.h"
#include "pass_level3/fuse_cat_stack_tensors.h"
#include "pass_level3/fuse_chunk_split_unbind_unpack.h"
#include "pass_level3/fuse_opnto1_tensors.h"
#include "pass_level3/fuse_op1ton_unpack.h"
#include "pass_level3/fuse_einsum_operands.h"
#include "pass_level3/fuse_expression.h"
#include "pass_level3/fuse_index_expression.h"
@@ -39,9 +39,9 @@ void pass_level3(Graph& g, const std::map<std::string, Attribute>& foldable_cons
{
assign_unique_name(g);

fuse_cat_stack_tensors(g);
fuse_opnto1_tensors(g);

fuse_chunk_split_unbind_unpack(g);
fuse_op1ton_unpack(g);

fuse_einsum_operands(g);



tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp → tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp View File

@@ -12,13 +12,13 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_chunk_split_unbind_unpack.h"
#include "fuse_op1ton_unpack.h"
#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_chunk_split_unbind_unpack(Graph& graph)
void fuse_op1ton_unpack(Graph& graph)
{
while (1)
{
@@ -28,7 +28,7 @@ void fuse_chunk_split_unbind_unpack(Graph& graph)
{
Operator* op = graph.ops[i];

if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind")
if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind" && op->type != "torch.tensor_split")
continue;

if (op->outputs.size() != 1)

tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h → tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h View File

@@ -16,6 +16,6 @@

namespace pnnx {

void fuse_chunk_split_unbind_unpack(Graph& graph);
void fuse_op1ton_unpack(Graph& graph);

} // namespace pnnx

tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.cpp → tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp View File

@@ -12,13 +12,13 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_cat_stack_tensors.h"
#include "fuse_opnto1_tensors.h"
#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_cat_stack_tensors(Graph& graph)
void fuse_opnto1_tensors(Graph& graph)
{
while (1)
{

tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.h → tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h View File

@@ -16,6 +16,6 @@

namespace pnnx {

void fuse_cat_stack_tensors(Graph& graph);
void fuse_opnto1_tensors(Graph& graph);

} // namespace pnnx

+ 3
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -34,6 +34,7 @@
#include "pass_level5/fuse_linear_batchnorm1d.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_slice_indices.h"
#include "pass_level5/fuse_slice_to_tensor_split.h"
#include "pass_level5/fuse_static_conv.h"
#include "pass_level5/normalize_einsum_equation.h"
#include "pass_level4/dead_code_elimination.h"
@@ -62,6 +63,8 @@ void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_cons

fuse_select_to_unbind(g);

fuse_slice_to_tensor_split(g);

fuse_static_conv(g);

fuse_conv1d_batchnorm1d(g);


+ 151
- 0
tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp View File

@@ -0,0 +1,151 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_slice_to_tensor_split.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_slice_to_tensor_split(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "Tensor.slice")
continue;

Operand* op_in = op->inputs[0];

if (op->params.find("dims") == op->params.end()
|| op->params.find("starts") == op->params.end()
|| op->params.find("ends") == op->params.end()
|| op->params.find("steps") == op->params.end())
continue;

if (op->params.at("dims").ai.size() != 1)
continue;

int dim = op->params.at("dims").ai[0];
int start = op->params.at("starts").ai[0];
int end = op->params.at("ends").ai[0];
int step = op->params.at("steps").ai[0];
if (start != 0 || step != 1)
continue;

// slice 0 i j k ... n
std::vector<int> tensor_split_indices;
std::vector<Operator*> slice_n_ops;

tensor_split_indices.push_back(end);
slice_n_ops.push_back(op);

bool full_dimsize_slice = false;
while (1)
{
// find slice with starts == end
Operator* op2 = 0;

for (auto x : op_in->consumers)
{
if (x->type != "Tensor.slice")
continue;

if (x->inputs[0] != op_in)
continue;

if (x->params.find("dims") == x->params.end()
|| x->params.find("starts") == x->params.end()
|| x->params.find("ends") == x->params.end()
|| x->params.find("steps") == x->params.end())
continue;

if (x->params.at("dims").ai.size() != 1)
continue;

int dim2 = x->params.at("dims").ai[0];
int start2 = x->params.at("starts").ai[0];
int step2 = x->params.at("steps").ai[0];
if (step2 != 1)
continue;

if (dim == dim2 && start2 == end)
{
op2 = x;
break;
}
}

if (!op2)
break;

int end2 = op2->params.at("ends").ai[0];
if (end2 == -1)
{
slice_n_ops.push_back(op2);
full_dimsize_slice = true;
break;
}

tensor_split_indices.push_back(end2);
slice_n_ops.push_back(op2);

end = end2;
}

if (!full_dimsize_slice)
continue;

matched = true;

// delete all slice ops and replace with tensor_split
Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, op);
op_tensor_split->params["dim"] = dim;
op_tensor_split->params["indices"] = tensor_split_indices;

op_tensor_split->inputs.push_back(op_in);
for (size_t j = 0; j < slice_n_ops.size(); j++)
{
op_in->consumers.erase(std::find(op_in->consumers.begin(), op_in->consumers.end(), slice_n_ops[j]));
}
op_in->consumers.push_back(op_tensor_split);

op_tensor_split->outputs.resize(slice_n_ops.size());
for (size_t j = 0; j < slice_n_ops.size(); j++)
{
op_tensor_split->outputs[j] = slice_n_ops[j]->outputs[0];
slice_n_ops[j]->outputs[0]->producer = op_tensor_split;
}

for (size_t j = 0; j < slice_n_ops.size(); j++)
{
graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), slice_n_ops[j]));
delete slice_n_ops[j];
}

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_slice_to_tensor_split(Graph& graph);

} // namespace pnnx

+ 2
- 0
tools/pnnx/src/pass_ncnn.cpp View File

@@ -22,6 +22,7 @@
#include "pass_ncnn/convert_torch_chunk.h"
#include "pass_ncnn/convert_torch_einsum.h"
#include "pass_ncnn/convert_torch_split.h"
#include "pass_ncnn/convert_torch_tensor_split.h"
#include "pass_ncnn/convert_torch_unbind.h"
#include "pass_ncnn/convert_Tensor_select.h"
#include "pass_ncnn/eliminate_output.h"
@@ -90,6 +91,7 @@ void pass_ncnn(Graph& g)
ncnn::convert_torch_chunk(g);
ncnn::convert_torch_split(g);
ncnn::convert_torch_unbind(g);
ncnn::convert_torch_tensor_split(g);
ncnn::convert_torch_einsum(g);

ncnn::convert_Tensor_select(g);


+ 102
- 0
tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp View File

@@ -0,0 +1,102 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "convert_torch_tensor_split.h"

namespace pnnx {

namespace ncnn {

void convert_torch_tensor_split(Graph& graph)
{
int op_index = 0;

for (Operator* op : graph.ops)
{
if (op->type != "torch.tensor_split")
continue;

op->type = "Slice";
op->name = std::string("tensor_split_") + std::to_string(op_index++);

const int batch_index = op->inputs[0]->params["__batch_index"].i;

int axis = op->params.at("dim").i;
if (axis == batch_index)
{
fprintf(stderr, "tensor_split along batch axis %d is not supported\n", batch_index);
continue;
}

if (axis < 0)
{
int input_rank = op->inputs[0]->shape.size();
axis = input_rank + axis;
}

if (op->params.find("sections") != op->params.end())
{
int sections = op->params.at("sections").i;

if (!op->inputs[0]->shape.empty())
{
int size = op->inputs[0]->shape[axis];
if (size % sections != 0)
{
fprintf(stderr, "tensor_split with non-perfect divided size %d / %d is not supported\n", size, sections);
}
}

op->params["0"].type = 5;
op->params["0"].ai.resize(sections, -233);

op->params.erase("sections");
}
else
{
const std::vector<int>& indices = op->params.at("indices").ai;

op->params["0"].type = 5;
op->params["0"].ai.resize(indices.size() + 1);

for (size_t i = 0; i < indices.size() + 1; i++)
{
if (i == 0)
{
op->params["0"].ai[i] = indices[i];
}
else if (i == indices.size())
{
op->params["0"].ai[i] = -233;
}
else
{
op->params["0"].ai[i] = indices[i] - indices[i - 1];
}
}

op->params.erase("indices");
}

if (axis > batch_index)
axis -= 1;

op->params["1"] = axis;
op->params.erase("dim");
}
}

} // namespace ncnn

} // namespace pnnx

+ 25
- 0
tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.h View File

@@ -0,0 +1,25 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

void convert_torch_tensor_split(Graph& graph);

} // namespace ncnn

} // namespace pnnx

+ 2
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -207,6 +207,7 @@ pnnx_add_test(torch_sum)
pnnx_add_test(torch_split)
pnnx_add_test(torch_squeeze)
pnnx_add_test(torch_stack)
pnnx_add_test(torch_tensor_split)
pnnx_add_test(torch_transpose)
pnnx_add_test(torch_unbind)
pnnx_add_test(torch_unsqueeze)
@@ -251,6 +252,7 @@ pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_linear_batchnorm1d)
pnnx_add_test(pnnx_fuse_select_to_unbind)
pnnx_add_test(pnnx_fuse_slice_to_tensor_split)

if(Torch_VERSION VERSION_GREATER_EQUAL "1.9")
pnnx_add_test(F_mish)


+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -144,6 +144,7 @@ pnnx_ncnn_add_test(torch_permute)
pnnx_ncnn_add_test(torch_prod)
pnnx_ncnn_add_test(torch_sum)
pnnx_ncnn_add_test(torch_squeeze)
pnnx_ncnn_add_test(torch_tensor_split)
pnnx_ncnn_add_test(torch_transpose)
pnnx_ncnn_add_test(torch_unbind)
pnnx_ncnn_add_test(torch_unsqueeze)


+ 63
- 0
tools/pnnx/tests/ncnn/test_torch_tensor_split.py View File

@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x0, x1, x2 = torch.tensor_split(x, (12, 13))
y0, y1, y2 = torch.tensor_split(y, 3, dim=1)
z0, z1 = torch.tensor_split(z, (3,), dim=0)
return x0, x1, x2, y0, y1, y2, z0, z1

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(100)
y = torch.rand(3, 15)
z = torch.rand(5, 9, 3)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_tensor_split.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_tensor_split.pt inputshape=[100],[3,15],[5,9,3]")

# ncnn inference
import test_torch_tensor_split_ncnn
b = test_torch_tensor_split_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
print(a0.shape)
print(b0.shape)
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 70
- 0
tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py View File

@@ -0,0 +1,70 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x0 = x[:3]
x1 = x[3:]

y0 = y[:2,:]
y1 = y[2:4,:]
y2 = y[4:,:]

z0 = z[:,:,:2]
z1 = z[:,:,2:4]
z2 = z[:,:,4:7]
z3 = z[:,:,7:]

return x0, x1, y0, y1, y2, z0, z1, z2, z3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(8)
y = torch.rand(9, 10)
z = torch.rand(8, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_pnnx_fuse_slice_to_tensor_split.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fuse_slice_to_tensor_split.pt inputshape=[8],[9,10],[8,9,10]")

# pnnx inference
import test_pnnx_fuse_slice_to_tensor_split_pnnx
b = test_pnnx_fuse_slice_to_tensor_split_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 63
- 0
tools/pnnx/tests/test_torch_tensor_split.py View File

@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
x0, x1, x2 = torch.tensor_split(x, (12, 13))
y0, y1, y2 = torch.tensor_split(y, 3, dim=1)
z0, z1 = torch.tensor_split(z, (3,), dim=0)
w0, w1, w2, w3, w4 = torch.tensor_split(w, (1, 3, 7, 17), dim=3)
return x0, x1, x2, y0, y1, y2, z0, z1, w0, w1, w2, w3, w4

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(100)
y = torch.rand(3, 16)
z = torch.rand(5, 9, 3)
w = torch.rand(6, 13, 6, 22)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_tensor_split.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_tensor_split.pt inputshape=[100],[3,16],[5,9,3],[6,13,6,22]")

# pnnx inference
import test_torch_tensor_split_pnnx
b = test_torch_tensor_split_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save