Browse Source

implement flip layer and pnnx torch.flip conversion

pull/6233/head
nihuini 11 months ago
parent
commit
84faed0b6d
No known key found for this signature in database GPG Key ID: 98FD8F4EBC3E5DB8
16 changed files with 489 additions and 742 deletions
  1. +9
    -0
      docs/developer-guide/operators.md
  2. +64
    -535
      src/layer/flip.cpp
  3. +3
    -14
      src/layer/flip.h
  4. +123
    -73
      tests/test_flip.cpp
  5. +8
    -0
      tools/pnnx/src/ir.h
  6. +6
    -0
      tools/pnnx/src/load_onnx.cpp
  7. +6
    -0
      tools/pnnx/src/load_torchscript.cpp
  8. +70
    -0
      tools/pnnx/src/pass_level2/torch_flip.cpp
  9. +4
    -15
      tools/pnnx/src/pass_ncnn/torch_flip.cpp
  10. +4
    -0
      tools/pnnx/src/pass_onnx.cpp
  11. +4
    -0
      tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp
  12. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  13. +29
    -105
      tools/pnnx/tests/ncnn/test_torch_flip.py
  14. +1
    -0
      tools/pnnx/tests/onnx/CMakeLists.txt
  15. +78
    -0
      tools/pnnx/tests/onnx/test_torch_flip.py
  16. +79
    -0
      tools/pnnx/tests/test_torch_flip.py

+ 9
- 0
docs/developer-guide/operators.md View File

@@ -33,6 +33,7 @@
* [Embed](#embed)
* [Exp](#exp)
* [Flatten](#flatten)
* [Flip](#flip)
* [Fold](#fold)
* [GELU](#gelu)
* [GLU](#glu)
@@ -870,6 +871,14 @@ Reshape blob to 1 dimension

* one_blob_only

# Flip

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | axes | array | [ ] | |

# Fold
```
y = fold(x)


+ 64
- 535
src/layer/flip.cpp View File

@@ -1,16 +1,5 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "flip.h"

@@ -23,564 +12,104 @@ Flip::Flip()

int Flip::load_param(const ParamDict& pd)
{
axis = pd.get(0, Mat());
// 调试
// const int *axis_ptr = axis;
// printf("axis_len = %d\n", axis.w);
// printf("axis[0] = %d\n", axis_ptr[0]);
axes = pd.get(0, Mat());

if (axes.w > 4)
{
// only handle up to 4-dim
return -1;
}

return 0;
}

int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
// 已知参数
int dims = bottom_blob.dims;
int w = bottom_blob.w;
int h = bottom_blob.h;
int d = bottom_blob.d;
int channels = bottom_blob.c;
size_t elemsize = bottom_blob.elemsize;

// 校准输入参数
if (axis.w > 4)
if (axes.empty())
{
return -1;
top_blob = bottom_blob;
return 0;
}
const int* axis_ptr = axis;

if (dims == 1)
const int dims = bottom_blob.dims;
const int w = bottom_blob.w;
const int h = bottom_blob.h;
const int d = bottom_blob.d;
const int channels = bottom_blob.c;

int axes_flag[4] = {0};
bool flip_w = false;
bool flip_h = false;
bool flip_d = false;
bool flip_c = false;
{
// 1D 只有一种情况
top_blob.create(w, elemsize, opt.blob_allocator);
const float* ptr = bottom_blob;
float* outptr = top_blob;
for (int i = 0; i < w; i++)
const int* axes_ptr = axes;
for (int i = 0; i < axes.w; i++)
{
outptr[i] = ptr[w - 1 - i];
int axis = axes_ptr[i];
// handle negative axis
if (axis < 0)
axis += dims;
axes_flag[axis] = 1;
}
}
else if (dims == 2)
{
// 2D 有三种,安装上下、左右和上下左右同时翻转;[-2/0上下翻转, -1/1左右翻转,交叉为上下左右翻转]
top_blob.create(w, h, elemsize, opt.blob_allocator);
if (axis.w == 1)
{
if (axis_ptr[0] == -2 || axis_ptr[0] == 0)
{
// 按照行翻转
for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.row(h - 1 - i); // 从最后一行开始
float* outptr = top_blob.row(i); // 输出到当前行

// 直接复制整行数据
memcpy(outptr, ptr, w * sizeof(float));
}
}
else
{
// 按照列翻转
for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.row(i);
float* outptr = top_blob.row(i);

// 使用临时buffer存储反转的行数据
std::vector<float> line_buffer(w);
for (int j = 0; j < w; j++)
{
line_buffer[j] = ptr[w - 1 - j];
}

// 一次性复制整行
memcpy(outptr, line_buffer.data(), w * sizeof(float));
}
}
}
else
if (dims == 1)
{
// 当axis.w=2时,上下左右都翻转
for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.row(h - 1 - i); // 从最后一行开始读取
float* outptr = top_blob.row(i); // 输出到当前行

// 每行内左右翻转
for (int j = 0; j < w; j++)
{
outptr[j] = ptr[w - 1 - j]; // 反向读取每行像素
}
}
flip_w = true;
}
}
else if (dims == 3)
{
top_blob.create(w, h, channels, elemsize, opt.blob_allocator);
if (axis.w == 1)
else if (dims == 2)
{
// w、h、c
// 约定到正数,简化后续判断
int axis0 = axis_ptr[0] < 0 ? 3 + axis_ptr[0] : axis_ptr[0];
if (axis0 == 0)
{
// -3/0 整体上下翻转
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(channels - 1 - i).row(j); // 从最后一个channel开始
float* outptr = top_blob.channel(i).row(j);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
else if (axis0 == 1)
{
// -2/1 整体内部上下翻转
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(i).row(h - 1 - j);
float* outptr = top_blob.channel(i).row(j);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
else
{
// -1/2 整体左右翻转
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(i).row(j);
float* outptr = top_blob.channel(i).row(j);
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
if (axes_flag[0] == 1) flip_h = true;
if (axes_flag[1] == 1) flip_w = true;
}
else if (axis.w == 2)
else if (dims == 3)
{
// ch、cw、hw
int axis0 = axis_ptr[0] < 0 ? 3 + axis_ptr[0] : axis_ptr[0];
int axis1 = axis_ptr[1] < 0 ? 3 + axis_ptr[1] : axis_ptr[1];
int axis_sum = axis0 + axis1;
if (axis_sum == 1)
{
// 对应ch
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
// 组合两种翻转:channel维度和行维度同时翻转
const float* ptr = bottom_blob.channel(channels - 1 - i).row(h - 1 - j);
float* outptr = top_blob.channel(i).row(j);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
else if (axis_sum == 2)
{
// 对应cw
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(channels - 1 - i).row(j);
float* outptr = top_blob.channel(i).row(j);
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
else if (axis_sum == 3)
{
// 对应hw
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(i).row(h - 1 - j);
float* outptr = top_blob.channel(i).row(j);

// 增加左右翻转
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
if (axes_flag[0] == 1) flip_c = true;
if (axes_flag[1] == 1) flip_h = true;
if (axes_flag[2] == 1) flip_w = true;
}
else
else if (dims == 4)
{
// whc
for (int i = 0; i < channels; i++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(channels - 1 - i).row(h - 1 - j);
float* outptr = top_blob.channel(i).row(j);

// 左右翻转实现完全倒序
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
if (axes_flag[0] == 1) flip_c = true;
if (axes_flag[1] == 1) flip_d = true;
if (axes_flag[2] == 1) flip_h = true;
if (axes_flag[3] == 1) flip_w = true;
}
}
else if (dims == 4)
{
top_blob.create(w, h, d, channels, elemsize, opt.blob_allocator);
if (axis.w == 1)
{
// w、h、d、c
int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0];
if (axis0 == 0)
{
// -4/0 整体上下翻转 torch中按c维度翻转
for (int c = 0; c < channels; c++) // 遍历channels=3
{
int flipped_c = channels - 1 - c; // 计算channels翻转位置
for (int z = 0; z < d; z++) // 遍历d=2维度
{
for (int j = 0; j < h; j++) // 遍历行
{
const float* ptr = bottom_blob.channel(c).row(z * h + j);
float* outptr = top_blob.channel(flipped_c).row(z * h + j);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
else if (axis0 == 1)
{
// -3/1 torh中按d维度内部上下翻转
for (int i = 0; i < channels; i++) // 遍历channels
{
for (int z = 0; z < d; z++) // 遍历d维度
{
for (int j = 0; j < h; j++) // 遍历h维度
{
// 翻转d维度的数据读取位置
const float* ptr = bottom_blob.channel(i).row((d - 1 - z) * h + j);
float* outptr = top_blob.channel(i).row(z * h + j);
// 逐行复制w元素
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
else if (axis0 == 2)
{
// -2/2 按torch中H维度翻转 上下
for (int i = 0; i < channels; i++)
{
for (int z = 0; z < d; z++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(i).row(z * h + (h - 1 - j));
float* outptr = top_blob.channel(i).row(z * h + j);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
else
{
// -1/3 按torch中W维度翻转 左右
for (int i = 0; i < channels; i++)
{
for (int z = 0; z < d; z++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(i).row(z * h + j);
float* outptr = top_blob.channel(i).row(z * h + j);
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
}
else if (axis.w == 2)
{
// dc1、dh2、dw3、ch3、cw4、hw5
int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0];
int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1];
int axis_sum = axis0 + axis1;
if (axis_sum == 1)
{
// 对应dc
for (int c = 0; c < channels; c++) // 遍历channels
{
int flipped_c = channels - 1 - c; // 翻转后的channel位置

for (int z = 0; z < d; z++) // 遍历d维度
{
int flipped_d = d - 1 - z; // 翻转后的d位置

for (int j = 0; j < h; j++) // 遍历行
{
const float* ptr = bottom_blob.channel(c).row(z * h + j);
float* outptr = top_blob.channel(flipped_c).row(flipped_d * h + j);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
else if (axis_sum == 2)
{
// 对应dh
for (int c = 0; c < channels; c++) // 遍历 channels=2 维度
{
int flipped_c = channels - 1 - c; // 计算 c 维度翻转位置 (0→1, 1→0)

for (int z = 0; z < d; z++) // 遍历 d=3 维度
{
// 按翻转顺序逐行复制 h 维度数据
for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = top_blob.channel(flipped_c).row(z * h + (h - 1 - i)); // 保持z维度顺序,翻转h维度
memcpy(outptr, ptr, w * sizeof(float)); // 按行复制,保持 w 维度顺序
}
}
}
}
else if (axis_sum == 3)
{
// 对应dw;有一个为0或3
if (axis0 == 0 || axis0 == 3)
{
// 对应dw
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度

for (int z = 0; z < d; z++) // d维度保持不变
{
for (int j = 0; j < h; j++) // h维度保持不变
{
const float* ptr = bottom_blob.channel(c).row(z * h + j);
float* outptr = top_blob.channel(flipped_c).row(z * h + j);

// 翻转w维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
else
{
// 对应ch
for (int c = 0; c < channels; c++)
{
for (int z = 0; z < d; z++)
{
int flipped_d = d - 1 - z;

for (int j = 0; j < h; j++)
{
int flipped_h = h - 1 - j;
// 读取源数据
const float* ptr = bottom_blob.channel(c).row(z * h + j);
float* outptr = top_blob.channel(c).row(flipped_d * h + flipped_h);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
}
else if (axis_sum == 4)
{
// 对应cw
for (int c = 0; c < channels; c++)
{
for (int z = 0; z < d; z++)
{
int flipped_d = d - 1 - z; // 翻转 d 维度

for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + j);
float* outptr = top_blob.channel(c).row(flipped_d * h + j); // c维度保持不变

// 翻转 w 维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
else
{
// 对应hw
for (int c = 0; c < channels; c++)
{
for (int z = 0; z < d; z++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + j);
float* outptr = top_blob.channel(c).row(z * h + (h - 1 - j)); // 翻转 h 维度
top_blob.create_like(bottom_blob, opt.blob_allocator);
if (top_blob.empty())
return -100;

// 翻转 w 维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
}
else if (axis.w == 3)
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
for (int z = 0; z < d; z++)
{
// dch3、dcw4、dhw5,chw6
int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0];
int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1];
int axis2 = axis_ptr[2] < 0 ? 4 + axis_ptr[2] : axis_ptr[2];
int axis_sum = axis0 + axis1 + axis2;
if (axis_sum == 3)
{
// 对应dch,除w外,其余全翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c;

for (int z = 0; z < d; z++)
{
int flipped_d = d - 1 - z;

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = top_blob.channel(flipped_c).depth(flipped_d).row(h - 1 - i);
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
else if (axis_sum == 4)
{
// 对应dcw,除h外,其余全翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度

for (int z = 0; z < d; z++)
{
int flipped_d = d - 1 - z; // 翻转d维度

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = top_blob.channel(flipped_c).row(flipped_d * h + i); // h维度保持不变

// 翻转w维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
else if (axis_sum == 5)
for (int i = 0; i < h; i++)
{
// 对应dhw,除了d外全翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度
int q2 = flip_c ? channels - 1 - q : q;
int z2 = flip_d ? d - 1 - z : z;
int i2 = flip_h ? h - 1 - i : i;

for (int z = 0; z < d; z++) // d维度保持不变
{
for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = top_blob.channel(flipped_c).depth(z).row(h - 1 - i); // 翻转h维度
const float* ptr = bottom_blob.channel(q2).depth(z2).row(i2);
float* outptr = top_blob.channel(q).depth(z).row(i);

// 翻转w维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
else if (axis_sum == 6)
{
// 对应chw,除了c外全翻转
for (int c = 0; c < channels; c++) // c维度保持不变
if (flip_w)
{
for (int z = 0; z < d; z++)
ptr += w - 1;
for (int j = 0; j < w; j++)
{
int flipped_d = d - 1 - z; // 翻转d维度

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = top_blob.channel(c).depth(flipped_d).row(h - 1 - i); // 翻转h维度
// 翻转w维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
*outptr++ = *ptr--;
}
}
}
}
else
{
// dchw全部翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度

for (int z = 0; z < d; z++)
else
{
int flipped_d = d - 1 - z; // 翻转d维度

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = top_blob.channel(flipped_c).row(flipped_d * h + (h - 1 - i)); // 翻转h维度

// 翻转w维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
}
else
{
return -1;
}

return 0;
}


+ 3
- 14
src/layer/flip.h View File

@@ -1,16 +1,5 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#ifndef LAYER_FLIP_H
#define LAYER_FLIP_H
@@ -29,7 +18,7 @@ public:
virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
Mat axis; // 翻转维度
Mat axes;
};

} // namespace ncnn


+ 123
- 73
tests/test_flip.cpp View File

@@ -1,132 +1,182 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "layer.h"
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "testutil.h"

// 为兼容低于c++11
// ncnn::Mat axis_mat(axis.size());
// for (size_t i = 0; i < axis.size(); i++)
// {
// axis_mat[i] = axis[i];
// }
static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static int test_flip(const ncnn::Mat& a, const ncnn::Mat& axis)
static void print_int_array(const std::vector<int>& a)
{
fprintf(stderr, "[");
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_flip(const ncnn::Mat& a, const std::vector<int>& axes_array)
{
ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(0, axis);
pd.set(0, axes);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer("Flip", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d) axis_w=%d\n", a.dims, a.w, a.h, a.d, a.c, axis.w);
fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " axes=");
print_int_array(axes_array);
fprintf(stderr, "\n");
}

return ret;
}

static int test_flip_nd(const ncnn::Mat& a)
{
int ret1 = test_flip(a, IntArray(0));

if (a.dims == 1 || ret1 != 0)
return ret1;

int ret2 = 0
|| test_flip(a, IntArray(0))
|| test_flip(a, IntArray(1))
|| test_flip(a, IntArray(0, 1));

if (a.dims == 2 || ret2 != 0)
return ret2;

int ret3 = 0
|| test_flip(a, IntArray(0))
|| test_flip(a, IntArray(1))
|| test_flip(a, IntArray(2))
|| test_flip(a, IntArray(0, 1))
|| test_flip(a, IntArray(0, 2))
|| test_flip(a, IntArray(1, 2))
|| test_flip(a, IntArray(0, 1, 2));

if (a.dims == 3 || ret3 != 0)
return ret3;

int ret4 = 0
|| test_flip(a, IntArray(0))
|| test_flip(a, IntArray(1))
|| test_flip(a, IntArray(2))
|| test_flip(a, IntArray(3))
|| test_flip(a, IntArray(0, 1))
|| test_flip(a, IntArray(0, 2))
|| test_flip(a, IntArray(0, 3))
|| test_flip(a, IntArray(1, 2))
|| test_flip(a, IntArray(1, 3))
|| test_flip(a, IntArray(2, 3))
|| test_flip(a, IntArray(0, 1, 2))
|| test_flip(a, IntArray(0, 1, 3))
|| test_flip(a, IntArray(0, 2, 3))
|| test_flip(a, IntArray(1, 2, 3))
|| test_flip(a, IntArray(0, 1, 2, 3));

return ret4;
}

static int test_flip_0()
{
ncnn::Mat a = RandomMat(5, 6, 7, 24);
ncnn::Mat b = RandomMat(7, 8, 9, 12);
ncnn::Mat c = RandomMat(3, 4, 5, 13);

return 0
|| test_flip(RandomMat(2, 3, 4, 5), IntArrayMat(0))
|| test_flip(RandomMat(3, 2, 4, 5), IntArrayMat(1))
|| test_flip(RandomMat(4, 3, 2, 5), IntArrayMat(2))
|| test_flip(RandomMat(2, 3, 1, 5), IntArrayMat(3))
|| test_flip(RandomMat(6, 3, 4, 5), IntArrayMat(0, 1))
|| test_flip(RandomMat(2, 3, 1, 6), IntArrayMat(0, 2))
|| test_flip(RandomMat(5, 1, 2, 5), IntArrayMat(0, 3))
|| test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(1, 2))
|| test_flip(RandomMat(4, 5, 2, 3), IntArrayMat(1, 3))
|| test_flip(RandomMat(2, 6, 4, 5), IntArrayMat(2, 3))
|| test_flip(RandomMat(6, 1, 4, 5), IntArrayMat(0, 1, 2))
|| test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(0, 1, 3))
|| test_flip(RandomMat(4, 3, 3, 5), IntArrayMat(0, 2, 3))
|| test_flip(RandomMat(4, 3, 4, 5), IntArrayMat(1, 2, 3))
|| test_flip(RandomMat(6, 3, 3, 2), IntArrayMat(0, 1, 2, 3));
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

static int test_flip_1()
{
ncnn::Mat a = RandomMat(5, 7, 24);
ncnn::Mat b = RandomMat(7, 9, 12);
ncnn::Mat c = RandomMat(3, 5, 13);

return 0
|| test_flip(RandomMat(2, 3, 5), IntArrayMat(0))
|| test_flip(RandomMat(3, 3, 5), IntArrayMat(1))
|| test_flip(RandomMat(4, 3, 5), IntArrayMat(2))
|| test_flip(RandomMat(3, 1, 5), IntArrayMat(0, 1))
|| test_flip(RandomMat(3, 2, 5), IntArrayMat(0, 2))
|| test_flip(RandomMat(3, 3, 4), IntArrayMat(1, 2))
|| test_flip(RandomMat(4, 3, 2), IntArrayMat(0, 1, 2));
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

static int test_flip_2()
{
ncnn::Mat a = RandomMat(15, 24);
ncnn::Mat b = RandomMat(17, 12);
ncnn::Mat c = RandomMat(19, 15);

return 0
|| test_flip(RandomMat(8, 2), IntArrayMat(-2))
|| test_flip(RandomMat(16, 3), IntArrayMat(-1))
|| test_flip(RandomMat(7, 2), IntArrayMat(-2, -1));
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

static int test_flip_3()
{
ncnn::Mat a = RandomMat(128);
ncnn::Mat b = RandomMat(124);
ncnn::Mat c = RandomMat(127);

return 0
|| test_flip(RandomMat(18), IntArrayMat(-1));
|| test_flip_nd(a)
|| test_flip_nd(b)
|| test_flip_nd(c);
}

int main()
{
SRAND(7767517);

return 0
|| test_flip_0()
|| test_flip_1()
|| test_flip_2()
|| test_flip_3();
}
}

+ 8
- 0
tools/pnnx/src/ir.h View File

@@ -62,14 +62,18 @@ public:
: type(2)
{
if (_l == std::numeric_limits<long>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<long>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<long>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<long>::min() + 1) _l = INT_MIN + 1;
i = (int)_l;
}
Parameter(long long _l)
: type(2)
{
if (_l == std::numeric_limits<long long>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<long long>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<long long>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<long long>::min() + 1) _l = INT_MIN + 1;
i = (int)_l;
}
Parameter(float _f)
@@ -99,7 +103,9 @@ public:
{
int64_t _l = x;
if (_l == std::numeric_limits<int64_t>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<int64_t>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<int64_t>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<int64_t>::min() + 1) _l = INT_MIN + 1;
ai.push_back((int)_l);
}
}
@@ -114,7 +120,9 @@ public:
{
int64_t _l = x;
if (_l == std::numeric_limits<int64_t>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<int64_t>::max() - 1) _l = INT_MAX - 1;
if (_l == std::numeric_limits<int64_t>::min()) _l = INT_MIN;
if (_l == std::numeric_limits<int64_t>::min() + 1) _l = INT_MIN + 1;
ai.push_back((int)_l);
}
}


+ 6
- 0
tools/pnnx/src/load_onnx.cpp View File

@@ -76,7 +76,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr)
type = 2;
int64_t i64 = attr.i();
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
break;
}
@@ -99,7 +101,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr)
{
int64_t i64 = attr.ints().at(i);
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
ai.push_back(i64);
}
break;
@@ -165,7 +169,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr)
i64 = tensor.int64_data().at(0);
}
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
}
else if (tensor.data_type() == onnx::TensorProto::FLOAT)


+ 6
- 0
tools/pnnx/src/load_torchscript.cpp View File

@@ -100,7 +100,9 @@ Parameter::Parameter(const torch::jit::Node* value_node)
type = 2;
int64_t i64 = value_node->i(torch::jit::attr::value);
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
break;
}
@@ -141,7 +143,9 @@ Parameter::Parameter(const torch::jit::Node* value_node)
type = 2;
int64_t i64 = t.item<int64_t>();
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
i = (int)i64;
}
else if (t.scalar_type() == c10::ScalarType::Int)
@@ -193,7 +197,9 @@ Parameter::Parameter(const torch::jit::Node* value_node)
for (auto i64 : i64s)
{
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
ai.push_back(i64);
}
break;


+ 70
- 0
tools/pnnx/src/pass_level2/torch_flip.cpp View File

@@ -27,4 +27,74 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 60)

class torch_flip_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Slice op_0 1 1 input out axes=%axes starts=%starts ends=%ends steps=%steps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.flip";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("axes").type == 2)
{
int axis = captured_params.at("axes").i;
int start = captured_params.at("starts").i;
int end = captured_params.at("ends").i;
int step = captured_params.at("steps").i;

if (axis == 0 && start == -1 && end == INT_MIN + 1 && step == -1)
{
fprintf(stderr, "aaa %d %d %d\n", start, end, step);
return true;
}
}
else // if (captured_params.at("axes").type == 5)
{
const std::vector<int>& axes = captured_params.at("axes").ai;
const std::vector<int>& starts = captured_params.at("starts").ai;
const std::vector<int>& ends = captured_params.at("ends").ai;
const std::vector<int>& steps = captured_params.at("steps").ai;

for (size_t i = 0; i < axes.size(); i++)
{
if (starts[i] != -1 || ends[i] != INT_MIN + 1 || steps[i] != -1)
{
fprintf(stderr, "%d %d %d\n", starts[i], ends[i], steps[i]);
return false;
}
}
}

fprintf(stderr, "bbb\n");
return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("axes").type == 2)
{
int dim = captured_params.at("axes").i;
op->params["dims"] = std::vector<int>{dim};
}
else // if (captured_params.at("axes").type == 5)
{
op->params["dims"] = captured_params.at("axes");
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip_onnx, 60)

} // namespace pnnx

+ 4
- 15
tools/pnnx/src/pass_ncnn/torch_flip.cpp View File

@@ -1,16 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#include "pass_ncnn.h"

namespace pnnx {
@@ -44,7 +34,6 @@ pnnx.Output output 1 0 out
{
const std::vector<int>& dims = captured_params.at("dims").ai;

// 设置参数
op->params["0"] = dims;
}
};
@@ -53,4 +42,4 @@ REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_flip, 20)

} // namespace ncnn

} // namespace pnnx
} // namespace pnnx

+ 4
- 0
tools/pnnx/src/pass_onnx.cpp View File

@@ -875,7 +875,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
i64 = tensor.int64_data().at(0);
}
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
op_const->params["value"] = (int)i64;
}
else if (tensor.data_type() == onnx::TensorProto::FLOAT)
@@ -961,7 +963,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
{
int64_t i64 = ai[k];
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
expr += std::to_string(i64);
if (k != (int)ai.size() - 1)
expr += ",";


+ 4
- 0
tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp View File

@@ -146,7 +146,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model)
}

if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;

onnx::AttributeProto* attr = node->add_attribute();
attr->set_name(std::string(attr_name));
@@ -242,7 +244,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model)
for (auto i64 : ai)
{
if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;

attr->add_ints((int)i64);
}


+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -212,6 +212,7 @@ pnnx_add_test(torch_einsum)
pnnx_add_test(torch_eq)
pnnx_add_test(torch_diag)
pnnx_add_test(torch_flatten)
pnnx_add_test(torch_flip)
pnnx_add_test(torch_full)
pnnx_add_test(torch_full_like)
pnnx_add_test(torch_gather)


+ 29
- 105
tools/pnnx/tests/ncnn/test_torch_flip.py View File

@@ -1,58 +1,15 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
# Copyright 2025 Tencent
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, d):
def forward(self, x, y, z, w):
# 1D
x0 = torch.flip(x, [0])
# 2D
@@ -68,86 +25,53 @@ class Model(nn.Module):
z5 = torch.flip(z, [1, 2])
z6 = torch.flip(z, [0, 1, 2])
# 4D
d0 = torch.flip(d, [-1])
d1 = torch.flip(d, [-2])
d2 = torch.flip(d, [-3])
d3 = torch.flip(d, [-4])
d4 = torch.flip(d, [0, 1])
d5 = torch.flip(d, [0, 2])
d6 = torch.flip(d, [0, 3])
d7 = torch.flip(d, [1, 2])
d8 = torch.flip(d, [1, 3])
d9 = torch.flip(d, [2, 3])
d10 = torch.flip(d, [0, 1, 2])
d11 = torch.flip(d, [0, 1, 3])
d12 = torch.flip(d, [0, 2, 3])
d13 = torch.flip(d, [1, 2, 3])
d14 = torch.flip(d, [0, 1, 2, 3])

return (
x0,
y0,
y1,
y2,
z0,
z1,
z2,
z3,
z4,
z5,
z6,
d0,
d1,
d2,
d3,
d4,
d5,
d6,
d7,
d8,
d9,
d10,
d11,
d12,
d13,
d14,
)

w0 = torch.flip(w, [-1])
w1 = torch.flip(w, [-2])
w2 = torch.flip(w, [-3])
w3 = torch.flip(w, [-4])
w4 = torch.flip(w, [0, 1])
w5 = torch.flip(w, [0, 2])
w6 = torch.flip(w, [0, 3])
w7 = torch.flip(w, [1, 2])
w8 = torch.flip(w, [1, 3])
w9 = torch.flip(w, [2, 3])
w10 = torch.flip(w, [0, 1, 2])
w11 = torch.flip(w, [0, 1, 3])
w12 = torch.flip(w, [0, 2, 3])
w13 = torch.flip(w, [1, 2, 3])
w14 = torch.flip(w, [0, 1, 2, 3])

return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(36) # 1D
y = torch.rand(4, 7) # 2D
z = torch.rand(3, 4, 5) # 3D
d = torch.rand(4, 2, 6, 7) # 4D
x = torch.rand(36)
y = torch.rand(14, 17)
z = torch.rand(13, 14, 15)
w = torch.rand(48, 12, 16, 17)

a = net(x, y, z, d)
a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, d))
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_flip.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]")

os.system(
"../../src/pnnx test_torch_flip.pt inputshape=[36],[4,7],[3,4,5],[4,2,6,7]"
)

# pnnx inference
# ncnn inference
import test_torch_flip_ncnn

b = test_torch_flip_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
if not torch.equal(a0, b0):
return False
return True


if __name__ == "__main__":
if test():
exit(0)


+ 1
- 0
tools/pnnx/tests/onnx/CMakeLists.txt View File

@@ -157,6 +157,7 @@ pnnx_onnx_add_test(torch_ceil)
pnnx_onnx_add_test(torch_chunk)
pnnx_onnx_add_test(torch_clamp)
pnnx_onnx_add_test(torch_flatten)
pnnx_onnx_add_test(torch_flip)
pnnx_onnx_add_test(torch_floor)
pnnx_onnx_add_test(torch_logical_not)
pnnx_onnx_add_test(torch_logical_and)


+ 78
- 0
tools/pnnx/tests/onnx/test_torch_flip.py View File

@@ -0,0 +1,78 @@
# Copyright 2025 Tencent
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
# 1D
x0 = torch.flip(x, [0])
# 2D
y0 = torch.flip(y, [0])
y1 = torch.flip(y, [1])
y2 = torch.flip(y, [-2, -1])
# 3D
z0 = torch.flip(z, [0])
z1 = torch.flip(z, [1])
z2 = torch.flip(z, [2])
z3 = torch.flip(z, [0, 1])
z4 = torch.flip(z, [0, 2])
z5 = torch.flip(z, [1, 2])
z6 = torch.flip(z, [0, 1, 2])
# 4D
w0 = torch.flip(w, [-1])
w1 = torch.flip(w, [-2])
w2 = torch.flip(w, [-3])
w3 = torch.flip(w, [-4])
w4 = torch.flip(w, [0, 1])
w5 = torch.flip(w, [0, 2])
w6 = torch.flip(w, [0, 3])
w7 = torch.flip(w, [1, 2])
w8 = torch.flip(w, [1, 3])
w9 = torch.flip(w, [2, 3])
w10 = torch.flip(w, [0, 1, 2])
w11 = torch.flip(w, [0, 1, 3])
w12 = torch.flip(w, [0, 2, 3])
w13 = torch.flip(w, [1, 2, 3])
w14 = torch.flip(w, [0, 1, 2, 3])

return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(36)
y = torch.rand(14, 17)
z = torch.rand(13, 14, 15)
w = torch.rand(48, 12, 16, 17)

a = net(x, y, z, w)

# export onnx
torch.onnx.export(net, (x, y, z, w), "test_torch_flip.onnx")

# onnx to pnnx
import os
os.system("../../src/pnnx test_torch_flip.onnx inputshape=[36],[14,17],[13,14,15],[48,12,16,17]")

# pnnx inference
import test_torch_flip_pnnx
b = test_torch_flip_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 79
- 0
tools/pnnx/tests/test_torch_flip.py View File

@@ -0,0 +1,79 @@
# Copyright 2025 Tencent
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
# 1D
x0 = torch.flip(x, [0])
# 2D
y0 = torch.flip(y, [0])
y1 = torch.flip(y, [1])
y2 = torch.flip(y, [-2, -1])
# 3D
z0 = torch.flip(z, [0])
z1 = torch.flip(z, [1])
z2 = torch.flip(z, [2])
z3 = torch.flip(z, [0, 1])
z4 = torch.flip(z, [0, 2])
z5 = torch.flip(z, [1, 2])
z6 = torch.flip(z, [0, 1, 2])
# 4D
w0 = torch.flip(w, [-1])
w1 = torch.flip(w, [-2])
w2 = torch.flip(w, [-3])
w3 = torch.flip(w, [-4])
w4 = torch.flip(w, [0, 1])
w5 = torch.flip(w, [0, 2])
w6 = torch.flip(w, [0, 3])
w7 = torch.flip(w, [1, 2])
w8 = torch.flip(w, [1, 3])
w9 = torch.flip(w, [2, 3])
w10 = torch.flip(w, [0, 1, 2])
w11 = torch.flip(w, [0, 1, 3])
w12 = torch.flip(w, [0, 2, 3])
w13 = torch.flip(w, [1, 2, 3])
w14 = torch.flip(w, [0, 1, 2, 3])

return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(36)
y = torch.rand(14, 17)
z = torch.rand(13, 14, 15)
w = torch.rand(48, 12, 16, 17)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_flip.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]")

# pnnx inference
import test_torch_flip_pnnx
b = test_torch_flip_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save