From 9b91fe5153dca73ff50fcb9a2e8d0be45be4c14c Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 6 Aug 2025 14:49:16 +0800 Subject: [PATCH] implement flip layer and pnnx torch.flip conversion (#6233) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 佰阅 <43716063+Baiyuetribe@users.noreply.github.com> --- docs/developer-guide/operators.md | 9 + src/CMakeLists.txt | 1 + src/layer/flip.cpp | 117 +++++++++++ src/layer/flip.h | 26 +++ tests/CMakeLists.txt | 1 + tests/test_flip.cpp | 182 ++++++++++++++++++ tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/ir.h | 8 + tools/pnnx/src/load_onnx.cpp | 6 + tools/pnnx/src/load_torchscript.cpp | 6 + tools/pnnx/src/pass_level2/torch_flip.cpp | 64 ++++++ tools/pnnx/src/pass_ncnn/torch_flip.cpp | 58 ++++++ tools/pnnx/src/pass_onnx.cpp | 4 + .../pass_onnx/fuse_constant_as_attribute.cpp | 4 + tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_flip.py | 79 ++++++++ tools/pnnx/tests/onnx/CMakeLists.txt | 1 + tools/pnnx/tests/onnx/test_torch_flip.py | 82 ++++++++ tools/pnnx/tests/test_torch_flip.py | 79 ++++++++ 20 files changed, 730 insertions(+) create mode 100644 src/layer/flip.cpp create mode 100644 src/layer/flip.h create mode 100644 tests/test_flip.cpp create mode 100644 tools/pnnx/src/pass_ncnn/torch_flip.cpp create mode 100644 tools/pnnx/tests/ncnn/test_torch_flip.py create mode 100644 tools/pnnx/tests/onnx/test_torch_flip.py create mode 100644 tools/pnnx/tests/test_torch_flip.py diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index cab7bdca1..a1c1c9cf3 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -33,6 +33,7 @@ * [Embed](#embed) * [Exp](#exp) * [Flatten](#flatten) +* [Flip](#flip) * [Fold](#fold) * [GELU](#gelu) * [GLU](#glu) @@ -870,6 +871,14 @@ Reshape blob to 1 dimension * one_blob_only +# Flip + +* one_blob_only + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | axes | array | [ ] | | + # Fold ``` y = fold(x) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cb7570a02..86b7ec319 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -170,6 +170,7 @@ ncnn_add_layer(Shrink) ncnn_add_layer(RMSNorm) ncnn_add_layer(Spectrogram) ncnn_add_layer(InverseSpectrogram) +ncnn_add_layer(Flip) if(NCNN_VULKAN) ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) diff --git a/src/layer/flip.cpp b/src/layer/flip.cpp new file mode 100644 index 000000000..01201feda --- /dev/null +++ b/src/layer/flip.cpp @@ -0,0 +1,117 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "flip.h" + +namespace ncnn { + +Flip::Flip() +{ + one_blob_only = true; +} + +int Flip::load_param(const ParamDict& pd) +{ + axes = pd.get(0, Mat()); + + if (axes.w > 4) + { + // only handle up to 4-dim + return -1; + } + + return 0; +} + +int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + if (axes.empty()) + { + top_blob = bottom_blob; + return 0; + } + + const int dims = bottom_blob.dims; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int d = bottom_blob.d; + const int channels = bottom_blob.c; + + int axes_flag[4] = {0}; + bool flip_w = false; + bool flip_h = false; + bool flip_d = false; + bool flip_c = false; + { + const int* axes_ptr = axes; + for (int i = 0; i < axes.w; i++) + { + int axis = axes_ptr[i]; + // handle negative axis + if (axis < 0) + axis += dims; + axes_flag[axis] = 1; + } + + if (dims == 1) + { + flip_w = true; + } + else if (dims == 2) + { + if (axes_flag[0] == 1) flip_h = true; + if (axes_flag[1] == 1) flip_w = true; + } + else if (dims == 3) + { + if (axes_flag[0] == 1) flip_c = true; + if (axes_flag[1] == 1) flip_h = true; + if (axes_flag[2] == 1) flip_w = true; + } + else if (dims == 4) + { + if (axes_flag[0] == 1) flip_c = true; + if (axes_flag[1] == 1) flip_d = true; + if (axes_flag[2] == 1) flip_h = true; + if (axes_flag[3] == 1) flip_w = true; + } + } + + top_blob.create_like(bottom_blob, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int z = 0; z < d; z++) + { + for (int i = 0; i < h; i++) + { + int q2 = flip_c ? channels - 1 - q : q; + int z2 = flip_d ? d - 1 - z : z; + int i2 = flip_h ? h - 1 - i : i; + + const float* ptr = bottom_blob.channel(q2).depth(z2).row(i2); + float* outptr = top_blob.channel(q).depth(z).row(i); + + if (flip_w) + { + ptr += w - 1; + for (int j = 0; j < w; j++) + { + *outptr++ = *ptr--; + } + } + else + { + memcpy(outptr, ptr, w * sizeof(float)); + } + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/flip.h b/src/layer/flip.h new file mode 100644 index 000000000..e675a086a --- /dev/null +++ b/src/layer/flip.h @@ -0,0 +1,26 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_FLIP_H +#define LAYER_FLIP_H + +#include "layer.h" + +namespace ncnn { + +class Flip : public Layer +{ +public: + Flip(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +public: + Mat axes; +}; + +} // namespace ncnn + +#endif // LAYER_FLIP_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9d5b6517e..5a0940e88 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -107,6 +107,7 @@ ncnn_add_layer_test(Embed) ncnn_add_layer_test(Erf) ncnn_add_layer_test(ExpandDims) ncnn_add_layer_test(Flatten) +ncnn_add_layer_test(Flip) ncnn_add_layer_test(Fold) ncnn_add_layer_test(GELU) ncnn_add_layer_test(GLU) diff --git a/tests/test_flip.cpp b/tests/test_flip.cpp new file mode 100644 index 000000000..172f80d43 --- /dev/null +++ b/tests/test_flip.cpp @@ -0,0 +1,182 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "testutil.h" + +static std::vector IntArray(int a0) +{ + std::vector m(1); + m[0] = a0; + return m; +} + +static std::vector IntArray(int a0, int a1) +{ + std::vector m(2); + m[0] = a0; + m[1] = a1; + return m; +} + +static std::vector IntArray(int a0, int a1, int a2) +{ + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; + return m; +} + +static std::vector IntArray(int a0, int a1, int a2, int a3) +{ + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; + return m; +} + +static void print_int_array(const std::vector& a) +{ + fprintf(stderr, "["); + for (size_t i = 0; i < a.size(); i++) + { + fprintf(stderr, " %d", a[i]); + } + fprintf(stderr, " ]"); +} + +static int test_flip(const ncnn::Mat& a, const std::vector& axes_array) +{ + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + + ncnn::ParamDict pd; + pd.set(0, axes); + + std::vector weights(0); + + int ret = test_layer("Flip", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); + fprintf(stderr, " axes="); + print_int_array(axes_array); + fprintf(stderr, "\n"); + } + + return ret; +} + +static int test_flip_nd(const ncnn::Mat& a) +{ + int ret1 = test_flip(a, IntArray(0)); + + if (a.dims == 1 || ret1 != 0) + return ret1; + + int ret2 = 0 + || test_flip(a, IntArray(0)) + || test_flip(a, IntArray(1)) + || test_flip(a, IntArray(0, 1)); + + if (a.dims == 2 || ret2 != 0) + return ret2; + + int ret3 = 0 + || test_flip(a, IntArray(0)) + || test_flip(a, IntArray(1)) + || test_flip(a, IntArray(2)) + || test_flip(a, IntArray(0, 1)) + || test_flip(a, IntArray(0, 2)) + || test_flip(a, IntArray(1, 2)) + || test_flip(a, IntArray(0, 1, 2)); + + if (a.dims == 3 || ret3 != 0) + return ret3; + + int ret4 = 0 + || test_flip(a, IntArray(0)) + || test_flip(a, IntArray(1)) + || test_flip(a, IntArray(2)) + || test_flip(a, IntArray(3)) + || test_flip(a, IntArray(0, 1)) + || test_flip(a, IntArray(0, 2)) + || test_flip(a, IntArray(0, 3)) + || test_flip(a, IntArray(1, 2)) + || test_flip(a, IntArray(1, 3)) + || test_flip(a, IntArray(2, 3)) + || test_flip(a, IntArray(0, 1, 2)) + || test_flip(a, IntArray(0, 1, 3)) + || test_flip(a, IntArray(0, 2, 3)) + || test_flip(a, IntArray(1, 2, 3)) + || test_flip(a, IntArray(0, 1, 2, 3)); + + return ret4; +} + +static int test_flip_0() +{ + ncnn::Mat a = RandomMat(5, 6, 7, 24); + ncnn::Mat b = RandomMat(7, 8, 9, 12); + ncnn::Mat c = RandomMat(3, 4, 5, 13); + + return 0 + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); +} + +static int test_flip_1() +{ + ncnn::Mat a = RandomMat(5, 7, 24); + ncnn::Mat b = RandomMat(7, 9, 12); + ncnn::Mat c = RandomMat(3, 5, 13); + + return 0 + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); +} + +static int test_flip_2() +{ + ncnn::Mat a = RandomMat(15, 24); + ncnn::Mat b = RandomMat(17, 12); + ncnn::Mat c = RandomMat(19, 15); + + return 0 + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); +} + +static int test_flip_3() +{ + ncnn::Mat a = RandomMat(128); + ncnn::Mat b = RandomMat(124); + ncnn::Mat c = RandomMat(127); + + return 0 + || test_flip_nd(a) + || test_flip_nd(b) + || test_flip_nd(c); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_flip_0() + || test_flip_1() + || test_flip_2() + || test_flip_3(); +} diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index d6b967848..cdfa51074 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -592,6 +592,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_cumsum.cpp pass_ncnn/torch_diag.cpp pass_ncnn/torch_flatten.cpp + pass_ncnn/torch_flip.cpp pass_ncnn/torch_istft.cpp pass_ncnn/torch_logsumexp.cpp pass_ncnn/torch_matmul.cpp diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index f50fed155..249228dd1 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -62,14 +62,18 @@ public: : type(2) { if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; i = (int)_l; } Parameter(long long _l) : type(2) { if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; i = (int)_l; } Parameter(float _f) @@ -99,7 +103,9 @@ public: { int64_t _l = x; if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; ai.push_back((int)_l); } } @@ -114,7 +120,9 @@ public: { int64_t _l = x; if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::max() - 1) _l = INT_MAX - 1; if (_l == std::numeric_limits::min()) _l = INT_MIN; + if (_l == std::numeric_limits::min() + 1) _l = INT_MIN + 1; ai.push_back((int)_l); } } diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index c09ea6526..a5581fc12 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -76,7 +76,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr) type = 2; int64_t i64 = attr.i(); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; break; } @@ -99,7 +101,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr) { int64_t i64 = attr.ints().at(i); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; ai.push_back(i64); } break; @@ -165,7 +169,9 @@ Parameter::Parameter(const onnx::AttributeProto& attr) i64 = tensor.int64_data().at(0); } if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; } else if (tensor.data_type() == onnx::TensorProto::FLOAT) diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 01fa3937a..2e9a2158c 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -100,7 +100,9 @@ Parameter::Parameter(const torch::jit::Node* value_node) type = 2; int64_t i64 = value_node->i(torch::jit::attr::value); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; break; } @@ -141,7 +143,9 @@ Parameter::Parameter(const torch::jit::Node* value_node) type = 2; int64_t i64 = t.item(); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; i = (int)i64; } else if (t.scalar_type() == c10::ScalarType::Int) @@ -193,7 +197,9 @@ Parameter::Parameter(const torch::jit::Node* value_node) for (auto i64 : i64s) { if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; ai.push_back(i64); } break; diff --git a/tools/pnnx/src/pass_level2/torch_flip.cpp b/tools/pnnx/src/pass_level2/torch_flip.cpp index 39235654d..fc78aecc1 100644 --- a/tools/pnnx/src/pass_level2/torch_flip.cpp +++ b/tools/pnnx/src/pass_level2/torch_flip.cpp @@ -27,4 +27,68 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 60) +class torch_flip_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Slice op_0 1 1 input out axes=%axes starts=%starts ends=%ends steps=%steps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.flip"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.at("axes").type == 2) + { + int start = captured_params.at("starts").i; + int end = captured_params.at("ends").i; + int step = captured_params.at("steps").i; + + if (start == -1 && end == INT_MIN + 1 && step == -1) + return true; + } + else // if (captured_params.at("axes").type == 5) + { + const std::vector& axes = captured_params.at("axes").ai; + const std::vector& starts = captured_params.at("starts").ai; + const std::vector& ends = captured_params.at("ends").ai; + const std::vector& steps = captured_params.at("steps").ai; + + for (size_t i = 0; i < axes.size(); i++) + { + if (starts[i] != -1 || ends[i] != INT_MIN + 1 || steps[i] != -1) + return false; + } + + return true; + } + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("axes").type == 2) + { + int dim = captured_params.at("axes").i; + op->params["dims"] = std::vector{dim}; + } + else // if (captured_params.at("axes").type == 5) + { + op->params["dims"] = captured_params.at("axes"); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip_onnx, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_flip.cpp b/tools/pnnx/src/pass_ncnn/torch_flip.cpp new file mode 100644 index 000000000..503ce0446 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_flip.cpp @@ -0,0 +1,58 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_flip : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.flip op_0 1 1 input out dims=%dims +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Flip"; + } + + const char* name_str() const + { + return "flip"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dims").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + // drop batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + op->params["0"] = new_dims; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_flip, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx.cpp b/tools/pnnx/src/pass_onnx.cpp index c91f783ab..433a562fd 100644 --- a/tools/pnnx/src/pass_onnx.cpp +++ b/tools/pnnx/src/pass_onnx.cpp @@ -875,7 +875,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph) i64 = tensor.int64_data().at(0); } if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; op_const->params["value"] = (int)i64; } else if (tensor.data_type() == onnx::TensorProto::FLOAT) @@ -961,7 +963,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph) { int64_t i64 = ai[k]; if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; expr += std::to_string(i64); if (k != (int)ai.size() - 1) expr += ","; diff --git a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp index 290def79e..9960a698b 100644 --- a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp +++ b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp @@ -146,7 +146,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model) } if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; onnx::AttributeProto* attr = node->add_attribute(); attr->set_name(std::string(attr_name)); @@ -242,7 +244,9 @@ void fuse_constant_as_attribute(onnx::ModelProto& model) for (auto i64 : ai) { if (i64 == std::numeric_limits::max()) i64 = INT_MAX; + if (i64 == std::numeric_limits::max() - 1) i64 = INT_MAX - 1; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; + if (i64 == std::numeric_limits::min() + 1) i64 = INT_MIN + 1; attr->add_ints((int)i64); } diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index b39932769..afd2046f0 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -212,6 +212,7 @@ pnnx_add_test(torch_einsum) pnnx_add_test(torch_eq) pnnx_add_test(torch_diag) pnnx_add_test(torch_flatten) +pnnx_add_test(torch_flip) pnnx_add_test(torch_full) pnnx_add_test(torch_full_like) pnnx_add_test(torch_gather) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index b3838cf1c..7acab3c3f 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -189,6 +189,7 @@ pnnx_ncnn_add_test(torch_clamp) pnnx_ncnn_add_test(torch_cos) pnnx_ncnn_add_test(torch_exp) pnnx_ncnn_add_test(torch_floor) +pnnx_ncnn_add_test(torch_flip) pnnx_ncnn_add_test(torch_log) pnnx_ncnn_add_test(torch_log10) pnnx_ncnn_add_test(torch_maximum) diff --git a/tools/pnnx/tests/ncnn/test_torch_flip.py b/tools/pnnx/tests/ncnn/test_torch_flip.py new file mode 100644 index 000000000..72d016fbe --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_flip.py @@ -0,0 +1,79 @@ +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + # 1D + x0 = torch.flip(x, [0]) + # 2D + y0 = torch.flip(y, [0]) + y1 = torch.flip(y, [1]) + y2 = torch.flip(y, [-2, -1]) + # 3D + z0 = torch.flip(z, [0]) + z1 = torch.flip(z, [1]) + z2 = torch.flip(z, [2]) + z3 = torch.flip(z, [0, 1]) + z4 = torch.flip(z, [0, 2]) + z5 = torch.flip(z, [1, 2]) + z6 = torch.flip(z, [0, 1, 2]) + # 4D + w0 = torch.flip(w, [-1]) + w1 = torch.flip(w, [-2]) + w2 = torch.flip(w, [-3]) + w3 = torch.flip(w, [-4]) + w4 = torch.flip(w, [0, 1]) + w5 = torch.flip(w, [0, 2]) + w6 = torch.flip(w, [0, 3]) + w7 = torch.flip(w, [1, 2]) + w8 = torch.flip(w, [1, 3]) + w9 = torch.flip(w, [2, 3]) + w10 = torch.flip(w, [0, 1, 2]) + w11 = torch.flip(w, [0, 1, 3]) + w12 = torch.flip(w, [0, 2, 3]) + w13 = torch.flip(w, [1, 2, 3]) + w14 = torch.flip(w, [0, 1, 2, 3]) + + return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(36) + y = torch.rand(14, 17) + z = torch.rand(13, 14, 15) + w = torch.rand(48, 12, 16, 17) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_flip.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]") + + # ncnn inference + import test_torch_flip_ncnn + b = test_torch_flip_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 599cf61e4..7b1e8e099 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -157,6 +157,7 @@ pnnx_onnx_add_test(torch_ceil) pnnx_onnx_add_test(torch_chunk) pnnx_onnx_add_test(torch_clamp) pnnx_onnx_add_test(torch_flatten) +pnnx_onnx_add_test(torch_flip) pnnx_onnx_add_test(torch_floor) pnnx_onnx_add_test(torch_logical_not) pnnx_onnx_add_test(torch_logical_and) diff --git a/tools/pnnx/tests/onnx/test_torch_flip.py b/tools/pnnx/tests/onnx/test_torch_flip.py new file mode 100644 index 000000000..15b20bb77 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_flip.py @@ -0,0 +1,82 @@ +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + # 1D + x0 = torch.flip(x, [0]) + # 2D + y0 = torch.flip(y, [0]) + y1 = torch.flip(y, [1]) + y2 = torch.flip(y, [-2, -1]) + # 3D + z0 = torch.flip(z, [0]) + z1 = torch.flip(z, [1]) + z2 = torch.flip(z, [2]) + z3 = torch.flip(z, [0, 1]) + z4 = torch.flip(z, [0, 2]) + z5 = torch.flip(z, [1, 2]) + z6 = torch.flip(z, [0, 1, 2]) + # 4D + w0 = torch.flip(w, [-1]) + w1 = torch.flip(w, [-2]) + w2 = torch.flip(w, [-3]) + w3 = torch.flip(w, [-4]) + w4 = torch.flip(w, [0, 1]) + w5 = torch.flip(w, [0, 2]) + w6 = torch.flip(w, [0, 3]) + w7 = torch.flip(w, [1, 2]) + w8 = torch.flip(w, [1, 3]) + w9 = torch.flip(w, [2, 3]) + w10 = torch.flip(w, [0, 1, 2]) + w11 = torch.flip(w, [0, 1, 3]) + w12 = torch.flip(w, [0, 2, 3]) + w13 = torch.flip(w, [1, 2, 3]) + w14 = torch.flip(w, [0, 1, 2, 3]) + + return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14 + +def test(): + if version.parse(torch.__version__) < version.parse('1.12'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(36) + y = torch.rand(14, 17) + z = torch.rand(13, 14, 15) + w = torch.rand(48, 12, 16, 17) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_torch_flip.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_flip.onnx inputshape=[36],[14,17],[13,14,15],[48,12,16,17]") + + # pnnx inference + import test_torch_flip_pnnx + b = test_torch_flip_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_flip.py b/tools/pnnx/tests/test_torch_flip.py new file mode 100644 index 000000000..5e5fd5e16 --- /dev/null +++ b/tools/pnnx/tests/test_torch_flip.py @@ -0,0 +1,79 @@ +# Copyright 2025 Tencent +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + # 1D + x0 = torch.flip(x, [0]) + # 2D + y0 = torch.flip(y, [0]) + y1 = torch.flip(y, [1]) + y2 = torch.flip(y, [-2, -1]) + # 3D + z0 = torch.flip(z, [0]) + z1 = torch.flip(z, [1]) + z2 = torch.flip(z, [2]) + z3 = torch.flip(z, [0, 1]) + z4 = torch.flip(z, [0, 2]) + z5 = torch.flip(z, [1, 2]) + z6 = torch.flip(z, [0, 1, 2]) + # 4D + w0 = torch.flip(w, [-1]) + w1 = torch.flip(w, [-2]) + w2 = torch.flip(w, [-3]) + w3 = torch.flip(w, [-4]) + w4 = torch.flip(w, [0, 1]) + w5 = torch.flip(w, [0, 2]) + w6 = torch.flip(w, [0, 3]) + w7 = torch.flip(w, [1, 2]) + w8 = torch.flip(w, [1, 3]) + w9 = torch.flip(w, [2, 3]) + w10 = torch.flip(w, [0, 1, 2]) + w11 = torch.flip(w, [0, 1, 3]) + w12 = torch.flip(w, [0, 2, 3]) + w13 = torch.flip(w, [1, 2, 3]) + w14 = torch.flip(w, [0, 1, 2, 3]) + + return x0, y0, y1, y2, z0, z1, z2, z3, z4, z5, z6, w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(36) + y = torch.rand(14, 17) + z = torch.rand(13, 14, 15) + w = torch.rand(48, 12, 16, 17) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_flip.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_flip.pt inputshape=[36],[14,17],[13,14,15],[48,12,16,17]") + + # pnnx inference + import test_torch_flip_pnnx + b = test_torch_flip_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)