Browse Source

flip

pull/6233/head
nihuini 9 months ago
parent
commit
7ec2d03465
No known key found for this signature in database GPG Key ID: 98FD8F4EBC3E5DB8
2 changed files with 4 additions and 7 deletions
  1. +0
    -7
      tools/pnnx/src/pass_level2/torch_flip.cpp
  2. +4
    -0
      tools/pnnx/tests/onnx/test_torch_flip.py

+ 0
- 7
tools/pnnx/src/pass_level2/torch_flip.cpp View File

@@ -55,10 +55,7 @@ pnnx.Output output 1 0 out
int step = captured_params.at("steps").i;

if (axis == 0 && start == -1 && end == INT_MIN + 1 && step == -1)
{
fprintf(stderr, "aaa %d %d %d\n", start, end, step);
return true;
}
}
else // if (captured_params.at("axes").type == 5)
{
@@ -70,14 +67,10 @@ pnnx.Output output 1 0 out
for (size_t i = 0; i < axes.size(); i++)
{
if (starts[i] != -1 || ends[i] != INT_MIN + 1 || steps[i] != -1)
{
fprintf(stderr, "%d %d %d\n", starts[i], ends[i], steps[i]);
return false;
}
}
}

fprintf(stderr, "bbb\n");
return true;
}



+ 4
- 0
tools/pnnx/tests/onnx/test_torch_flip.py View File

@@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
@@ -44,6 +45,9 @@ class Model(nn.Module):
return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
if version.parse(torch.__version__) < version.parse('1.12'):
return True

net = Model()
net.eval()



Loading…
Cancel
Save