Browse Source

implement flip layer and pnnx torch.flip conversion (#6233)

Co-authored-by: 佰阅 <43716063+Baiyuetribe@users.noreply.github.com>
pull/6236/head
nihui GitHub 10 months ago
parent
commit
9b91fe5153
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
20 changed files with 730 additions and 0 deletions
  1. +9
    -0
      docs/developer-guide/operators.md
  2. +1
    -0
      src/CMakeLists.txt
  3. +117
    -0
      src/layer/flip.cpp
  4. +26
    -0
      src/layer/flip.h
  5. +1
    -0
      tests/CMakeLists.txt
  6. +182
    -0
      tests/test_flip.cpp
  7. +1
    -0
      tools/pnnx/src/CMakeLists.txt
  8. +8
    -0
      tools/pnnx/src/ir.h
  9. +6
    -0
      tools/pnnx/src/load_onnx.cpp
  10. +6
    -0
      tools/pnnx/src/load_torchscript.cpp
  11. +64
    -0
      tools/pnnx/src/pass_level2/torch_flip.cpp
  12. +58
    -0
      tools/pnnx/src/pass_ncnn/torch_flip.cpp
  13. +4
    -0
      tools/pnnx/src/pass_onnx.cpp
  14. +4
    -0
      tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp
  15. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  16. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  17. +79
    -0
      tools/pnnx/tests/ncnn/test_torch_flip.py
  18. +1
    -0
      tools/pnnx/tests/onnx/CMakeLists.txt
  19. +82
    -0
      tools/pnnx/tests/onnx/test_torch_flip.py
  20. +79
    -0
      tools/pnnx/tests/test_torch_flip.py

+ 9
- 0
docs/developer-guide/operators.md View File

@@ -33,6 +33,7 @@
* [Embed](#embed)
* [Exp](#exp)
* [Flatten](#flatten)
* [Flip](#flip)
* [Fold](#fold)
* [GELU](#gelu)
* [GLU](#glu)
@@ -870,6 +871,14 @@ Reshape blob to 1 dimension

* one_blob_only

# Flip

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | axes | array | [ ] | |

# Fold
```
y = fold(x)


+ 1
- 0
src/CMakeLists.txt View File

@@ -170,6 +170,7 @@ ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)
ncnn_add_layer(Spectrogram)
ncnn_add_layer(InverseSpectrogram)
ncnn_add_layer(Flip)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)


+ 117
- 0
src/layer/flip.cpp View File

@@ -0,0 +1,117 @@
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "flip.h"

namespace ncnn {

Flip::Flip()
{
one_blob_only = true;
}

int Flip::load_param(const ParamDict& pd)
{
axes = pd.get(0, Mat());

if (axes.w > 4)
{
// only handle up to 4-dim
return -1;
}

return 0;
}

int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
if (axes.empty())
{
top_blob = bottom_blob;
return 0;
}

const int dims = bottom_blob.dims;
const int w = bottom_blob.w;
const int h = bottom_blob.h;
const int d = bottom_blob.d;
const int channels = bottom_blob.c;

int axes_flag[4] = {0};
bool flip_w = false;
bool flip_h = false;
bool flip_d = false;
bool flip_c = false;
{
const int* axes_ptr = axes;
for (int i = 0; i < axes.w; i++)
{
int axis = axes_ptr[i];
// handle negative axis
if (axis < 0)
axis += dims;
axes_flag[axis] = 1;
}

if (dims == 1)
{
flip_w = true;
}
else if (dims == 2)
{
if (axes_flag[0] == 1) flip_h = true;
if (axes_flag[1] == 1) flip_w = true;
}
else if (dims == 3)
{
if (axes_flag[0] == 1) flip_c = true;
if (axes_flag[1] == 1) flip_h = true;
if (axes_flag[2] == 1) flip_w = true;
}
else if (dims == 4)
{
if (axes_flag[0] == 1) flip_c = true;
if (axes_flag[1] == 1) flip_d = true;
if (axes_flag[2] == 1) flip_h = true;
if (axes_flag[3] == 1) flip_w = true;
}
}

top_blob.create_like(bottom_blob, opt.blob_allocator);
if (top_blob.empty())
return -100;

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
for (int z = 0; z < d; z++)
{
for (int i = 0; i < h; i++)
{
int q2 = flip_c ? channels - 1 - q : q;
int z2 = flip_d ? d - 1 - z : z;
int i2 = flip_h ? h - 1 - i : i;

const float* ptr = bottom_blob.channel(q2).depth(z2).row(i2);
float* outptr = top_blob.channel(q).depth(z).row(i);

if (flip_w)
{
ptr += w - 1;
for (int j = 0; j < w; j++)
{
*outptr++ = *ptr--;
}
}
else
{
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}

return 0;
}

} // namespace ncnn

+ 26
- 0
src/layer/flip.h View File

@@ -0,0 +1,26 @@
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#ifndef LAYER_FLIP_H
#define LAYER_FLIP_H

#include "layer.h"

namespace ncnn {

class Flip : public Layer
{
public:
Flip();

virtual int load_param(const ParamDict& pd);

virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
Mat axes;
};

} // namespace ncnn

#endif // LAYER_FLIP_H

+ 1
- 0
tests/CMakeLists.txt View File

@@ -107,6 +107,7 @@ ncnn_add_layer_test(Embed)
ncnn_add_layer_test(Erf)
ncnn_add_layer_test(ExpandDims)
ncnn_add_layer_test(Flatten)
ncnn_add_layer_test(Flip)
ncnn_add_layer_test(Fold)
ncnn_add_layer_test(GELU)
ncnn_add_layer_test(GLU)


+ 182
- 0
tests/test_flip.cpp View File

@@ -0,0 +1,182 @@
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "testutil.h"

static std::vector<int> IntArray(int a0)
{
std::vector<int> m(1);
m[0] = a0;
return m;
}

static std::vector<int> IntArray(int a0, int a1)
{
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static std::vector<int> IntArray(int a0, int a1, int a2)
{
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const std::vector<int>& a)
{
fprintf(stderr, "[");
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_flip(const ncnn::Mat& a, const std::vector<int>& axes_array)
{
ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(0, axes);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer("Flip", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " axes=");
print_int_array(axes_array);
fprintf(stderr, "\n");
}

return ret;
}

static int test_flip_nd(const ncnn::Mat& a)
{
int ret1 = test_flip(a, IntArray(0));

if (a.dims == 1 || ret1 != 0)
return ret1;

int ret2 = 0
|| test_flip(a, IntArray(0))
|| test_flip(a, IntArray(1))
|| test_flip(a, IntArray(0, 1));

if (a.dims == 2 || ret2 != 0)
return ret2;

int ret3 = 0
|| test_flip(a, IntArray(0))
|| test_flip(a, IntArray(1))
|| test_flip(a, IntArray(2))
|| test_flip(a, IntArray(0, 1))
|| test_flip(a, IntArray(0, 2))
|| test_flip(a, IntArray(1, 2))
|| test_flip(a, IntArray(0, 1, 2));

if (a.dims == 3 || ret3 != 0)
return ret3;

int ret4 = 0
|| test_flip(a, IntArray(0))
|| test_flip(a, IntArray(1))
|| test_flip(a, IntArray(2))
|| test_flip(a, IntArray(3))
|| test_flip(a, IntArray(0, 1))
|| test_flip(a, IntArray(0, 2))
|| test_flip(a, IntArray(0, 3))
|| test_flip(a, IntArray(1, 2))
|| test_flip(a, IntArray(1, 3))
|| test_flip(a, IntArray(2, 3))
|| test_flip(a, IntArray(0, 1, 2))
|| test_flip(a, IntArray(0, 1, 3))
|| test_flip(a, IntArray(0, 2, 3))
|| test_flip(a, IntArray(1, 2, 3))
|| test_flip(a, IntArray(0, 1, 2, 3));

return ret4;
}

static int test_flip_0()
{
ncnn::Mat a = RandomMat(5, 6, 7, 24);
ncnn::Mat b = RandomMat(7, 8, 9, 12);
ncnn::Mat c = RandomMat(3, 4, 5, 13);

return 0
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

static int test_flip_1()
{
ncnn::Mat a = RandomMat(5, 7, 24);
ncnn::Mat b = RandomMat(7, 9, 12);
ncnn::Mat c = RandomMat(3, 5, 13);

return 0
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

static int test_flip_2()
{
ncnn::Mat a = RandomMat(15, 24);
ncnn::Mat b = RandomMat(17, 12);
ncnn::Mat c = RandomMat(19, 15);

return 0
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

static int test_flip_3()
{
ncnn::Mat a = RandomMat(128);
ncnn::Mat b = RandomMat(124);
ncnn::Mat c = RandomMat(127);

return 0
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

int main()
{
SRAND(7767517);

return 0
|| test_flip_0()
|| test_flip_1()
|| test_flip_2()
|| test_flip_3();
}

+ 1
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -592,6 +592,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_cumsum.cpp
pass_ncnn/torch_diag.cpp
pass_ncnn/torch_flatten.cpp
pass_ncnn/torch_flip.cpp
pass_ncnn/torch_istft.cpp
pass_ncnn/torch_logsumexp.cpp
pass_ncnn/torch_matmul.cpp


+ 8
- 0
tools/pnnx/src/ir.h View File

@@ -62,14 +62,18 @@ public:
: type(2)
{
if (_l == std::numeric_limits<long>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<long>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<long>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<long>::min() + 1) _l = INT_MIN + 1;
i = (int)_l;
}
Parameter(long long _l)
: type(2)
{
if (_l == std::numeric_limits<long long>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<long long>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<long long>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<long long>::min() + 1) _l = INT_MIN + 1;
i = (int)_l;
}
Parameter(float _f)
@@ -99,7 +103,9 @@ public:
{
int64_t _l = x;
if (_l == std::numeric_limits<int64_t>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<int64_t>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<int64_t>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<int64_t>::min() + 1) _l = INT_MIN + 1;
ai.push_back((int)_l);
}
}
@@ -114,7 +120,9 @@ public:
{
int64_t _l = x;
if (_l == std::numeric_limits<int64_t>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<int64_t>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<int64_t>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<int64_t>::min() + 1) _l = INT_MIN + 1;
ai.push_back((int)_l);
}
}


+ 6
- 0
tools/pnnx/src/load_onnx.cpp View File

@@ -76,7 +76,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr)
type = 2;
int64_t i64 = attr.i();
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
break;
}
@@ -99,7 +101,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr)
{
int64_t i64 = attr.ints().at(i);
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
ai.push_back(i64);
}
break;
@@ -165,7 +169,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr)
i64 = tensor.int64_data().at(0);
}
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
}
else if (tensor.data_type() == onnx::TensorProto::FLOAT)


+ 6
- 0
tools/pnnx/src/load_torchscript.cpp View File

@@ -100,7 +100,9 @@ Parameter::Parameter(const torch::jit::Node* value_node)
type = 2;
int64_t i64 = value_node->i(torch::jit::attr::value);
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
break;
}
@@ -141,7 +143,9 @@ Parameter::Parameter(const torch::jit::Node* value_node)
type = 2;
int64_t i64 = t.item<int64_t>();
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
}
else if (t.scalar_type() == c10::ScalarType::Int)
@@ -193,7 +197,9 @@ Parameter::Parameter(const torch::jit::Node* value_node)
for (auto i64 : i64s)
{
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
ai.push_back(i64);
}
break;


+ 64
- 0
tools/pnnx/src/pass_level2/torch_flip.cpp View File

@@ -27,4 +27,68 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 60)

class torch_flip_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Slice op_0 1 1 input out axes=%axes starts=%starts ends=%ends steps=%steps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.flip";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("axes").type == 2)
{
int start = captured_params.at("starts").i;
int end = captured_params.at("ends").i;
int step = captured_params.at("steps").i;

if (start == -1 && end == INT_MIN + 1 && step == -1)
return true;
}
else // if (captured_params.at("axes").type == 5)
{
const std::vector<int>& axes = captured_params.at("axes").ai;
const std::vector<int>& starts = captured_params.at("starts").ai;
const std::vector<int>& ends = captured_params.at("ends").ai;
const std::vector<int>& steps = captured_params.at("steps").ai;

for (size_t i = 0; i < axes.size(); i++)
{
if (starts[i] != -1 || ends[i] != INT_MIN + 1 || steps[i] != -1)
return false;
}

return true;
}

return false;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("axes").type == 2)
{
int dim = captured_params.at("axes").i;
op->params["dims"] = std::vector<int>{dim};
}
else // if (captured_params.at("axes").type == 5)
{
op->params["dims"] = captured_params.at("axes");
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip_onnx, 60)

} // namespace pnnx

+ 58
- 0
tools/pnnx/src/pass_ncnn/torch_flip.cpp View File

@@ -0,0 +1,58 @@
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_flip : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.flip op_0 1 1 input out dims=%dims
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Flip";
}

const char* name_str() const
{
return "flip";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::vector<int>& dims = captured_params.at("dims").ai;

const int batch_index = op->inputs[0]->params["__batch_index"].i;

// drop batch index
std::vector<int> new_dims;
for (int i = 0; i < (int)dims.size(); i++)
{
if (dims[i] == batch_index)
continue;

int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i];
new_dims.push_back(new_dim);
}

op->params["0"] = new_dims;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_flip, 20)

} // namespace ncnn

} // namespace pnnx

+ 4
- 0
tools/pnnx/src/pass_onnx.cpp View File

@@ -875,7 +875,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
i64 = tensor.int64_data().at(0);
}
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
op_const->params["value"] = (int)i64;
}
else if (tensor.data_type() == onnx::TensorProto::FLOAT)
@@ -961,7 +963,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
{
int64_t i64 = ai[k];
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
expr += std::to_string(i64);
if (k != (int)ai.size() - 1)
expr += ",";


+ 4
- 0
tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp View File

@@ -146,7 +146,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model)
}

if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;

onnx::AttributeProto* attr = node->add_attribute();
attr->set_name(std::string(attr_name));
@@ -242,7 +244,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model)
for (auto i64 : ai)
{
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;

attr->add_ints((int)i64);
}


+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -212,6 +212,7 @@ pnnx_add_test(torch_einsum)
pnnx_add_test(torch_eq)
pnnx_add_test(torch_diag)
pnnx_add_test(torch_flatten)
pnnx_add_test(torch_flip)
pnnx_add_test(torch_full)
pnnx_add_test(torch_full_like)
pnnx_add_test(torch_gather)


+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -189,6 +189,7 @@ pnnx_ncnn_add_test(torch_clamp)
pnnx_ncnn_add_test(torch_cos)
pnnx_ncnn_add_test(torch_exp)
pnnx_ncnn_add_test(torch_floor)
pnnx_ncnn_add_test(torch_flip)
pnnx_ncnn_add_test(torch_log)
pnnx_ncnn_add_test(torch_log10)
pnnx_ncnn_add_test(torch_maximum)


+ 79
- 0
tools/pnnx/tests/ncnn/test_torch_flip.py View File

@@ -0,0 +1,79 @@
# Copyright 2025 Tencent
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
# 1D
x0 = torch.flip(x, [0])
# 2D
y0 = torch.flip(y, [0])
y1 = torch.flip(y, [1])
y2 = torch.flip(y, [-2, -1])
# 3D
z0 = torch.flip(z, [0])
z1 = torch.flip(z, [1])
z2 = torch.flip(z, [2])
z3 = torch.flip(z, [0, 1])
z4 = torch.flip(z, [0, 2])
z5 = torch.flip(z, [1, 2])
z6 = torch.flip(z, [0, 1, 2])
# 4D
w0 = torch.flip(w, [-1])
w1 = torch.flip(w, [-2])
w2 = torch.flip(w, [-3])
w3 = torch.flip(w, [-4])
w4 = torch.flip(w, [0, 1])
w5 = torch.flip(w, [0, 2])
w6 = torch.flip(w, [0, 3])
w7 = torch.flip(w, [1, 2])
w8 = torch.flip(w, [1, 3])
w9 = torch.flip(w, [2, 3])
w10 = torch.flip(w, [0, 1, 2])
w11 = torch.flip(w, [0, 1, 3])
w12 = torch.flip(w, [0, 2, 3])
w13 = torch.flip(w, [1, 2, 3])
w14 = torch.flip(w, [0, 1, 2, 3])

return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(36)
y = torch.rand(14, 17)
z = torch.rand(13, 14, 15)
w = torch.rand(48, 12, 16, 17)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_flip.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]")

# ncnn inference
import test_torch_flip_ncnn
b = test_torch_flip_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 1
- 0
tools/pnnx/tests/onnx/CMakeLists.txt View File

@@ -157,6 +157,7 @@ pnnx_onnx_add_test(torch_ceil)
pnnx_onnx_add_test(torch_chunk)
pnnx_onnx_add_test(torch_clamp)
pnnx_onnx_add_test(torch_flatten)
pnnx_onnx_add_test(torch_flip)
pnnx_onnx_add_test(torch_floor)
pnnx_onnx_add_test(torch_logical_not)
pnnx_onnx_add_test(torch_logical_and)


+ 82
- 0
tools/pnnx/tests/onnx/test_torch_flip.py View File

@@ -0,0 +1,82 @@
# Copyright 2025 Tencent
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
# 1D
x0 = torch.flip(x, [0])
# 2D
y0 = torch.flip(y, [0])
y1 = torch.flip(y, [1])
y2 = torch.flip(y, [-2, -1])
# 3D
z0 = torch.flip(z, [0])
z1 = torch.flip(z, [1])
z2 = torch.flip(z, [2])
z3 = torch.flip(z, [0, 1])
z4 = torch.flip(z, [0, 2])
z5 = torch.flip(z, [1, 2])
z6 = torch.flip(z, [0, 1, 2])
# 4D
w0 = torch.flip(w, [-1])
w1 = torch.flip(w, [-2])
w2 = torch.flip(w, [-3])
w3 = torch.flip(w, [-4])
w4 = torch.flip(w, [0, 1])
w5 = torch.flip(w, [0, 2])
w6 = torch.flip(w, [0, 3])
w7 = torch.flip(w, [1, 2])
w8 = torch.flip(w, [1, 3])
w9 = torch.flip(w, [2, 3])
w10 = torch.flip(w, [0, 1, 2])
w11 = torch.flip(w, [0, 1, 3])
w12 = torch.flip(w, [0, 2, 3])
w13 = torch.flip(w, [1, 2, 3])
w14 = torch.flip(w, [0, 1, 2, 3])

return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
if version.parse(torch.__version__) < version.parse('1.12'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(36)
y = torch.rand(14, 17)
z = torch.rand(13, 14, 15)
w = torch.rand(48, 12, 16, 17)

a = net(x, y, z, w)

# export onnx
torch.onnx.export(net, (x, y, z, w), "test_torch_flip.onnx")

# onnx to pnnx
import os
os.system("../../src/pnnx test_torch_flip.onnx inputshape=[36],[14,17],[13,14,15],[48,12,16,17]")

# pnnx inference
import test_torch_flip_pnnx
b = test_torch_flip_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 79
- 0
tools/pnnx/tests/test_torch_flip.py View File

@@ -0,0 +1,79 @@
# Copyright 2025 Tencent
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
# 1D
x0 = torch.flip(x, [0])
# 2D
y0 = torch.flip(y, [0])
y1 = torch.flip(y, [1])
y2 = torch.flip(y, [-2, -1])
# 3D
z0 = torch.flip(z, [0])
z1 = torch.flip(z, [1])
z2 = torch.flip(z, [2])
z3 = torch.flip(z, [0, 1])
z4 = torch.flip(z, [0, 2])
z5 = torch.flip(z, [1, 2])
z6 = torch.flip(z, [0, 1, 2])
# 4D
w0 = torch.flip(w, [-1])
w1 = torch.flip(w, [-2])
w2 = torch.flip(w, [-3])
w3 = torch.flip(w, [-4])
w4 = torch.flip(w, [0, 1])
w5 = torch.flip(w, [0, 2])
w6 = torch.flip(w, [0, 3])
w7 = torch.flip(w, [1, 2])
w8 = torch.flip(w, [1, 3])
w9 = torch.flip(w, [2, 3])
w10 = torch.flip(w, [0, 1, 2])
w11 = torch.flip(w, [0, 1, 3])
w12 = torch.flip(w, [0, 2, 3])
w13 = torch.flip(w, [1, 2, 3])
w14 = torch.flip(w, [0, 1, 2, 3])

return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(36)
y = torch.rand(14, 17)
z = torch.rand(13, 14, 15)
w = torch.rand(48, 12, 16, 17)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_flip.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]")

# pnnx inference
import test_torch_flip_pnnx
b = test_torch_flip_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save