diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index cab7bdca1..a1c1c9cf3 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -33,6 +33,7 @@ * [Embed](#embed) * [Exp](#exp) * [Flatten](#flatten) +* [Flip](#flip) * [Fold](#fold) * [GELU](#gelu) * [GLU](#glu) @@ -870,6 +871,14 @@ Reshape blob to 1 dimension * one_blob_only +# Flip + +* one_blob_only + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | axes | array | [ ] | | + # Fold ``` y = fold(x) diff --git a/src/layer/flip.cpp b/src/layer/flip.cpp index ae191c4ed..01201feda 100644 --- a/src/layer/flip.cpp +++ b/src/layer/flip.cpp @@ -1,16 +1,5 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause #include "flip.h" @@ -23,564 +12,104 @@ Flip::Flip() int Flip::load_param(const ParamDict& pd) { - axis = pd.get(0, Mat()); - // 调试 - // const int *axis_ptr = axis; - // printf("axis_len = %d\n", axis.w); - // printf("axis[0] = %d\n", axis_ptr[0]); + axes = pd.get(0, Mat()); + + if (axes.w > 4) + { + // only handle up to 4-dim + return -1; + } + return 0; } int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - // 已知参数 - int dims = bottom_blob.dims; - int w = bottom_blob.w; - int h = bottom_blob.h; - int d = bottom_blob.d; - int channels = bottom_blob.c; - size_t elemsize = bottom_blob.elemsize; - - // 校准输入参数 - if (axis.w > 4) + if (axes.empty()) { - return -1; + top_blob = bottom_blob; + return 0; } - const int* axis_ptr = axis; - if (dims == 1) + const int dims = bottom_blob.dims; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int d = bottom_blob.d; + const int channels = bottom_blob.c; + + int axes_flag[4] = {0}; + bool flip_w = false; + bool flip_h = false; + bool flip_d = false; + bool flip_c = false; { - // 1D 只有一种情况 - top_blob.create(w, elemsize, opt.blob_allocator); - const float* ptr = bottom_blob; - float* outptr = top_blob; - for (int i = 0; i < w; i++) + const int* axes_ptr = axes; + for (int i = 0; i < axes.w; i++) { - outptr[i] = ptr[w - 1 - i]; + int axis = axes_ptr[i]; + // handle negative axis + if (axis < 0) + axis += dims; + axes_flag[axis] = 1; } - } - else if (dims == 2) - { - // 2D 有三种,安装上下、左右和上下左右同时翻转;[-2/0上下翻转, -1/1左右翻转,交叉为上下左右翻转] - top_blob.create(w, h, elemsize, opt.blob_allocator); - if (axis.w == 1) - { - if (axis_ptr[0] == -2 || axis_ptr[0] == 0) - { - // 按照行翻转 - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.row(h - 1 - i); // 从最后一行开始 - float* outptr = top_blob.row(i); // 输出到当前行 - // 直接复制整行数据 - memcpy(outptr, ptr, w * sizeof(float)); - } - } - else - { - // 按照列翻转 - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.row(i); - float* outptr = top_blob.row(i); - - // 使用临时buffer存储反转的行数据 - std::vector line_buffer(w); - for (int j = 0; j < w; j++) - { - line_buffer[j] = ptr[w - 1 - j]; - } - - // 一次性复制整行 - memcpy(outptr, line_buffer.data(), w * sizeof(float)); - } - } - } - else + if (dims == 1) { - // 当axis.w=2时,上下左右都翻转 - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.row(h - 1 - i); // 从最后一行开始读取 - float* outptr = top_blob.row(i); // 输出到当前行 - - // 每行内左右翻转 - for (int j = 0; j < w; j++) - { - outptr[j] = ptr[w - 1 - j]; // 反向读取每行像素 - } - } + flip_w = true; } - } - else if (dims == 3) - { - top_blob.create(w, h, channels, elemsize, opt.blob_allocator); - if (axis.w == 1) + else if (dims == 2) { - // w、h、c - // 约定到正数,简化后续判断 - int axis0 = axis_ptr[0] < 0 ? 3 + axis_ptr[0] : axis_ptr[0]; - if (axis0 == 0) - { - // -3/0 整体上下翻转 - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(channels - 1 - i).row(j); // 从最后一个channel开始 - float* outptr = top_blob.channel(i).row(j); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - else if (axis0 == 1) - { - // -2/1 整体内部上下翻转 - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(i).row(h - 1 - j); - float* outptr = top_blob.channel(i).row(j); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - else - { - // -1/2 整体左右翻转 - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(i).row(j); - float* outptr = top_blob.channel(i).row(j); - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } + if (axes_flag[0] == 1) flip_h = true; + if (axes_flag[1] == 1) flip_w = true; } - else if (axis.w == 2) + else if (dims == 3) { - // ch、cw、hw - int axis0 = axis_ptr[0] < 0 ? 3 + axis_ptr[0] : axis_ptr[0]; - int axis1 = axis_ptr[1] < 0 ? 3 + axis_ptr[1] : axis_ptr[1]; - int axis_sum = axis0 + axis1; - if (axis_sum == 1) - { - // 对应ch - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - // 组合两种翻转:channel维度和行维度同时翻转 - const float* ptr = bottom_blob.channel(channels - 1 - i).row(h - 1 - j); - float* outptr = top_blob.channel(i).row(j); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - else if (axis_sum == 2) - { - // 对应cw - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(channels - 1 - i).row(j); - float* outptr = top_blob.channel(i).row(j); - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - else if (axis_sum == 3) - { - // 对应hw - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(i).row(h - 1 - j); - float* outptr = top_blob.channel(i).row(j); - - // 增加左右翻转 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } + if (axes_flag[0] == 1) flip_c = true; + if (axes_flag[1] == 1) flip_h = true; + if (axes_flag[2] == 1) flip_w = true; } - else + else if (dims == 4) { - // whc - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(channels - 1 - i).row(h - 1 - j); - float* outptr = top_blob.channel(i).row(j); - - // 左右翻转实现完全倒序 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } + if (axes_flag[0] == 1) flip_c = true; + if (axes_flag[1] == 1) flip_d = true; + if (axes_flag[2] == 1) flip_h = true; + if (axes_flag[3] == 1) flip_w = true; } } - else if (dims == 4) - { - top_blob.create(w, h, d, channels, elemsize, opt.blob_allocator); - if (axis.w == 1) - { - // w、h、d、c - int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0]; - if (axis0 == 0) - { - // -4/0 整体上下翻转 torch中按c维度翻转 - for (int c = 0; c < channels; c++) // 遍历channels=3 - { - int flipped_c = channels - 1 - c; // 计算channels翻转位置 - for (int z = 0; z < d; z++) // 遍历d=2维度 - { - for (int j = 0; j < h; j++) // 遍历行 - { - const float* ptr = bottom_blob.channel(c).row(z * h + j); - float* outptr = top_blob.channel(flipped_c).row(z * h + j); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - } - else if (axis0 == 1) - { - // -3/1 torh中按d维度内部上下翻转 - for (int i = 0; i < channels; i++) // 遍历channels - { - for (int z = 0; z < d; z++) // 遍历d维度 - { - for (int j = 0; j < h; j++) // 遍历h维度 - { - // 翻转d维度的数据读取位置 - const float* ptr = bottom_blob.channel(i).row((d - 1 - z) * h + j); - float* outptr = top_blob.channel(i).row(z * h + j); - // 逐行复制w元素 - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - } - else if (axis0 == 2) - { - // -2/2 按torch中H维度翻转 上下 - for (int i = 0; i < channels; i++) - { - for (int z = 0; z < d; z++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(i).row(z * h + (h - 1 - j)); - float* outptr = top_blob.channel(i).row(z * h + j); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - } - else - { - // -1/3 按torch中W维度翻转 左右 - for (int i = 0; i < channels; i++) - { - for (int z = 0; z < d; z++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(i).row(z * h + j); - float* outptr = top_blob.channel(i).row(z * h + j); - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - } - } - else if (axis.w == 2) - { - // dc1、dh2、dw3、ch3、cw4、hw5 - int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0]; - int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1]; - int axis_sum = axis0 + axis1; - if (axis_sum == 1) - { - // 对应dc - for (int c = 0; c < channels; c++) // 遍历channels - { - int flipped_c = channels - 1 - c; // 翻转后的channel位置 - for (int z = 0; z < d; z++) // 遍历d维度 - { - int flipped_d = d - 1 - z; // 翻转后的d位置 - - for (int j = 0; j < h; j++) // 遍历行 - { - const float* ptr = bottom_blob.channel(c).row(z * h + j); - float* outptr = top_blob.channel(flipped_c).row(flipped_d * h + j); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - } - else if (axis_sum == 2) - { - // 对应dh - for (int c = 0; c < channels; c++) // 遍历 channels=2 维度 - { - int flipped_c = channels - 1 - c; // 计算 c 维度翻转位置 (0→1, 1→0) - - for (int z = 0; z < d; z++) // 遍历 d=3 维度 - { - // 按翻转顺序逐行复制 h 维度数据 - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.channel(c).row(z * h + i); - float* outptr = top_blob.channel(flipped_c).row(z * h + (h - 1 - i)); // 保持z维度顺序,翻转h维度 - memcpy(outptr, ptr, w * sizeof(float)); // 按行复制,保持 w 维度顺序 - } - } - } - } - else if (axis_sum == 3) - { - // 对应dw;有一个为0或3 - if (axis0 == 0 || axis0 == 3) - { - // 对应dw - for (int c = 0; c < channels; c++) - { - int flipped_c = channels - 1 - c; // 翻转c维度 - - for (int z = 0; z < d; z++) // d维度保持不变 - { - for (int j = 0; j < h; j++) // h维度保持不变 - { - const float* ptr = bottom_blob.channel(c).row(z * h + j); - float* outptr = top_blob.channel(flipped_c).row(z * h + j); - - // 翻转w维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - } - else - { - // 对应ch - for (int c = 0; c < channels; c++) - { - for (int z = 0; z < d; z++) - { - int flipped_d = d - 1 - z; - - for (int j = 0; j < h; j++) - { - int flipped_h = h - 1 - j; - // 读取源数据 - const float* ptr = bottom_blob.channel(c).row(z * h + j); - float* outptr = top_blob.channel(c).row(flipped_d * h + flipped_h); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - } - } - else if (axis_sum == 4) - { - // 对应cw - for (int c = 0; c < channels; c++) - { - for (int z = 0; z < d; z++) - { - int flipped_d = d - 1 - z; // 翻转 d 维度 - - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(c).row(z * h + j); - float* outptr = top_blob.channel(c).row(flipped_d * h + j); // c维度保持不变 - - // 翻转 w 维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - } - else - { - // 对应hw - for (int c = 0; c < channels; c++) - { - for (int z = 0; z < d; z++) - { - for (int j = 0; j < h; j++) - { - const float* ptr = bottom_blob.channel(c).row(z * h + j); - float* outptr = top_blob.channel(c).row(z * h + (h - 1 - j)); // 翻转 h 维度 + top_blob.create_like(bottom_blob, opt.blob_allocator); + if (top_blob.empty()) + return -100; - // 翻转 w 维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - } - } - else if (axis.w == 3) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int z = 0; z < d; z++) { - // dch3、dcw4、dhw5,chw6 - int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0]; - int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1]; - int axis2 = axis_ptr[2] < 0 ? 4 + axis_ptr[2] : axis_ptr[2]; - int axis_sum = axis0 + axis1 + axis2; - if (axis_sum == 3) - { - // 对应dch,除w外,其余全翻转 - for (int c = 0; c < channels; c++) - { - int flipped_c = channels - 1 - c; - - for (int z = 0; z < d; z++) - { - int flipped_d = d - 1 - z; - - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.channel(c).depth(z).row(i); - float* outptr = top_blob.channel(flipped_c).depth(flipped_d).row(h - 1 - i); - memcpy(outptr, ptr, w * sizeof(float)); - } - } - } - } - else if (axis_sum == 4) - { - // 对应dcw,除h外,其余全翻转 - for (int c = 0; c < channels; c++) - { - int flipped_c = channels - 1 - c; // 翻转c维度 - - for (int z = 0; z < d; z++) - { - int flipped_d = d - 1 - z; // 翻转d维度 - - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.channel(c).row(z * h + i); - float* outptr = top_blob.channel(flipped_c).row(flipped_d * h + i); // h维度保持不变 - - // 翻转w维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - } - else if (axis_sum == 5) + for (int i = 0; i < h; i++) { - // 对应dhw,除了d外全翻转 - for (int c = 0; c < channels; c++) - { - int flipped_c = channels - 1 - c; // 翻转c维度 + int q2 = flip_c ? channels - 1 - q : q; + int z2 = flip_d ? d - 1 - z : z; + int i2 = flip_h ? h - 1 - i : i; - for (int z = 0; z < d; z++) // d维度保持不变 - { - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.channel(c).depth(z).row(i); - float* outptr = top_blob.channel(flipped_c).depth(z).row(h - 1 - i); // 翻转h维度 + const float* ptr = bottom_blob.channel(q2).depth(z2).row(i2); + float* outptr = top_blob.channel(q).depth(z).row(i); - // 翻转w维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } - } - } - } - else if (axis_sum == 6) - { - // 对应chw,除了c外全翻转 - for (int c = 0; c < channels; c++) // c维度保持不变 + if (flip_w) { - for (int z = 0; z < d; z++) + ptr += w - 1; + for (int j = 0; j < w; j++) { - int flipped_d = d - 1 - z; // 翻转d维度 - - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.channel(c).depth(z).row(i); - float* outptr = top_blob.channel(c).depth(flipped_d).row(h - 1 - i); // 翻转h维度 - // 翻转w维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } + *outptr++ = *ptr--; } } - } - } - else - { - // dchw全部翻转 - for (int c = 0; c < channels; c++) - { - int flipped_c = channels - 1 - c; // 翻转c维度 - - for (int z = 0; z < d; z++) + else { - int flipped_d = d - 1 - z; // 翻转d维度 - - for (int i = 0; i < h; i++) - { - const float* ptr = bottom_blob.channel(c).row(z * h + i); - float* outptr = top_blob.channel(flipped_c).row(flipped_d * h + (h - 1 - i)); // 翻转h维度 - - // 翻转w维度 - for (int k = 0; k < w; k++) - { - outptr[k] = ptr[w - 1 - k]; - } - } + memcpy(outptr, ptr, w * sizeof(float)); } } } } - else - { - return -1; - } return 0; } diff --git a/src/layer/flip.h b/src/layer/flip.h index 61a05d453..e675a086a 100644 --- a/src/layer/flip.h +++ b/src/layer/flip.h @@ -1,16 +1,5 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause #ifndef LAYER_FLIP_H #define LAYER_FLIP_H @@ -29,7 +18,7 @@ public: virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; public: - Mat axis; // 翻转维度 + Mat axes; }; } // namespace ncnn diff --git a/tests/test_flip.cpp b/tests/test_flip.cpp index 7ebf787a4..172f80d43 100644 --- a/tests/test_flip.cpp +++ b/tests/test_flip.cpp @@ -1,132 +1,182 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "layer.h" +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + #include "testutil.h" -// 为兼容低于c++11 -// ncnn::Mat axis_mat(axis.size()); -// for (size_t i = 0; i < axis.size(); i++) -// { -// axis_mat[i] = axis[i]; -// } -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static int test_flip(const ncnn::Mat& a, const ncnn::Mat& axis) +static void print_int_array(const std::vector& a) { + fprintf(stderr, "["); + for (size_t i = 0; i < a.size(); i++) + { + fprintf(stderr, " %d", a[i]); + } + fprintf(stderr, " ]"); +} + +static int test_flip(const ncnn::Mat& a, const std::vector& axes_array) +{ + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + ncnn::ParamDict pd; - pd.set(0, axis); + pd.set(0, axes); std::vector weights(0); int ret = test_layer("Flip", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d) axis_w=%d\n", a.dims, a.w, a.h, a.d, a.c, axis.w); + fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); + fprintf(stderr, " axes="); + print_int_array(axes_array); + fprintf(stderr, "\n"); } return ret; } +static int test_flip_nd(const ncnn::Mat& a) +{ + int ret1 = test_flip(a, IntArray(0)); + + if (a.dims == 1 || ret1 != 0) + return ret1; + + int ret2 = 0 + || test_flip(a, IntArray(0)) + || test_flip(a, IntArray(1)) + || test_flip(a, IntArray(0, 1)); + + if (a.dims == 2 || ret2 != 0) + return ret2; + + int ret3 = 0 + || test_flip(a, IntArray(0)) + || test_flip(a, IntArray(1)) + || test_flip(a, IntArray(2)) + || test_flip(a, IntArray(0, 1)) + || test_flip(a, IntArray(0, 2)) + || test_flip(a, IntArray(1, 2)) + || test_flip(a, IntArray(0, 1, 2)); + + if (a.dims == 3 || ret3 != 0) + return ret3; + + int ret4 = 0 + || test_flip(a, IntArray(0)) + || test_flip(a, IntArray(1)) + || test_flip(a, IntArray(2)) + || test_flip(a, IntArray(3)) + || test_flip(a, IntArray(0, 1)) + || test_flip(a, IntArray(0, 2)) + || test_flip(a, IntArray(0, 3)) + || test_flip(a, IntArray(1, 2)) + || test_flip(a, IntArray(1, 3)) + || test_flip(a, IntArray(2, 3)) + || test_flip(a, IntArray(0, 1, 2)) + || test_flip(a, IntArray(0, 1, 3)) + || test_flip(a, IntArray(0, 2, 3)) + || test_flip(a, IntArray(1, 2, 3)) + || test_flip(a, IntArray(0, 1, 2, 3)); + + return ret4; +} + static int test_flip_0() { + ncnn::Mat a = RandomMat(5, 6, 7, 24); + ncnn::Mat b = RandomMat(7, 8, 9, 12); + ncnn::Mat c = RandomMat(3, 4, 5, 13); + return 0 - || test_flip(RandomMat(2, 3, 4, 5), IntArrayMat(0)) - || test_flip(RandomMat(3, 2, 4, 5), IntArrayMat(1)) - || test_flip(RandomMat(4, 3, 2, 5), IntArrayMat(2)) - || test_flip(RandomMat(2, 3, 1, 5), IntArrayMat(3)) - || test_flip(RandomMat(6, 3, 4, 5), IntArrayMat(0, 1)) - || test_flip(RandomMat(2, 3, 1, 6), IntArrayMat(0, 2)) - || test_flip(RandomMat(5, 1, 2, 5), IntArrayMat(0, 3)) - || test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(1, 2)) - || test_flip(RandomMat(4, 5, 2, 3), IntArrayMat(1, 3)) - || test_flip(RandomMat(2, 6, 4, 5), IntArrayMat(2, 3)) - || test_flip(RandomMat(6, 1, 4, 5), IntArrayMat(0, 1, 2)) - || test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(0, 1, 3)) - || test_flip(RandomMat(4, 3, 3, 5), IntArrayMat(0, 2, 3)) - || test_flip(RandomMat(4, 3, 4, 5), IntArrayMat(1, 2, 3)) - || test_flip(RandomMat(6, 3, 3, 2), IntArrayMat(0, 1, 2, 3)); + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); } static int test_flip_1() { + ncnn::Mat a = RandomMat(5, 7, 24); + ncnn::Mat b = RandomMat(7, 9, 12); + ncnn::Mat c = RandomMat(3, 5, 13); + return 0 - || test_flip(RandomMat(2, 3, 5), IntArrayMat(0)) - || test_flip(RandomMat(3, 3, 5), IntArrayMat(1)) - || test_flip(RandomMat(4, 3, 5), IntArrayMat(2)) - || test_flip(RandomMat(3, 1, 5), IntArrayMat(0, 1)) - || test_flip(RandomMat(3, 2, 5), IntArrayMat(0, 2)) - || test_flip(RandomMat(3, 3, 4), IntArrayMat(1, 2)) - || test_flip(RandomMat(4, 3, 2), IntArrayMat(0, 1, 2)); + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); } static int test_flip_2() { + ncnn::Mat a = RandomMat(15, 24); + ncnn::Mat b = RandomMat(17, 12); + ncnn::Mat c = RandomMat(19, 15); + return 0 - || test_flip(RandomMat(8, 2), IntArrayMat(-2)) - || test_flip(RandomMat(16, 3), IntArrayMat(-1)) - || test_flip(RandomMat(7, 2), IntArrayMat(-2, -1)); + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); } static int test_flip_3() { + ncnn::Mat a = RandomMat(128); + ncnn::Mat b = RandomMat(124); + ncnn::Mat c = RandomMat(127); + return 0 - || test_flip(RandomMat(18), IntArrayMat(-1)); + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); } int main() { SRAND(7767517); + return 0 || test_flip_0() || test_flip_1() || test_flip_2() || test_flip_3(); -} \ No newline at end of file +} diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index f50fed155..249228dd1 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -62,14 +62,18 @@ public: : type(2) { if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; i = (int)_l; } Parameter(long long _l) : type(2) { if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; i = (int)_l; } Parameter(float _f) @@ -99,7 +103,9 @@ public: { int64_t _l = x; if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; ai.push_back((int)_l); } } @@ -114,7 +120,9 @@ public: { int64_t _l = x; if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; ai.push_back((int)_l); } } diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index c09ea6526..a5581fc12 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -76,7 +76,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr) type = 2; int64_t i64 = attr.i(); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; break; } @@ -99,7 +101,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr) { int64_t i64 = attr.ints().at(i); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; ai.push_back(i64); } break; @@ -165,7 +169,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr) i64 = tensor.int64_data().at(0); } if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; } else if (tensor.data_type() == onnx::TensorProto::FLOAT) diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 01fa3937a..2e9a2158c 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -100,7 +100,9 @@ Parameter::Parameter(const torch::jit::Node* value_node) type = 2; int64_t i64 = value_node->i(torch::jit::attr::value); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; break; } @@ -141,7 +143,9 @@ Parameter::Parameter(const torch::jit::Node* value_node) type = 2; int64_t i64 = t.item(); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; } else if (t.scalar_type() == c10::ScalarType::Int) @@ -193,7 +197,9 @@ Parameter::Parameter(const torch::jit::Node* value_node) for (auto i64 : i64s) { if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; ai.push_back(i64); } break; diff --git a/tools/pnnx/src/pass_level2/torch_flip.cpp b/tools/pnnx/src/pass_level2/torch_flip.cpp index 39235654d..01ba7524b 100644 --- a/tools/pnnx/src/pass_level2/torch_flip.cpp +++ b/tools/pnnx/src/pass_level2/torch_flip.cpp @@ -27,4 +27,74 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 60) +class torch_flip_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Slice op_0 1 1 input out axes=%axes starts=%starts ends=%ends steps=%steps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.flip"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.at("axes").type == 2) + { + int axis = captured_params.at("axes").i; + int start = captured_params.at("starts").i; + int end = captured_params.at("ends").i; + int step = captured_params.at("steps").i; + + if (axis == 0 && start == -1 && end == INT_MIN + 1 && step == -1) + { + fprintf(stderr, "aaa %d %d %d\n", start, end, step); + return true; + } + } + else // if (captured_params.at("axes").type == 5) + { + const std::vector& axes = captured_params.at("axes").ai; + const std::vector& starts = captured_params.at("starts").ai; + const std::vector& ends = captured_params.at("ends").ai; + const std::vector& steps = captured_params.at("steps").ai; + + for (size_t i = 0; i < axes.size(); i++) + { + if (starts[i] != -1 || ends[i] != INT_MIN + 1 || steps[i] != -1) + { + fprintf(stderr, "%d %d %d\n", starts[i], ends[i], steps[i]); + return false; + } + } + } + + fprintf(stderr, "bbb\n"); + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("axes").type == 2) + { + int dim = captured_params.at("axes").i; + op->params["dims"] = std::vector{dim}; + } + else // if (captured_params.at("axes").type == 5) + { + op->params["dims"] = captured_params.at("axes"); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip_onnx, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_flip.cpp b/tools/pnnx/src/pass_ncnn/torch_flip.cpp index bc0e33485..179f0d292 100644 --- a/tools/pnnx/src/pass_ncnn/torch_flip.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_flip.cpp @@ -1,16 +1,6 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + #include "pass_ncnn.h" namespace pnnx { @@ -44,7 +34,6 @@ pnnx.Output output 1 0 out { const std::vector& dims = captured_params.at("dims").ai; - // 设置参数 op->params["0"] = dims; } }; @@ -53,4 +42,4 @@ REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_flip, 20) } // namespace ncnn -} // namespace pnnx \ No newline at end of file +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx.cpp b/tools/pnnx/src/pass_onnx.cpp index c91f783ab..433a562fd 100644 --- a/tools/pnnx/src/pass_onnx.cpp +++ b/tools/pnnx/src/pass_onnx.cpp @@ -875,7 +875,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph) i64 = tensor.int64_data().at(0); } if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; op_const->params["value"] = (int)i64; } else if (tensor.data_type() == onnx::TensorProto::FLOAT) @@ -961,7 +963,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph) { int64_t i64 = ai[k]; if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; expr += std::to_string(i64); if (k != (int)ai.size() - 1) expr += ","; diff --git a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp index 290def79e..9960a698b 100644 --- a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp +++ b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp @@ -146,7 +146,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model) } if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; onnx::AttributeProto* attr = node->add_attribute(); attr->set_name(std::string(attr_name)); @@ -242,7 +244,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model) for (auto i64 : ai) { if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; attr->add_ints((int)i64); } diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index b39932769..afd2046f0 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -212,6 +212,7 @@ pnnx_add_test(torch_einsum) pnnx_add_test(torch_eq) pnnx_add_test(torch_diag) pnnx_add_test(torch_flatten) +pnnx_add_test(torch_flip) pnnx_add_test(torch_full) pnnx_add_test(torch_full_like) pnnx_add_test(torch_gather) diff --git a/tools/pnnx/tests/ncnn/test_torch_flip.py b/tools/pnnx/tests/ncnn/test_torch_flip.py index b07a8d297..72d016fbe 100644 --- a/tools/pnnx/tests/ncnn/test_torch_flip.py +++ b/tools/pnnx/tests/ncnn/test_torch_flip.py @@ -1,58 +1,15 @@ -# Tencent is pleased to support the open source community by making ncnn available. -# -# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software distributed -# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -# CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause import torch import torch.nn as nn import torch.nn.functional as F -# Tencent is pleased to support the open source community by making ncnn available. -# -# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software distributed -# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -# CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# Tencent is pleased to support the open source community by making ncnn available. -# -# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software distributed -# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -# CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z, d): + def forward(self, x, y, z, w): # 1D x0 = torch.flip(x, [0]) # 2D @@ -68,86 +25,53 @@ class Model(nn.Module): z5 = torch.flip(z, [1, 2]) z6 = torch.flip(z, [0, 1, 2]) # 4D - d0 = torch.flip(d, [-1]) - d1 = torch.flip(d, [-2]) - d2 = torch.flip(d, [-3]) - d3 = torch.flip(d, [-4]) - d4 = torch.flip(d, [0, 1]) - d5 = torch.flip(d, [0, 2]) - d6 = torch.flip(d, [0, 3]) - d7 = torch.flip(d, [1, 2]) - d8 = torch.flip(d, [1, 3]) - d9 = torch.flip(d, [2, 3]) - d10 = torch.flip(d, [0, 1, 2]) - d11 = torch.flip(d, [0, 1, 3]) - d12 = torch.flip(d, [0, 2, 3]) - d13 = torch.flip(d, [1, 2, 3]) - d14 = torch.flip(d, [0, 1, 2, 3]) - - return ( - x0, - y0, - y1, - y2, - z0, - z1, - z2, - z3, - z4, - z5, - z6, - d0, - d1, - d2, - d3, - d4, - d5, - d6, - d7, - d8, - d9, - d10, - d11, - d12, - d13, - d14, - ) - + w0 = torch.flip(w, [-1]) + w1 = torch.flip(w, [-2]) + w2 = torch.flip(w, [-3]) + w3 = torch.flip(w, [-4]) + w4 = torch.flip(w, [0, 1]) + w5 = torch.flip(w, [0, 2]) + w6 = torch.flip(w, [0, 3]) + w7 = torch.flip(w, [1, 2]) + w8 = torch.flip(w, [1, 3]) + w9 = torch.flip(w, [2, 3]) + w10 = torch.flip(w, [0, 1, 2]) + w11 = torch.flip(w, [0, 1, 3]) + w12 = torch.flip(w, [0, 2, 3]) + w13 = torch.flip(w, [1, 2, 3]) + w14 = torch.flip(w, [0, 1, 2, 3]) + + return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14 def test(): net = Model() net.eval() torch.manual_seed(0) - x = torch.rand(36) # 1D - y = torch.rand(4, 7) # 2D - z = torch.rand(3, 4, 5) # 3D - d = torch.rand(4, 2, 6, 7) # 4D + x = torch.rand(36) + y = torch.rand(14, 17) + z = torch.rand(13, 14, 15) + w = torch.rand(48, 12, 16, 17) - a = net(x, y, z, d) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z, d)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_torch_flip.pt") # torchscript to pnnx import os + os.system("../../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]") - os.system( - "../../src/pnnx test_torch_flip.pt inputshape=[36],[4,7],[3,4,5],[4,2,6,7]" - ) - - # pnnx inference + # ncnn inference import test_torch_flip_ncnn - b = test_torch_flip_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-3, 1e-3): + if not torch.equal(a0, b0): return False return True - if __name__ == "__main__": if test(): exit(0) diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 599cf61e4..7b1e8e099 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -157,6 +157,7 @@ pnnx_onnx_add_test(torch_ceil) pnnx_onnx_add_test(torch_chunk) pnnx_onnx_add_test(torch_clamp) pnnx_onnx_add_test(torch_flatten) +pnnx_onnx_add_test(torch_flip) pnnx_onnx_add_test(torch_floor) pnnx_onnx_add_test(torch_logical_not) pnnx_onnx_add_test(torch_logical_and) diff --git a/tools/pnnx/tests/onnx/test_torch_flip.py b/tools/pnnx/tests/onnx/test_torch_flip.py new file mode 100644 index 000000000..8fcaecd55 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_flip.py @@ -0,0 +1,78 @@ +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + # 1D + x0 = torch.flip(x, [0]) + # 2D + y0 = torch.flip(y, [0]) + y1 = torch.flip(y, [1]) + y2 = torch.flip(y, [-2, -1]) + # 3D + z0 = torch.flip(z, [0]) + z1 = torch.flip(z, [1]) + z2 = torch.flip(z, [2]) + z3 = torch.flip(z, [0, 1]) + z4 = torch.flip(z, [0, 2]) + z5 = torch.flip(z, [1, 2]) + z6 = torch.flip(z, [0, 1, 2]) + # 4D + w0 = torch.flip(w, [-1]) + w1 = torch.flip(w, [-2]) + w2 = torch.flip(w, [-3]) + w3 = torch.flip(w, [-4]) + w4 = torch.flip(w, [0, 1]) + w5 = torch.flip(w, [0, 2]) + w6 = torch.flip(w, [0, 3]) + w7 = torch.flip(w, [1, 2]) + w8 = torch.flip(w, [1, 3]) + w9 = torch.flip(w, [2, 3]) + w10 = torch.flip(w, [0, 1, 2]) + w11 = torch.flip(w, [0, 1, 3]) + w12 = torch.flip(w, [0, 2, 3]) + w13 = torch.flip(w, [1, 2, 3]) + w14 = torch.flip(w, [0, 1, 2, 3]) + + return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(36) + y = torch.rand(14, 17) + z = torch.rand(13, 14, 15) + w = torch.rand(48, 12, 16, 17) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_torch_flip.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_flip.onnx inputshape=[36],[14,17],[13,14,15],[48,12,16,17]") + + # pnnx inference + import test_torch_flip_pnnx + b = test_torch_flip_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_flip.py b/tools/pnnx/tests/test_torch_flip.py new file mode 100644 index 000000000..5e5fd5e16 --- /dev/null +++ b/tools/pnnx/tests/test_torch_flip.py @@ -0,0 +1,79 @@ +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + # 1D + x0 = torch.flip(x, [0]) + # 2D + y0 = torch.flip(y, [0]) + y1 = torch.flip(y, [1]) + y2 = torch.flip(y, [-2, -1]) + # 3D + z0 = torch.flip(z, [0]) + z1 = torch.flip(z, [1]) + z2 = torch.flip(z, [2]) + z3 = torch.flip(z, [0, 1]) + z4 = torch.flip(z, [0, 2]) + z5 = torch.flip(z, [1, 2]) + z6 = torch.flip(z, [0, 1, 2]) + # 4D + w0 = torch.flip(w, [-1]) + w1 = torch.flip(w, [-2]) + w2 = torch.flip(w, [-3]) + w3 = torch.flip(w, [-4]) + w4 = torch.flip(w, [0, 1]) + w5 = torch.flip(w, [0, 2]) + w6 = torch.flip(w, [0, 3]) + w7 = torch.flip(w, [1, 2]) + w8 = torch.flip(w, [1, 3]) + w9 = torch.flip(w, [2, 3]) + w10 = torch.flip(w, [0, 1, 2]) + w11 = torch.flip(w, [0, 1, 3]) + w12 = torch.flip(w, [0, 2, 3]) + w13 = torch.flip(w, [1, 2, 3]) + w14 = torch.flip(w, [0, 1, 2, 3]) + + return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(36) + y = torch.rand(14, 17) + z = torch.rand(13, 14, 15) + w = torch.rand(48, 12, 16, 17) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_flip.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]") + + # pnnx inference + import test_torch_flip_pnnx + b = test_torch_flip_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)