diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index a55e53192..4e9bc6cab 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -232,6 +232,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_stack.cpp pass_level2/torch_sum.cpp pass_level2/torch_permute.cpp + pass_level2/torch_tensor_split.cpp pass_level2/torch_transpose.cpp pass_level2/torch_unbind.cpp pass_level2/torch_unsqueeze.cpp @@ -265,8 +266,8 @@ set(pnnx_pass_level3_SRCS pass_level3/eliminate_noop_math.cpp pass_level3/eliminate_tuple_pair.cpp pass_level3/expand_quantization_modules.cpp - pass_level3/fuse_cat_stack_tensors.cpp - pass_level3/fuse_chunk_split_unbind_unpack.cpp + pass_level3/fuse_opnto1_tensors.cpp + pass_level3/fuse_op1ton_unpack.cpp pass_level3/fuse_einsum_operands.cpp pass_level3/fuse_expression.cpp pass_level3/fuse_index_expression.cpp @@ -305,6 +306,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_linear_batchnorm1d.cpp pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_slice_indices.cpp + pass_level5/fuse_slice_to_tensor_split.cpp pass_level5/fuse_static_conv.cpp pass_level5/normalize_einsum_equation.cpp pass_level5/unroll_rnn_op.cpp @@ -319,6 +321,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/convert_torch_chunk.cpp pass_ncnn/convert_torch_einsum.cpp pass_ncnn/convert_torch_split.cpp + pass_ncnn/convert_torch_tensor_split.cpp pass_ncnn/convert_torch_unbind.cpp pass_ncnn/convert_Tensor_select.cpp pass_ncnn/eliminate_output.cpp diff --git a/tools/pnnx/src/pass_level2/torch_tensor_split.cpp b/tools/pnnx/src/pass_level2/torch_tensor_split.cpp new file mode 100644 index 000000000..319b70ce5 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_tensor_split.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_tensor_split : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 sections value=%sections +aten::tensor_split op_1 3 1 input sections dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.tensor_split"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tensor_split, 19) + +class torch_tensor_split_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 indices +pnnx.Input input_2 0 1 dim +aten::tensor_split op_0 3 1 input indices dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.tensor_split"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tensor_split_1, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3.cpp b/tools/pnnx/src/pass_level3.cpp index eeda02219..1d7759cc8 100644 --- a/tools/pnnx/src/pass_level3.cpp +++ b/tools/pnnx/src/pass_level3.cpp @@ -18,8 +18,8 @@ #include "pass_level3/eliminate_noop_math.h" #include "pass_level3/eliminate_tuple_pair.h" #include "pass_level3/expand_quantization_modules.h" -#include "pass_level3/fuse_cat_stack_tensors.h" -#include "pass_level3/fuse_chunk_split_unbind_unpack.h" +#include "pass_level3/fuse_opnto1_tensors.h" +#include "pass_level3/fuse_op1ton_unpack.h" #include "pass_level3/fuse_einsum_operands.h" #include "pass_level3/fuse_expression.h" #include "pass_level3/fuse_index_expression.h" @@ -39,9 +39,9 @@ void pass_level3(Graph& g, const std::map& foldable_cons { assign_unique_name(g); - fuse_cat_stack_tensors(g); + fuse_opnto1_tensors(g); - fuse_chunk_split_unbind_unpack(g); + fuse_op1ton_unpack(g); fuse_einsum_operands(g); diff --git a/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp similarity index 92% rename from tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp rename to tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp index ee3ebec14..882b4dd11 100644 --- a/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp +++ b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp @@ -12,13 +12,13 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "fuse_chunk_split_unbind_unpack.h" +#include "fuse_op1ton_unpack.h" #include #include "pass_level2.h" namespace pnnx { -void fuse_chunk_split_unbind_unpack(Graph& graph) +void fuse_op1ton_unpack(Graph& graph) { while (1) { @@ -28,7 +28,7 @@ void fuse_chunk_split_unbind_unpack(Graph& graph) { Operator* op = graph.ops[i]; - if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind") + if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind" && op->type != "torch.tensor_split") continue; if (op->outputs.size() != 1) diff --git a/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h similarity index 93% rename from tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h rename to tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h index da6fa0dec..a584c00cb 100644 --- a/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h +++ b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_chunk_split_unbind_unpack(Graph& graph); +void fuse_op1ton_unpack(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.cpp b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp similarity index 96% rename from tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.cpp rename to tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp index 9a2b279b1..f3dcecbc7 100644 --- a/tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.cpp +++ b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp @@ -12,13 +12,13 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "fuse_cat_stack_tensors.h" +#include "fuse_opnto1_tensors.h" #include #include "pass_level2.h" namespace pnnx { -void fuse_cat_stack_tensors(Graph& graph) +void fuse_opnto1_tensors(Graph& graph) { while (1) { diff --git a/tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.h b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h similarity index 94% rename from tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.h rename to tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h index 81d84bc91..4fb990a48 100644 --- a/tools/pnnx/src/pass_level3/fuse_cat_stack_tensors.h +++ b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_cat_stack_tensors(Graph& graph); +void fuse_opnto1_tensors(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index cdd3573e3..24eb9a971 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -34,6 +34,7 @@ #include "pass_level5/fuse_linear_batchnorm1d.h" #include "pass_level5/fuse_select_to_unbind.h" #include "pass_level5/fuse_slice_indices.h" +#include "pass_level5/fuse_slice_to_tensor_split.h" #include "pass_level5/fuse_static_conv.h" #include "pass_level5/normalize_einsum_equation.h" #include "pass_level4/dead_code_elimination.h" @@ -62,6 +63,8 @@ void pass_level5(Graph& g, const std::map& foldable_cons fuse_select_to_unbind(g); + fuse_slice_to_tensor_split(g); + fuse_static_conv(g); fuse_conv1d_batchnorm1d(g); diff --git a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp new file mode 100644 index 000000000..f7e154f35 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp @@ -0,0 +1,151 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_slice_to_tensor_split.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_slice_to_tensor_split(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.slice") + continue; + + Operand* op_in = op->inputs[0]; + + if (op->params.find("dims") == op->params.end() + || op->params.find("starts") == op->params.end() + || op->params.find("ends") == op->params.end() + || op->params.find("steps") == op->params.end()) + continue; + + if (op->params.at("dims").ai.size() != 1) + continue; + + int dim = op->params.at("dims").ai[0]; + int start = op->params.at("starts").ai[0]; + int end = op->params.at("ends").ai[0]; + int step = op->params.at("steps").ai[0]; + if (start != 0 || step != 1) + continue; + + // slice 0 i j k ... n + std::vector tensor_split_indices; + std::vector slice_n_ops; + + tensor_split_indices.push_back(end); + slice_n_ops.push_back(op); + + bool full_dimsize_slice = false; + while (1) + { + // find slice with starts == end + Operator* op2 = 0; + + for (auto x : op_in->consumers) + { + if (x->type != "Tensor.slice") + continue; + + if (x->inputs[0] != op_in) + continue; + + if (x->params.find("dims") == x->params.end() + || x->params.find("starts") == x->params.end() + || x->params.find("ends") == x->params.end() + || x->params.find("steps") == x->params.end()) + continue; + + if (x->params.at("dims").ai.size() != 1) + continue; + + int dim2 = x->params.at("dims").ai[0]; + int start2 = x->params.at("starts").ai[0]; + int step2 = x->params.at("steps").ai[0]; + if (step2 != 1) + continue; + + if (dim == dim2 && start2 == end) + { + op2 = x; + break; + } + } + + if (!op2) + break; + + int end2 = op2->params.at("ends").ai[0]; + if (end2 == -1) + { + slice_n_ops.push_back(op2); + full_dimsize_slice = true; + break; + } + + tensor_split_indices.push_back(end2); + slice_n_ops.push_back(op2); + + end = end2; + } + + if (!full_dimsize_slice) + continue; + + matched = true; + + // delete all slice ops and replace with tensor_split + Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, op); + op_tensor_split->params["dim"] = dim; + op_tensor_split->params["indices"] = tensor_split_indices; + + op_tensor_split->inputs.push_back(op_in); + for (size_t j = 0; j < slice_n_ops.size(); j++) + { + op_in->consumers.erase(std::find(op_in->consumers.begin(), op_in->consumers.end(), slice_n_ops[j])); + } + op_in->consumers.push_back(op_tensor_split); + + op_tensor_split->outputs.resize(slice_n_ops.size()); + for (size_t j = 0; j < slice_n_ops.size(); j++) + { + op_tensor_split->outputs[j] = slice_n_ops[j]->outputs[0]; + slice_n_ops[j]->outputs[0]->producer = op_tensor_split; + } + + for (size_t j = 0; j < slice_n_ops.size(); j++) + { + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), slice_n_ops[j])); + delete slice_n_ops[j]; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h new file mode 100644 index 000000000..1c172838b --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_slice_to_tensor_split(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp index c2452fb13..595a2ed67 100644 --- a/tools/pnnx/src/pass_ncnn.cpp +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -22,6 +22,7 @@ #include "pass_ncnn/convert_torch_chunk.h" #include "pass_ncnn/convert_torch_einsum.h" #include "pass_ncnn/convert_torch_split.h" +#include "pass_ncnn/convert_torch_tensor_split.h" #include "pass_ncnn/convert_torch_unbind.h" #include "pass_ncnn/convert_Tensor_select.h" #include "pass_ncnn/eliminate_output.h" @@ -90,6 +91,7 @@ void pass_ncnn(Graph& g) ncnn::convert_torch_chunk(g); ncnn::convert_torch_split(g); ncnn::convert_torch_unbind(g); + ncnn::convert_torch_tensor_split(g); ncnn::convert_torch_einsum(g); ncnn::convert_Tensor_select(g); diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp new file mode 100644 index 000000000..989104caa --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp @@ -0,0 +1,102 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_torch_tensor_split.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_tensor_split(Graph& graph) +{ + int op_index = 0; + + for (Operator* op : graph.ops) + { + if (op->type != "torch.tensor_split") + continue; + + op->type = "Slice"; + op->name = std::string("tensor_split_") + std::to_string(op_index++); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int axis = op->params.at("dim").i; + if (axis == batch_index) + { + fprintf(stderr, "tensor_split along batch axis %d is not supported\n", batch_index); + continue; + } + + if (axis < 0) + { + int input_rank = op->inputs[0]->shape.size(); + axis = input_rank + axis; + } + + if (op->params.find("sections") != op->params.end()) + { + int sections = op->params.at("sections").i; + + if (!op->inputs[0]->shape.empty()) + { + int size = op->inputs[0]->shape[axis]; + if (size % sections != 0) + { + fprintf(stderr, "tensor_split with non-perfect divided size %d / %d is not supported\n", size, sections); + } + } + + op->params["0"].type = 5; + op->params["0"].ai.resize(sections, -233); + + op->params.erase("sections"); + } + else + { + const std::vector& indices = op->params.at("indices").ai; + + op->params["0"].type = 5; + op->params["0"].ai.resize(indices.size() + 1); + + for (size_t i = 0; i < indices.size() + 1; i++) + { + if (i == 0) + { + op->params["0"].ai[i] = indices[i]; + } + else if (i == indices.size()) + { + op->params["0"].ai[i] = -233; + } + else + { + op->params["0"].ai[i] = indices[i] - indices[i - 1]; + } + } + + op->params.erase("indices"); + } + + if (axis > batch_index) + axis -= 1; + + op->params["1"] = axis; + op->params.erase("dim"); + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.h b/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.h new file mode 100644 index 000000000..7793029db --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_tensor_split(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index c368e53e8..08ef539b5 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -207,6 +207,7 @@ pnnx_add_test(torch_sum) pnnx_add_test(torch_split) pnnx_add_test(torch_squeeze) pnnx_add_test(torch_stack) +pnnx_add_test(torch_tensor_split) pnnx_add_test(torch_transpose) pnnx_add_test(torch_unbind) pnnx_add_test(torch_unsqueeze) @@ -251,6 +252,7 @@ pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) pnnx_add_test(pnnx_fuse_select_to_unbind) +pnnx_add_test(pnnx_fuse_slice_to_tensor_split) if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") pnnx_add_test(F_mish) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 2418c49ce..affefd771 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -144,6 +144,7 @@ pnnx_ncnn_add_test(torch_permute) pnnx_ncnn_add_test(torch_prod) pnnx_ncnn_add_test(torch_sum) pnnx_ncnn_add_test(torch_squeeze) +pnnx_ncnn_add_test(torch_tensor_split) pnnx_ncnn_add_test(torch_transpose) pnnx_ncnn_add_test(torch_unbind) pnnx_ncnn_add_test(torch_unsqueeze) diff --git a/tools/pnnx/tests/ncnn/test_torch_tensor_split.py b/tools/pnnx/tests/ncnn/test_torch_tensor_split.py new file mode 100644 index 000000000..aa3eb7c59 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_tensor_split.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x0, x1, x2 = torch.tensor_split(x, (12, 13)) + y0, y1, y2 = torch.tensor_split(y, 3, dim=1) + z0, z1 = torch.tensor_split(z, (3,), dim=0) + return x0, x1, x2, y0, y1, y2, z0, z1 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(100) + y = torch.rand(3, 15) + z = torch.rand(5, 9, 3) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_tensor_split.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_tensor_split.pt inputshape=[100],[3,15],[5,9,3]") + + # ncnn inference + import test_torch_tensor_split_ncnn + b = test_torch_tensor_split_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + print(a0.shape) + print(b0.shape) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py b/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py new file mode 100644 index 000000000..7bc545999 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x0 = x[:3] + x1 = x[3:] + + y0 = y[:2,:] + y1 = y[2:4,:] + y2 = y[4:,:] + + z0 = z[:,:,:2] + z1 = z[:,:,2:4] + z2 = z[:,:,4:7] + z3 = z[:,:,7:] + + return x0, x1, y0, y1, y2, z0, z1, z2, z3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(8) + y = torch.rand(9, 10) + z = torch.rand(8, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_pnnx_fuse_slice_to_tensor_split.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_slice_to_tensor_split.pt inputshape=[8],[9,10],[8,9,10]") + + # pnnx inference + import test_pnnx_fuse_slice_to_tensor_split_pnnx + b = test_pnnx_fuse_slice_to_tensor_split_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_tensor_split.py b/tools/pnnx/tests/test_torch_tensor_split.py new file mode 100644 index 000000000..2221d5939 --- /dev/null +++ b/tools/pnnx/tests/test_torch_tensor_split.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x0, x1, x2 = torch.tensor_split(x, (12, 13)) + y0, y1, y2 = torch.tensor_split(y, 3, dim=1) + z0, z1 = torch.tensor_split(z, (3,), dim=0) + w0, w1, w2, w3, w4 = torch.tensor_split(w, (1, 3, 7, 17), dim=3) + return x0, x1, x2, y0, y1, y2, z0, z1, w0, w1, w2, w3, w4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(100) + y = torch.rand(3, 16) + z = torch.rand(5, 9, 3) + w = torch.rand(6, 13, 6, 22) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_tensor_split.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_tensor_split.pt inputshape=[100],[3,16],[5,9,3],[6,13,6,22]") + + # pnnx inference + import test_torch_tensor_split_pnnx + b = test_torch_tensor_split_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)