| @@ -232,6 +232,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_stack.cpp | |||
| pass_level2/torch_sum.cpp | |||
| pass_level2/torch_permute.cpp | |||
| pass_level2/torch_tensor_split.cpp | |||
| pass_level2/torch_transpose.cpp | |||
| pass_level2/torch_unbind.cpp | |||
| pass_level2/torch_unsqueeze.cpp | |||
| @@ -265,8 +266,8 @@ set(pnnx_pass_level3_SRCS | |||
| pass_level3/eliminate_noop_math.cpp | |||
| pass_level3/eliminate_tuple_pair.cpp | |||
| pass_level3/expand_quantization_modules.cpp | |||
| pass_level3/fuse_cat_stack_tensors.cpp | |||
| pass_level3/fuse_chunk_split_unbind_unpack.cpp | |||
| pass_level3/fuse_opnto1_tensors.cpp | |||
| pass_level3/fuse_op1ton_unpack.cpp | |||
| pass_level3/fuse_einsum_operands.cpp | |||
| pass_level3/fuse_expression.cpp | |||
| pass_level3/fuse_index_expression.cpp | |||
| @@ -305,6 +306,7 @@ set(pnnx_pass_level5_SRCS | |||
| pass_level5/fuse_linear_batchnorm1d.cpp | |||
| pass_level5/fuse_select_to_unbind.cpp | |||
| pass_level5/fuse_slice_indices.cpp | |||
| pass_level5/fuse_slice_to_tensor_split.cpp | |||
| pass_level5/fuse_static_conv.cpp | |||
| pass_level5/normalize_einsum_equation.cpp | |||
| pass_level5/unroll_rnn_op.cpp | |||
| @@ -319,6 +321,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/convert_torch_chunk.cpp | |||
| pass_ncnn/convert_torch_einsum.cpp | |||
| pass_ncnn/convert_torch_split.cpp | |||
| pass_ncnn/convert_torch_tensor_split.cpp | |||
| pass_ncnn/convert_torch_unbind.cpp | |||
| pass_ncnn/convert_Tensor_select.cpp | |||
| pass_ncnn/eliminate_output.cpp | |||
| @@ -0,0 +1,65 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_tensor_split : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dim | |||
| prim::Constant op_0 0 1 sections value=%sections | |||
| aten::tensor_split op_1 3 1 input sections dim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.tensor_split"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tensor_split, 19) | |||
| class torch_tensor_split_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 indices | |||
| pnnx.Input input_2 0 1 dim | |||
| aten::tensor_split op_0 3 1 input indices dim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.tensor_split"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tensor_split_1, 20) | |||
| } // namespace pnnx | |||
| @@ -18,8 +18,8 @@ | |||
| #include "pass_level3/eliminate_noop_math.h" | |||
| #include "pass_level3/eliminate_tuple_pair.h" | |||
| #include "pass_level3/expand_quantization_modules.h" | |||
| #include "pass_level3/fuse_cat_stack_tensors.h" | |||
| #include "pass_level3/fuse_chunk_split_unbind_unpack.h" | |||
| #include "pass_level3/fuse_opnto1_tensors.h" | |||
| #include "pass_level3/fuse_op1ton_unpack.h" | |||
| #include "pass_level3/fuse_einsum_operands.h" | |||
| #include "pass_level3/fuse_expression.h" | |||
| #include "pass_level3/fuse_index_expression.h" | |||
| @@ -39,9 +39,9 @@ void pass_level3(Graph& g, const std::map<std::string, Attribute>& foldable_cons | |||
| { | |||
| assign_unique_name(g); | |||
| fuse_cat_stack_tensors(g); | |||
| fuse_opnto1_tensors(g); | |||
| fuse_chunk_split_unbind_unpack(g); | |||
| fuse_op1ton_unpack(g); | |||
| fuse_einsum_operands(g); | |||
| @@ -12,13 +12,13 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_chunk_split_unbind_unpack.h" | |||
| #include "fuse_op1ton_unpack.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void fuse_chunk_split_unbind_unpack(Graph& graph) | |||
| void fuse_op1ton_unpack(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| @@ -28,7 +28,7 @@ void fuse_chunk_split_unbind_unpack(Graph& graph) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind") | |||
| if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind" && op->type != "torch.tensor_split") | |||
| continue; | |||
| if (op->outputs.size() != 1) | |||
| @@ -16,6 +16,6 @@ | |||
| namespace pnnx { | |||
| void fuse_chunk_split_unbind_unpack(Graph& graph); | |||
| void fuse_op1ton_unpack(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -12,13 +12,13 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_cat_stack_tensors.h" | |||
| #include "fuse_opnto1_tensors.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void fuse_cat_stack_tensors(Graph& graph) | |||
| void fuse_opnto1_tensors(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| @@ -16,6 +16,6 @@ | |||
| namespace pnnx { | |||
| void fuse_cat_stack_tensors(Graph& graph); | |||
| void fuse_opnto1_tensors(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -34,6 +34,7 @@ | |||
| #include "pass_level5/fuse_linear_batchnorm1d.h" | |||
| #include "pass_level5/fuse_select_to_unbind.h" | |||
| #include "pass_level5/fuse_slice_indices.h" | |||
| #include "pass_level5/fuse_slice_to_tensor_split.h" | |||
| #include "pass_level5/fuse_static_conv.h" | |||
| #include "pass_level5/normalize_einsum_equation.h" | |||
| #include "pass_level4/dead_code_elimination.h" | |||
| @@ -62,6 +63,8 @@ void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_cons | |||
| fuse_select_to_unbind(g); | |||
| fuse_slice_to_tensor_split(g); | |||
| fuse_static_conv(g); | |||
| fuse_conv1d_batchnorm1d(g); | |||
| @@ -0,0 +1,151 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_slice_to_tensor_split.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void fuse_slice_to_tensor_split(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "Tensor.slice") | |||
| continue; | |||
| Operand* op_in = op->inputs[0]; | |||
| if (op->params.find("dims") == op->params.end() | |||
| || op->params.find("starts") == op->params.end() | |||
| || op->params.find("ends") == op->params.end() | |||
| || op->params.find("steps") == op->params.end()) | |||
| continue; | |||
| if (op->params.at("dims").ai.size() != 1) | |||
| continue; | |||
| int dim = op->params.at("dims").ai[0]; | |||
| int start = op->params.at("starts").ai[0]; | |||
| int end = op->params.at("ends").ai[0]; | |||
| int step = op->params.at("steps").ai[0]; | |||
| if (start != 0 || step != 1) | |||
| continue; | |||
| // slice 0 i j k ... n | |||
| std::vector<int> tensor_split_indices; | |||
| std::vector<Operator*> slice_n_ops; | |||
| tensor_split_indices.push_back(end); | |||
| slice_n_ops.push_back(op); | |||
| bool full_dimsize_slice = false; | |||
| while (1) | |||
| { | |||
| // find slice with starts == end | |||
| Operator* op2 = 0; | |||
| for (auto x : op_in->consumers) | |||
| { | |||
| if (x->type != "Tensor.slice") | |||
| continue; | |||
| if (x->inputs[0] != op_in) | |||
| continue; | |||
| if (x->params.find("dims") == x->params.end() | |||
| || x->params.find("starts") == x->params.end() | |||
| || x->params.find("ends") == x->params.end() | |||
| || x->params.find("steps") == x->params.end()) | |||
| continue; | |||
| if (x->params.at("dims").ai.size() != 1) | |||
| continue; | |||
| int dim2 = x->params.at("dims").ai[0]; | |||
| int start2 = x->params.at("starts").ai[0]; | |||
| int step2 = x->params.at("steps").ai[0]; | |||
| if (step2 != 1) | |||
| continue; | |||
| if (dim == dim2 && start2 == end) | |||
| { | |||
| op2 = x; | |||
| break; | |||
| } | |||
| } | |||
| if (!op2) | |||
| break; | |||
| int end2 = op2->params.at("ends").ai[0]; | |||
| if (end2 == -1) | |||
| { | |||
| slice_n_ops.push_back(op2); | |||
| full_dimsize_slice = true; | |||
| break; | |||
| } | |||
| tensor_split_indices.push_back(end2); | |||
| slice_n_ops.push_back(op2); | |||
| end = end2; | |||
| } | |||
| if (!full_dimsize_slice) | |||
| continue; | |||
| matched = true; | |||
| // delete all slice ops and replace with tensor_split | |||
| Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, op); | |||
| op_tensor_split->params["dim"] = dim; | |||
| op_tensor_split->params["indices"] = tensor_split_indices; | |||
| op_tensor_split->inputs.push_back(op_in); | |||
| for (size_t j = 0; j < slice_n_ops.size(); j++) | |||
| { | |||
| op_in->consumers.erase(std::find(op_in->consumers.begin(), op_in->consumers.end(), slice_n_ops[j])); | |||
| } | |||
| op_in->consumers.push_back(op_tensor_split); | |||
| op_tensor_split->outputs.resize(slice_n_ops.size()); | |||
| for (size_t j = 0; j < slice_n_ops.size(); j++) | |||
| { | |||
| op_tensor_split->outputs[j] = slice_n_ops[j]->outputs[0]; | |||
| slice_n_ops[j]->outputs[0]->producer = op_tensor_split; | |||
| } | |||
| for (size_t j = 0; j < slice_n_ops.size(); j++) | |||
| { | |||
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), slice_n_ops[j])); | |||
| delete slice_n_ops[j]; | |||
| } | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void fuse_slice_to_tensor_split(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -22,6 +22,7 @@ | |||
| #include "pass_ncnn/convert_torch_chunk.h" | |||
| #include "pass_ncnn/convert_torch_einsum.h" | |||
| #include "pass_ncnn/convert_torch_split.h" | |||
| #include "pass_ncnn/convert_torch_tensor_split.h" | |||
| #include "pass_ncnn/convert_torch_unbind.h" | |||
| #include "pass_ncnn/convert_Tensor_select.h" | |||
| #include "pass_ncnn/eliminate_output.h" | |||
| @@ -90,6 +91,7 @@ void pass_ncnn(Graph& g) | |||
| ncnn::convert_torch_chunk(g); | |||
| ncnn::convert_torch_split(g); | |||
| ncnn::convert_torch_unbind(g); | |||
| ncnn::convert_torch_tensor_split(g); | |||
| ncnn::convert_torch_einsum(g); | |||
| ncnn::convert_Tensor_select(g); | |||
| @@ -0,0 +1,102 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "convert_torch_tensor_split.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void convert_torch_tensor_split(Graph& graph) | |||
| { | |||
| int op_index = 0; | |||
| for (Operator* op : graph.ops) | |||
| { | |||
| if (op->type != "torch.tensor_split") | |||
| continue; | |||
| op->type = "Slice"; | |||
| op->name = std::string("tensor_split_") + std::to_string(op_index++); | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| int axis = op->params.at("dim").i; | |||
| if (axis == batch_index) | |||
| { | |||
| fprintf(stderr, "tensor_split along batch axis %d is not supported\n", batch_index); | |||
| continue; | |||
| } | |||
| if (axis < 0) | |||
| { | |||
| int input_rank = op->inputs[0]->shape.size(); | |||
| axis = input_rank + axis; | |||
| } | |||
| if (op->params.find("sections") != op->params.end()) | |||
| { | |||
| int sections = op->params.at("sections").i; | |||
| if (!op->inputs[0]->shape.empty()) | |||
| { | |||
| int size = op->inputs[0]->shape[axis]; | |||
| if (size % sections != 0) | |||
| { | |||
| fprintf(stderr, "tensor_split with non-perfect divided size %d / %d is not supported\n", size, sections); | |||
| } | |||
| } | |||
| op->params["0"].type = 5; | |||
| op->params["0"].ai.resize(sections, -233); | |||
| op->params.erase("sections"); | |||
| } | |||
| else | |||
| { | |||
| const std::vector<int>& indices = op->params.at("indices").ai; | |||
| op->params["0"].type = 5; | |||
| op->params["0"].ai.resize(indices.size() + 1); | |||
| for (size_t i = 0; i < indices.size() + 1; i++) | |||
| { | |||
| if (i == 0) | |||
| { | |||
| op->params["0"].ai[i] = indices[i]; | |||
| } | |||
| else if (i == indices.size()) | |||
| { | |||
| op->params["0"].ai[i] = -233; | |||
| } | |||
| else | |||
| { | |||
| op->params["0"].ai[i] = indices[i] - indices[i - 1]; | |||
| } | |||
| } | |||
| op->params.erase("indices"); | |||
| } | |||
| if (axis > batch_index) | |||
| axis -= 1; | |||
| op->params["1"] = axis; | |||
| op->params.erase("dim"); | |||
| } | |||
| } | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,25 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void convert_torch_tensor_split(Graph& graph); | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -207,6 +207,7 @@ pnnx_add_test(torch_sum) | |||
| pnnx_add_test(torch_split) | |||
| pnnx_add_test(torch_squeeze) | |||
| pnnx_add_test(torch_stack) | |||
| pnnx_add_test(torch_tensor_split) | |||
| pnnx_add_test(torch_transpose) | |||
| pnnx_add_test(torch_unbind) | |||
| pnnx_add_test(torch_unsqueeze) | |||
| @@ -251,6 +252,7 @@ pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) | |||
| pnnx_add_test(pnnx_fuse_linear_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_select_to_unbind) | |||
| pnnx_add_test(pnnx_fuse_slice_to_tensor_split) | |||
| if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") | |||
| pnnx_add_test(F_mish) | |||
| @@ -144,6 +144,7 @@ pnnx_ncnn_add_test(torch_permute) | |||
| pnnx_ncnn_add_test(torch_prod) | |||
| pnnx_ncnn_add_test(torch_sum) | |||
| pnnx_ncnn_add_test(torch_squeeze) | |||
| pnnx_ncnn_add_test(torch_tensor_split) | |||
| pnnx_ncnn_add_test(torch_transpose) | |||
| pnnx_ncnn_add_test(torch_unbind) | |||
| pnnx_ncnn_add_test(torch_unsqueeze) | |||
| @@ -0,0 +1,63 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x0, x1, x2 = torch.tensor_split(x, (12, 13)) | |||
| y0, y1, y2 = torch.tensor_split(y, 3, dim=1) | |||
| z0, z1 = torch.tensor_split(z, (3,), dim=0) | |||
| return x0, x1, x2, y0, y1, y2, z0, z1 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(100) | |||
| y = torch.rand(3, 15) | |||
| z = torch.rand(5, 9, 3) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_tensor_split.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_tensor_split.pt inputshape=[100],[3,15],[5,9,3]") | |||
| # ncnn inference | |||
| import test_torch_tensor_split_ncnn | |||
| b = test_torch_tensor_split_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| print(a0.shape) | |||
| print(b0.shape) | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,70 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x0 = x[:3] | |||
| x1 = x[3:] | |||
| y0 = y[:2,:] | |||
| y1 = y[2:4,:] | |||
| y2 = y[4:,:] | |||
| z0 = z[:,:,:2] | |||
| z1 = z[:,:,2:4] | |||
| z2 = z[:,:,4:7] | |||
| z3 = z[:,:,7:] | |||
| return x0, x1, y0, y1, y2, z0, z1, z2, z3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(8) | |||
| y = torch.rand(9, 10) | |||
| z = torch.rand(8, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_pnnx_fuse_slice_to_tensor_split.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_fuse_slice_to_tensor_split.pt inputshape=[8],[9,10],[8,9,10]") | |||
| # pnnx inference | |||
| import test_pnnx_fuse_slice_to_tensor_split_pnnx | |||
| b = test_pnnx_fuse_slice_to_tensor_split_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,63 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| x0, x1, x2 = torch.tensor_split(x, (12, 13)) | |||
| y0, y1, y2 = torch.tensor_split(y, 3, dim=1) | |||
| z0, z1 = torch.tensor_split(z, (3,), dim=0) | |||
| w0, w1, w2, w3, w4 = torch.tensor_split(w, (1, 3, 7, 17), dim=3) | |||
| return x0, x1, x2, y0, y1, y2, z0, z1, w0, w1, w2, w3, w4 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(100) | |||
| y = torch.rand(3, 16) | |||
| z = torch.rand(5, 9, 3) | |||
| w = torch.rand(6, 13, 6, 22) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torch_tensor_split.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_tensor_split.pt inputshape=[100],[3,16],[5,9,3],[6,13,6,22]") | |||
| # pnnx inference | |||
| import test_torch_tensor_split_pnnx | |||
| b = test_torch_tensor_split_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||