From 322667a2ab94b36c3fbe82df53d48c9dffbee60d Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 7 Jul 2022 19:50:17 +0800 Subject: [PATCH] pnnx fix fused tensor_split operator insert order (#4006) --- .../src/pass_level5/fuse_slice_to_tensor_split.cpp | 7 ++++++- tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py | 11 +++++++---- .../tests/test_pnnx_fuse_slice_to_tensor_split.py | 7 ++++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp index f7e154f35..2162908b4 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp +++ b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp @@ -57,6 +57,8 @@ void fuse_slice_to_tensor_split(Graph& graph) tensor_split_indices.push_back(end); slice_n_ops.push_back(op); + Operator* cur = op; + bool full_dimsize_slice = false; while (1) { @@ -96,6 +98,9 @@ void fuse_slice_to_tensor_split(Graph& graph) if (!op2) break; + if (std::find(graph.ops.begin(), graph.ops.end(), op2) < std::find(graph.ops.begin(), graph.ops.end(), cur)) + cur = op2; + int end2 = op2->params.at("ends").ai[0]; if (end2 == -1) { @@ -116,7 +121,7 @@ void fuse_slice_to_tensor_split(Graph& graph) matched = true; // delete all slice ops and replace with tensor_split - Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, op); + Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, cur); op_tensor_split->params["dim"] = dim; op_tensor_split->params["indices"] = tensor_split_indices; diff --git a/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py b/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py index 8b0041928..36f65e496 100644 --- a/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py +++ b/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py @@ -24,15 +24,18 @@ class Model(nn.Module): x0 = torch.select(x, 0, 0) x1 = torch.select(x, 0, 1) x2 = torch.select(x, 0, 2) + + z4 = torch.select(x, 2, 4) + z3 = torch.select(x, 2, 3) + y0 = torch.select(x, 1, 0) y1 = torch.select(x, 1, 1) y2 = torch.select(x, 1, 2) y3 = torch.select(x, 1, 3) - z0 = torch.select(x, 2, 0) - z1 = torch.select(x, 2, 1) + z2 = torch.select(x, 2, 2) - z3 = torch.select(x, 2, 3) - z4 = torch.select(x, 2, 4) + z1 = torch.select(x, 2, 1) + z0 = torch.select(x, 2, 0) return x0, x1, x2, y0, y1, y2, y3, z0, z1, z2, z3, z4 diff --git a/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py b/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py index 7bc545999..0db9b7849 100644 --- a/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py +++ b/tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py @@ -24,14 +24,15 @@ class Model(nn.Module): x0 = x[:3] x1 = x[3:] + z3 = z[:,:,7:] + z2 = z[:,:,4:7] + y0 = y[:2,:] y1 = y[2:4,:] y2 = y[4:,:] - z0 = z[:,:,:2] z1 = z[:,:,2:4] - z2 = z[:,:,4:7] - z3 = z[:,:,7:] + z0 = z[:,:,:2] return x0, x1, y0, y1, y2, z0, z1, z2, z3