Browse Source

pnnx fix fused tensor_split operator insert order (#4006)

tags/20220721
nihui GitHub 3 years ago
parent
commit
322667a2ab
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 8 deletions
  1. +6
    -1
      tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp
  2. +7
    -4
      tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py
  3. +4
    -3
      tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py

+ 6
- 1
tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp View File

@@ -57,6 +57,8 @@ void fuse_slice_to_tensor_split(Graph& graph)
tensor_split_indices.push_back(end);
slice_n_ops.push_back(op);

Operator* cur = op;

bool full_dimsize_slice = false;
while (1)
{
@@ -96,6 +98,9 @@ void fuse_slice_to_tensor_split(Graph& graph)
if (!op2)
break;

if (std::find(graph.ops.begin(), graph.ops.end(), op2) < std::find(graph.ops.begin(), graph.ops.end(), cur))
cur = op2;

int end2 = op2->params.at("ends").ai[0];
if (end2 == -1)
{
@@ -116,7 +121,7 @@ void fuse_slice_to_tensor_split(Graph& graph)
matched = true;

// delete all slice ops and replace with tensor_split
Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, op);
Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, cur);
op_tensor_split->params["dim"] = dim;
op_tensor_split->params["indices"] = tensor_split_indices;



+ 7
- 4
tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py View File

@@ -24,15 +24,18 @@ class Model(nn.Module):
x0 = torch.select(x, 0, 0)
x1 = torch.select(x, 0, 1)
x2 = torch.select(x, 0, 2)

z4 = torch.select(x, 2, 4)
z3 = torch.select(x, 2, 3)

y0 = torch.select(x, 1, 0)
y1 = torch.select(x, 1, 1)
y2 = torch.select(x, 1, 2)
y3 = torch.select(x, 1, 3)
z0 = torch.select(x, 2, 0)
z1 = torch.select(x, 2, 1)

z2 = torch.select(x, 2, 2)
z3 = torch.select(x, 2, 3)
z4 = torch.select(x, 2, 4)
z1 = torch.select(x, 2, 1)
z0 = torch.select(x, 2, 0)

return x0, x1, x2, y0, y1, y2, y3, z0, z1, z2, z3, z4



+ 4
- 3
tools/pnnx/tests/test_pnnx_fuse_slice_to_tensor_split.py View File

@@ -24,14 +24,15 @@ class Model(nn.Module):
x0 = x[:3]
x1 = x[3:]

z3 = z[:,:,7:]
z2 = z[:,:,4:7]

y0 = y[:2,:]
y1 = y[2:4,:]
y2 = y[4:,:]

z0 = z[:,:,:2]
z1 = z[:,:,2:4]
z2 = z[:,:,4:7]
z3 = z[:,:,7:]
z0 = z[:,:,:2]

return x0, x1, y0, y1, y2, z0, z1, z2, z3



Loading…
Cancel
Save