diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7cde1e194..7bdad85ef 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -149,6 +149,7 @@ ncnn_add_layer(ConvolutionDepthWise1D) ncnn_add_layer(Convolution3D) ncnn_add_layer(ConvolutionDepthWise3D) ncnn_add_layer(Pooling3D) +ncnn_add_layer(MatMul) if(NCNN_VULKAN) ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) diff --git a/src/layer/matmul.cpp b/src/layer/matmul.cpp new file mode 100644 index 000000000..211f0ff30 --- /dev/null +++ b/src/layer/matmul.cpp @@ -0,0 +1,401 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "matmul.h" + +namespace ncnn { + +MatMul::MatMul() +{ + one_blob_only = false; + support_inplace = false; +} + +int MatMul::load_param(const ParamDict& pd) +{ + transB = pd.get(0, 0); + + return 0; +} + +static void transpose(const Mat& X, Mat& XT, const Option& opt) +{ + const int w = X.w; + const int h = X.h; + + const float* pX = X; + float* pXT = XT; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + float* ptr = pXT + i * h; + for (int j = 0; j < h; j++) + { + ptr[j] = pX[j * w + i]; + } + } +} + +static void matmul_transb(const Mat& A, const Mat& B, Mat& top_blob, const Option& opt) +{ + const int M = A.h; + const int K = A.w; // assert A.w == B.w + const int N = B.h; + + const float* pA = A; + const float* pB = B; + float* pOut = top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < M; i++) + { + const float* ptrA = pA + i * K; + float* outptr = pOut + i * N; + + for (int j = 0; j < N; j++) + { + const float* ptrB = pB + j * K; + + float sum = 0.f; + for (int k = 0; k < K; k++) + { + sum += ptrA[k] * ptrB[k]; + } + + *outptr++ = sum; + } + } +} + +int MatMul::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + Mat& top_blob = top_blobs[0]; + + const int Adims = A.dims; + const int Bdims = B.dims; + const int max_ABdims = std::max(Adims, Bdims); + const size_t elemsize = A.elemsize; + + if (Adims == 1 && Bdims == 1) + { + // dot product + top_blob.create(1, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + const int K = A.w; // assert A.w == B.w + const float* ptrA = A; + const float* ptrB = B; + + float sum = 0.f; + for (int k = 0; k < K; k++) + { + sum += ptrA[k] * ptrB[k]; + } + + top_blob[0] = sum; + } + else if (Adims == 2 && Bdims == 2) + { + // matrix multiply + const int M = A.h; + const int N = transB == 0 ? B.w : B.h; + + top_blob.create(N, M, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + Mat BT; + if (transB == 0) + { + BT.create(B.h, B.w, elemsize, opt.workspace_allocator); + if (BT.empty()) + return -100; + + transpose(B, BT, opt); + } + else + { + BT = B; + } + + matmul_transb(A, BT, top_blob, opt); + } + else if (Adims == 1 && Bdims == 2) + { + // matrix multiply + const int N = transB == 0 ? B.w : B.h; + + Mat top_blob1(N, 1, elemsize, opt.blob_allocator); + if (top_blob1.empty()) + return -100; + + Mat A1 = A.reshape(A.w, 1); + + Mat BT; + if (transB == 0) + { + BT.create(B.h, B.w, elemsize, opt.workspace_allocator); + if (BT.empty()) + return -100; + + transpose(B, BT, opt); + } + else + { + BT = B; + } + + matmul_transb(A1, BT, top_blob1, opt); + + top_blob = top_blob1.reshape(N); + } + else if (Adims == 2 && Bdims == 1) + { + // matrix multiply + const int M = A.h; + + Mat top_blob1(1, M, elemsize, opt.blob_allocator); + if (top_blob1.empty()) + return -100; + + Mat BT = B.reshape(B.w, 1); + + matmul_transb(A, BT, top_blob1, opt); + + top_blob = top_blob1.reshape(M); + } + else if (Adims == 1 && Bdims > 2) + { + // batched matrix multiply + const int N = transB == 0 ? B.w : B.h; + const int batch_size = B.d * B.c; + + Mat top_blob1(N, 1, batch_size, elemsize, opt.blob_allocator); + if (top_blob1.empty()) + return -100; + + Mat A1 = A.reshape(A.w, 1); + Mat B1 = B.reshape(B.w, B.h, batch_size); + + for (int p = 0; p < batch_size; p++) + { + Mat BT; + if (transB == 0) + { + BT.create(B.h, B.w, elemsize, opt.workspace_allocator); + if (BT.empty()) + return -100; + + transpose(B1.channel(p), BT, opt); + } + else + { + BT = B1.channel(p); + } + + Mat top_blob1_p = top_blob1.channel(p); + matmul_transb(A1, BT, top_blob1_p, opt); + } + + if (Bdims == 3) + top_blob = top_blob1.reshape(N, B.d * B.c); + else + top_blob = top_blob1.reshape(N, B.d, B.c); + } + else if (Adims > 2 && Bdims == 1) + { + // batched matrix multiply + const int M = A.h; + const int batch_size = A.d * A.c; + + Mat top_blob1(1, M, batch_size, elemsize, opt.blob_allocator); + if (top_blob1.empty()) + return -100; + + Mat A1 = A.reshape(A.w, A.h, batch_size); + Mat BT = B.reshape(B.w, 1); + + for (int p = 0; p < batch_size; p++) + { + Mat top_blob1_p = top_blob1.channel(p); + matmul_transb(A1.channel(p), BT, top_blob1_p, opt); + } + + if (Adims == 3) + top_blob = top_blob1.reshape(M, A.d * A.c); + else + top_blob = top_blob1.reshape(M, A.d, A.c); + } + else if (max_ABdims == 3) + { + Mat A1 = Adims == 2 ? A.reshape(A.w, A.h, 1) : A; + Mat B1 = Bdims == 2 ? B.reshape(B.w, B.h, 1) : B; + + const int M = A1.h; + const int N = transB == 0 ? B1.w : B1.h; + const int batch_size = std::max(A1.c, B1.c); + + top_blob.create(N, M, batch_size, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + Mat BT0; + if (B1.c == 1) + { + if (transB == 0) + { + BT0.create(B1.h, B1.w, elemsize, opt.workspace_allocator); + if (BT0.empty()) + return -100; + + transpose(B1.channel(0), BT0, opt); + } + else + { + BT0 = B1.channel(0); + } + } + + for (int p = 0; p < batch_size; p++) + { + int Ap = A1.c == 1 ? 0 : p; + int Bp = B1.c == 1 ? 0 : p; + + Mat BT; + if (B1.c == 1) + { + BT = BT0; + } + else + { + if (transB == 0) + { + BT.create(B1.h, B1.w, elemsize, opt.workspace_allocator); + if (BT.empty()) + return -100; + + transpose(B1.channel(Bp), BT, opt); + } + else + { + BT = B1.channel(Bp); + } + } + + Mat top_blob_p = top_blob.channel(p); + matmul_transb(A1.channel(Ap), BT, top_blob_p, opt); + } + } + else if (max_ABdims == 4) + { + Mat A1 = Adims == 3 ? A.reshape(A.w, A.h, A.c, 1) : A; + Mat B1 = Bdims == 3 ? B.reshape(B.w, B.h, B.c, 1) : B; + + const int M = A1.h; + const int N = transB == 0 ? B1.w : B1.h; + const int batch_size_d = std::max(A1.d, B1.d); + const int batch_size_c = std::max(A1.c, B1.c); + + top_blob.create(N, M, batch_size_d, batch_size_c, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + Mat BT00; + if (B1.d == 1 && B1.c == 1) + { + if (transB == 0) + { + BT00.create(B1.h, B1.w, elemsize, opt.workspace_allocator); + if (BT00.empty()) + return -100; + + transpose(B1.channel(0).depth(0), BT00, opt); + } + else + { + BT00 = B1.channel(0).depth(0); + } + } + + for (int p = 0; p < batch_size_c; p++) + { + int Ap = A1.c == 1 ? 0 : p; + int Bp = B1.c == 1 ? 0 : p; + + Mat BT0x; + if (B1.d == 1 && B1.c != 1) + { + if (transB == 0) + { + BT0x.create(B1.h, B1.w, elemsize, opt.workspace_allocator); + if (BT0x.empty()) + return -100; + + transpose(B1.channel(Bp).depth(0), BT0x, opt); + } + else + { + BT0x = B1.channel(Bp).depth(0); + } + } + + for (int q = 0; q < batch_size_d; q++) + { + int Ad = A1.d == 1 ? 0 : q; + int Bd = B1.d == 1 ? 0 : q; + + Mat BT; + if (B1.d == 1 && B1.c == 1) + { + BT = BT00; + } + else if (B1.d == 1 && B1.c != 1) + { + BT = BT0x; + } + else + { + if (transB == 0) + { + BT.create(B1.h, B1.w, elemsize, opt.workspace_allocator); + if (BT.empty()) + return -100; + + transpose(B1.channel(Bp).depth(Bd), BT, opt); + } + else + { + BT = B1.channel(Bp).depth(Bd); + } + } + + Mat top_blob_p_q = top_blob.channel(p).depth(q); + matmul_transb(A1.channel(Ap).depth(Ad), BT, top_blob_p_q, opt); + } + } + } + else + { + NCNN_LOGE("impossible matmul %d %d", Adims, Bdims); + return -1; + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/matmul.h b/src/layer/matmul.h new file mode 100644 index 000000000..153670961 --- /dev/null +++ b/src/layer/matmul.h @@ -0,0 +1,37 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_MATMUL_H +#define LAYER_MATMUL_H + +#include "layer.h" + +namespace ncnn { + +class MatMul : public Layer +{ +public: + MatMul(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +public: + int transB; +}; + +} // namespace ncnn + +#endif // LAYER_MATMUL_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8b4ea7b2d..736fb9934 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -90,6 +90,7 @@ ncnn_add_layer_test(Interp) ncnn_add_layer_test(LayerNorm) ncnn_add_layer_test(LRN) ncnn_add_layer_test(LSTM) +ncnn_add_layer_test(MatMul) ncnn_add_layer_test(MemoryData) ncnn_add_layer_test(Mish) ncnn_add_layer_test(MultiHeadAttention) diff --git a/tests/test_matmul.cpp b/tests/test_matmul.cpp new file mode 100644 index 000000000..34b4aad03 --- /dev/null +++ b/tests/test_matmul.cpp @@ -0,0 +1,255 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layer/matmul.h" +#include "testutil.h" + +static int test_matmul(const ncnn::Mat& a, const ncnn::Mat& b) +{ + ncnn::ParamDict pd; + pd.set(0, 0); // transB + + std::vector weights(0); + + std::vector as(2); + as[0] = a; + as[1] = b; + + int ret = test_layer("MatMul", pd, weights, as); + if (ret != 0) + { + fprintf(stderr, "test_matmul failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c); + } + + return ret; +} + +static int test_matmul_transb(const ncnn::Mat& a, const ncnn::Mat& b) +{ + ncnn::ParamDict pd; + pd.set(0, 1); // transB + + std::vector weights(0); + + std::vector as(2); + as[0] = a; + as[1] = b; + + int ret = test_layer("MatMul", pd, weights, as); + if (ret != 0) + { + fprintf(stderr, "test_matmul_transb failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c); + } + + return ret; +} + +static int test_matmul_0() +{ + return 0 + || test_matmul(RandomMat(124), RandomMat(124)) + || test_matmul(RandomMat(127), RandomMat(127)) + || test_matmul(RandomMat(128), RandomMat(128)); +} + +static int test_matmul_1() +{ + return 0 + || test_matmul(RandomMat(5), RandomMat(6, 5)) + || test_matmul(RandomMat(16), RandomMat(12, 16)) + || test_matmul(RandomMat(11), RandomMat(16, 11)) + + || test_matmul_transb(RandomMat(5), RandomMat(5, 6)) + || test_matmul_transb(RandomMat(16), RandomMat(16, 12)) + || test_matmul_transb(RandomMat(11), RandomMat(11, 16)); +} + +static int test_matmul_2() +{ + return 0 + || test_matmul(RandomMat(13), RandomMat(7, 13, 12)) + || test_matmul(RandomMat(24), RandomMat(6, 24, 16)) + || test_matmul(RandomMat(20), RandomMat(8, 20, 19)) + + || test_matmul_transb(RandomMat(13), RandomMat(13, 7, 12)) + || test_matmul_transb(RandomMat(24), RandomMat(24, 6, 16)) + || test_matmul_transb(RandomMat(20), RandomMat(20, 8, 19)); +} + +static int test_matmul_3() +{ + return 0 + || test_matmul(RandomMat(13), RandomMat(7, 13, 5, 12)) + || test_matmul(RandomMat(24), RandomMat(6, 24, 4, 16)) + || test_matmul(RandomMat(20), RandomMat(8, 20, 3, 19)) + + || test_matmul_transb(RandomMat(13), RandomMat(13, 7, 5, 12)) + || test_matmul_transb(RandomMat(24), RandomMat(24, 6, 4, 16)) + || test_matmul_transb(RandomMat(20), RandomMat(20, 8, 3, 19)); +} + +static int test_matmul_4() +{ + return 0 + || test_matmul(RandomMat(5, 6), RandomMat(5)) + || test_matmul(RandomMat(16, 12), RandomMat(16)) + || test_matmul(RandomMat(11, 16), RandomMat(11)); +} + +static int test_matmul_5() +{ + return 0 + || test_matmul(RandomMat(32, 3, 10), RandomMat(32)) + || test_matmul(RandomMat(31, 4, 16), RandomMat(31)) + || test_matmul(RandomMat(28, 5, 28), RandomMat(28)); +} + +static int test_matmul_6() +{ + return 0 + || test_matmul(RandomMat(32, 3, 4, 10), RandomMat(32)) + || test_matmul(RandomMat(31, 4, 5, 16), RandomMat(31)) + || test_matmul(RandomMat(18, 5, 6, 28), RandomMat(18)); +} + +static int test_matmul_7() +{ + return 0 + || test_matmul(RandomMat(14, 10), RandomMat(5, 14)) + || test_matmul(RandomMat(16, 16), RandomMat(10, 16)) + || test_matmul(RandomMat(14, 28), RandomMat(9, 14)) + + || test_matmul_transb(RandomMat(14, 10), RandomMat(14, 5)) + || test_matmul_transb(RandomMat(16, 16), RandomMat(16, 10)) + || test_matmul_transb(RandomMat(14, 28), RandomMat(14, 9)); +} + +static int test_matmul_8() +{ + return 0 + || test_matmul(RandomMat(5, 4), RandomMat(4, 5, 12)) + || test_matmul(RandomMat(5, 14), RandomMat(5, 5, 16)) + || test_matmul(RandomMat(5, 24), RandomMat(6, 5, 19)) + + || test_matmul_transb(RandomMat(5, 4), RandomMat(5, 4, 12)) + || test_matmul_transb(RandomMat(5, 14), RandomMat(5, 5, 16)) + || test_matmul_transb(RandomMat(5, 24), RandomMat(5, 6, 19)); +} + +static int test_matmul_9() +{ + return 0 + || test_matmul(RandomMat(5, 4), RandomMat(4, 5, 2, 12)) + || test_matmul(RandomMat(5, 14), RandomMat(5, 5, 3, 16)) + || test_matmul(RandomMat(5, 24), RandomMat(6, 5, 4, 19)) + + || test_matmul_transb(RandomMat(5, 4), RandomMat(5, 4, 2, 12)) + || test_matmul_transb(RandomMat(5, 14), RandomMat(5, 5, 3, 16)) + || test_matmul_transb(RandomMat(5, 24), RandomMat(5, 6, 4, 19)); +} + +static int test_matmul_10() +{ + return 0 + || test_matmul(RandomMat(14, 23, 10), RandomMat(5, 14)) + || test_matmul(RandomMat(16, 22, 16), RandomMat(10, 16)) + || test_matmul(RandomMat(14, 20, 28), RandomMat(9, 14)) + + || test_matmul_transb(RandomMat(14, 23, 10), RandomMat(14, 5)) + || test_matmul_transb(RandomMat(16, 22, 16), RandomMat(16, 10)) + || test_matmul_transb(RandomMat(14, 20, 28), RandomMat(14, 9)); +} + +static int test_matmul_11() +{ + return 0 + || test_matmul(RandomMat(14, 13, 2, 10), RandomMat(5, 14)) + || test_matmul(RandomMat(16, 12, 3, 16), RandomMat(10, 16)) + || test_matmul(RandomMat(14, 10, 4, 28), RandomMat(9, 14)) + + || test_matmul_transb(RandomMat(14, 13, 2, 10), RandomMat(14, 5)) + || test_matmul_transb(RandomMat(16, 12, 3, 16), RandomMat(16, 10)) + || test_matmul_transb(RandomMat(14, 10, 4, 28), RandomMat(14, 9)); +} + +static int test_matmul_12() +{ + return 0 + || test_matmul(RandomMat(14, 23, 10), RandomMat(5, 14, 10)) + || test_matmul(RandomMat(16, 22, 16), RandomMat(10, 16, 16)) + || test_matmul(RandomMat(14, 20, 28), RandomMat(9, 14, 28)) + + || test_matmul_transb(RandomMat(14, 23, 10), RandomMat(14, 5, 10)) + || test_matmul_transb(RandomMat(16, 22, 16), RandomMat(16, 10, 16)) + || test_matmul_transb(RandomMat(14, 20, 28), RandomMat(14, 9, 28)); +} + +static int test_matmul_13() +{ + return 0 + || test_matmul(RandomMat(14, 23, 10), RandomMat(5, 14, 1, 16)) + || test_matmul(RandomMat(16, 22, 9), RandomMat(10, 16, 1, 17)) + || test_matmul(RandomMat(14, 20, 8), RandomMat(9, 14, 1, 18)) + + || test_matmul_transb(RandomMat(14, 23, 10), RandomMat(14, 5, 1, 16)) + || test_matmul_transb(RandomMat(16, 22, 9), RandomMat(16, 10, 1, 17)) + || test_matmul_transb(RandomMat(14, 20, 8), RandomMat(14, 9, 1, 18)); +} + +static int test_matmul_14() +{ + return 0 + || test_matmul(RandomMat(14, 23, 10, 1), RandomMat(5, 14, 1, 16)) + || test_matmul(RandomMat(16, 22, 9, 1), RandomMat(10, 16, 1, 17)) + || test_matmul(RandomMat(14, 20, 8, 1), RandomMat(9, 14, 1, 18)) + + || test_matmul_transb(RandomMat(14, 23, 10, 1), RandomMat(14, 5, 1, 16)) + || test_matmul_transb(RandomMat(16, 22, 9, 1), RandomMat(16, 10, 1, 17)) + || test_matmul_transb(RandomMat(14, 20, 8, 1), RandomMat(14, 9, 1, 18)); +} + +static int test_matmul_15() +{ + return 0 + || test_matmul(RandomMat(14, 23, 10, 16), RandomMat(5, 14, 10, 16)) + || test_matmul(RandomMat(16, 22, 9, 17), RandomMat(10, 16, 9, 17)) + || test_matmul(RandomMat(14, 20, 8, 18), RandomMat(9, 14, 8, 18)) + + || test_matmul_transb(RandomMat(14, 23, 10, 16), RandomMat(14, 5, 10, 16)) + || test_matmul_transb(RandomMat(16, 22, 9, 17), RandomMat(16, 10, 9, 17)) + || test_matmul_transb(RandomMat(14, 20, 8, 18), RandomMat(14, 9, 8, 18)); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_matmul_0() + || test_matmul_1() + || test_matmul_2() + || test_matmul_3() + || test_matmul_4() + || test_matmul_5() + || test_matmul_6() + || test_matmul_7() + || test_matmul_8() + || test_matmul_9() + || test_matmul_10() + || test_matmul_11() + || test_matmul_12() + || test_matmul_13() + || test_matmul_14() + || test_matmul_15(); +} diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 5b8cecfad..023eafdde 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -234,7 +234,9 @@ set(pnnx_pass_level4_SRCS set(pnnx_pass_level5_SRCS pass_level5/eliminate_dropout.cpp + pass_level5/eliminate_identity_operator.cpp pass_level5/eliminate_maxpool_indices.cpp + pass_level5/eliminate_noop_pad.cpp pass_level5/eliminate_slice.cpp pass_level5/eliminate_view_reshape.cpp pass_level5/eval_expression.cpp @@ -247,6 +249,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_convtranspose2d_batchnorm2d.cpp pass_level5/fuse_contiguous_view.cpp pass_level5/fuse_linear_batchnorm1d.cpp + pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_slice_indices.cpp pass_level5/unroll_rnn_op.cpp ) @@ -274,6 +277,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/fuse_deconvolution_activation.cpp pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp pass_ncnn/fuse_innerproduct_activation.cpp + pass_ncnn/fuse_transpose_matmul.cpp pass_ncnn/F_adaptive_avg_pool1d.cpp pass_ncnn/F_adaptive_avg_pool2d.cpp @@ -390,6 +394,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_clone.cpp pass_ncnn/torch_flatten.cpp pass_ncnn/torch_logsumexp.cpp + pass_ncnn/torch_matmul.cpp pass_ncnn/torch_mean.cpp pass_ncnn/torch_permute.cpp pass_ncnn/torch_prod.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index d776f73c9..880f59032 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -268,6 +268,38 @@ Parameter::Parameter(const torch::jit::Value* value) { } +bool operator==(const Parameter& lhs, const Parameter& rhs) +{ + if (lhs.type != rhs.type) + return false; + + if (lhs.type == 0) + return true; + + if (lhs.type == 1 && lhs.b == rhs.b) + return true; + + if (lhs.type == 2 && lhs.i == rhs.i) + return true; + + if (lhs.type == 3 && lhs.f == rhs.f) + return true; + + if (lhs.type == 4 && lhs.s == rhs.s) + return true; + + if (lhs.type == 5 && lhs.ai == rhs.ai) + return true; + + if (lhs.type == 6 && lhs.af == rhs.af) + return true; + + if (lhs.type == 7 && lhs.as == rhs.as) + return true; + + return false; +} + Attribute::Attribute(const at::Tensor& t) { type = get_at_tensor_type(t.scalar_type()); @@ -343,6 +375,23 @@ Attribute::Attribute(const std::initializer_list& _shape, const std::vector } } +bool operator==(const Attribute& lhs, const Attribute& rhs) +{ + if (lhs.type != rhs.type) + return false; + + if (lhs.type == 0) + return true; + + if (lhs.shape != rhs.shape) + return false; + + if (lhs.data != rhs.data) + return false; + + return true; +} + Parameter Parameter::parse_from_string(const std::string& value) { Parameter p; diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index ecc2b0cd1..07e5259dc 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -132,6 +132,8 @@ public: std::vector as; }; +bool operator==(const Parameter& lhs, const Parameter& rhs); + class Attribute { public: @@ -151,6 +153,8 @@ public: std::vector data; }; +bool operator==(const Attribute& lhs, const Attribute& rhs); + class Operator; class Operand { diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 757a8f180..b62368af6 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -16,6 +16,8 @@ #include "pass_level5/fold_constants.h" #include "pass_level5/eliminate_dropout.h" +#include "pass_level5/eliminate_identity_operator.h" +#include "pass_level5/eliminate_noop_pad.h" #include "pass_level5/eliminate_slice.h" #include "pass_level5/eliminate_view_reshape.h" #include "pass_level5/eval_expression.h" @@ -27,6 +29,7 @@ #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" #include "pass_level5/fuse_contiguous_view.h" #include "pass_level5/fuse_linear_batchnorm1d.h" +#include "pass_level5/fuse_select_to_unbind.h" #include "pass_level5/fuse_slice_indices.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" @@ -44,6 +47,10 @@ void pass_level5(Graph& g, const std::map& foldable_cons fuse_slice_indices(g); + eliminate_identity_operator(g); + + fuse_select_to_unbind(g); + fuse_conv1d_batchnorm1d(g); fuse_conv2d_batchnorm2d(g); @@ -54,12 +61,14 @@ void pass_level5(Graph& g, const std::map& foldable_cons fuse_linear_batchnorm1d(g); + eliminate_noop_pad(g); + + eliminate_dropout(g); + fuse_contiguous_view(g); eliminate_view_reshape(g); - eliminate_dropout(g); - fuse_channel_shuffle(g); fold_constants(g, foldable_constants); diff --git a/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp b/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp new file mode 100644 index 000000000..b13068fe7 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp @@ -0,0 +1,118 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_identity_operator.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_identity_operator(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op0 = graph.ops[i]; + + if (op0->type == "pnnx.Input" || op0->type == "pnnx.Output") + continue; + + Operator* op1 = 0; + + for (size_t j = i + 1; j < graph.ops.size(); j++) + { + op1 = graph.ops[j]; + + if (op1->type == "pnnx.Input" || op1->type == "pnnx.Output") + continue; + + if (op0->type != op1->type) + continue; + + if (op0->inputs != op1->inputs) + continue; + + if (op0->outputs.size() != op1->outputs.size()) + continue; + + if (op0->params != op1->params) + continue; + + if (op0->attrs != op1->attrs) + continue; + + // we find same operator with same inputs + matched = true; + break; + } + + if (!matched) + continue; + + // fprintf(stderr, "eliminate_identity_operator %s %s %s\n", op0->type.c_str(), op0->name.c_str(), op1->name.c_str()); + + int input_count = (int)op0->inputs.size(); + for (int j = 0; j < input_count; j++) + { + Operand* in0 = op0->inputs[j]; + + in0->consumers.erase(std::find(in0->consumers.begin(), in0->consumers.end(), op1)); + } + + int output_count = (int)op0->outputs.size(); + for (int j = 0; j < output_count; j++) + { + Operand* out0 = op0->outputs[j]; + Operand* out1 = op1->outputs[j]; + + for (auto x : out1->consumers) + { + for (size_t k = 0; k < x->inputs.size(); k++) + { + if (x->inputs[k] == out1) + x->inputs[k] = out0; + } + + out0->consumers.push_back(x); + } + + out1->consumers.clear(); + } + + // delete op1 and its output operands + for (int j = 0; j < output_count; j++) + { + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op1->outputs[j])); + delete op1->outputs[j]; + } + + op1->inputs.clear(); + op1->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op1)); + delete op1; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_identity_operator.h b/tools/pnnx/src/pass_level5/eliminate_identity_operator.h new file mode 100644 index 000000000..7ff0299a2 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_identity_operator.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_identity_operator(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp new file mode 100644 index 000000000..f3129a933 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp @@ -0,0 +1,91 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_noop_pad.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_noop_pad(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "F.pad") + continue; + + const std::vector& pad = op->params.at("pad").ai; + + bool noop_pad = true; + for (auto p : pad) + { + if (p != 0) + { + noop_pad = false; + break; + } + } + + if (!noop_pad) + continue; + + // delete noop-like pad + matched = true; + + for (auto& x : op->inputs) + { + x->remove_consumer(op); + } + + Operand* pad_out = op->outputs[0]; + + for (auto& x : pad_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == pad_out) + x->inputs[j] = op->inputs[0]; + } + + op->inputs[0]->consumers.push_back(x); + } + + pad_out->producer = 0; + pad_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), pad_out)); + delete pad_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_pad.h b/tools/pnnx/src/pass_level5/eliminate_noop_pad.h new file mode 100644 index 000000000..359c04634 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_noop_pad.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_noop_pad(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp new file mode 100644 index 000000000..a3f9f8c0f --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp @@ -0,0 +1,111 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_select_to_unbind.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_select_to_unbind(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.select") + continue; + + Operand* op_in = op->inputs[0]; + int input_rank = op_in->shape.size(); + if (input_rank == 0) + continue; + + int dim = op->params.at("dim").i; + const int select_dimsize = op_in->shape[dim]; + + // select 0..n + std::vector select_n(select_dimsize, 0); + std::vector select_n_ops(select_dimsize, 0); + for (auto x : op_in->consumers) + { + if (x->type != "Tensor.select") + continue; + + if (x->inputs[0] != op_in) + continue; + + int dim2 = x->params.at("dim").i; + int index2 = x->params.at("index").i; + + if (dim == dim2) + { + select_n[index2] = 1; + select_n_ops[index2] = x; + } + } + + bool select_full_index = true; + for (auto x : select_n) + { + if (x == 0) + { + select_full_index = false; + break; + } + } + + if (!select_full_index) + continue; + + matched = true; + + // delete all select ops and replace with unbind + Operator* op_unbind = graph.new_operator_before("torch.unbind", op->name, op); + op_unbind->params["dim"] = dim; + + op_unbind->inputs.push_back(op_in); + for (int j = 0; j < select_dimsize; j++) + { + op_in->consumers.erase(std::find(op_in->consumers.begin(), op_in->consumers.end(), select_n_ops[j])); + } + op_in->consumers.push_back(op_unbind); + + op_unbind->outputs.resize(select_dimsize); + for (int j = 0; j < select_dimsize; j++) + { + op_unbind->outputs[j] = select_n_ops[j]->outputs[0]; + select_n_ops[j]->outputs[0]->producer = op_unbind; + } + + for (int j = 0; j < select_dimsize; j++) + { + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), select_n_ops[j])); + delete select_n_ops[j]; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h new file mode 100644 index 000000000..b48e64432 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_select_to_unbind(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp index fdcc73c36..49aabc536 100644 --- a/tools/pnnx/src/pass_ncnn.cpp +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -36,6 +36,7 @@ #include "pass_ncnn/fuse_deconvolution_activation.h" #include "pass_ncnn/fuse_deconvolutiondepthwise_activation.h" #include "pass_ncnn/fuse_innerproduct_activation.h" +#include "pass_ncnn/fuse_transpose_matmul.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" @@ -92,6 +93,7 @@ void pass_ncnn(Graph& g) ncnn::insert_split(g); ncnn::eliminate_noop(g); + ncnn::fuse_transpose_matmul(g); ncnn::fuse_convolution_activation(g); ncnn::fuse_convolution1d_activation(g); ncnn::fuse_convolutiondepthwise_activation(g); diff --git a/tools/pnnx/src/pass_ncnn/convert_attribute.cpp b/tools/pnnx/src/pass_ncnn/convert_attribute.cpp index 9e0cf2ff7..a5bfb7425 100644 --- a/tools/pnnx/src/pass_ncnn/convert_attribute.cpp +++ b/tools/pnnx/src/pass_ncnn/convert_attribute.cpp @@ -48,6 +48,19 @@ void convert_attribute(Graph& graph) new_shape.push_back(data.shape[i]); } + if (new_shape.size() == 5 && batch_index == 233) + { + if (new_shape[0] == 1) + { + fprintf(stderr, "assume pnnx attribute 5-rank tensor has batch_index 0\n"); + new_shape.erase(new_shape.begin()); + } + else + { + fprintf(stderr, "pnnx attribute 5-rank tensor is not supported yet!\n"); + } + } + if (new_shape.size() == 1) { op->params["0"] = new_shape[0]; diff --git a/tools/pnnx/src/pass_ncnn/fuse_transpose_matmul.cpp b/tools/pnnx/src/pass_ncnn/fuse_transpose_matmul.cpp new file mode 100644 index 000000000..7ee0455fe --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_transpose_matmul.cpp @@ -0,0 +1,66 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_transpose_matmul.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_transpose_matmul_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_a 0 1 a +pnnx.Input input_b 0 1 b +Permute op_0 1 1 b bt 0=1 +MatMul op_1 2 1 a bt out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "MatMul"; + } + + const char* name_str() const + { + return "matmultransb"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& /*captured_attrs*/) const + { + op->params["0"] = 1; + } +}; + +void fuse_transpose_matmul(Graph& graph) +{ + fuse_transpose_matmul_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_transpose_matmul.h b/tools/pnnx/src/pass_ncnn/fuse_transpose_matmul.h new file mode 100644 index 000000000..400e4c5d6 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_transpose_matmul.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_transpose_matmul(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp index 0b22a5868..777b074ee 100644 --- a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp +++ b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp @@ -160,6 +160,28 @@ static void solve_batch_index_forward(Operand* operand) solve_batch_index_backward(r); } } + else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") + { + const std::vector& shape = op->params.at("shape").ai; + + if (shape[batch_index] == 1) + { + for (Operand* r : op->outputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + } + else + { + // give up reshape across batch index + } + } else { for (Operand* r : op->outputs) @@ -207,6 +229,28 @@ static void solve_batch_index_backward(Operand* operand) solve_batch_index_forward(r); } } + else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") + { + const std::vector& shape = op->params.at("shape").ai; + + if (shape[batch_index] == 1) + { + for (Operand* r : op->inputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index; + + solve_batch_index_backward(r); + solve_batch_index_forward(r); + } + } + else + { + // give up reshape across batch index + } + } else { for (Operand* r : op->inputs) diff --git a/tools/pnnx/src/pass_ncnn/torch_matmul.cpp b/tools/pnnx/src/pass_ncnn/torch_matmul.cpp new file mode 100644 index 000000000..265b65156 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_matmul.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_matmul : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Input other 0 1 other +torch.matmul op_0 2 1 input other out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "MatMul"; + } + + const char* name_str() const + { + return "matmul"; + } + + void write(Operator* /*op*/, const std::map& /*captured_params*/) const + { + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_matmul, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_transpose.cpp b/tools/pnnx/src/pass_ncnn/torch_transpose.cpp index 626b0dbc5..f87ac8ace 100644 --- a/tools/pnnx/src/pass_ncnn/torch_transpose.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_transpose.cpp @@ -49,14 +49,24 @@ pnnx.Output output 1 0 out int dim0 = captured_params.at("dim0").i; int dim1 = captured_params.at("dim1").i; + + int input_rank = op->inputs[0]->shape.size(); + + if (dim0 < 0) + { + dim0 = input_rank + dim0; + } + if (dim1 < 0) + { + dim1 = input_rank + dim1; + } + if (dim0 == batch_index || dim1 == batch_index) { fprintf(stderr, "permute across batch dim is not supported yet!\n"); return; } - int input_rank = op->inputs[0]->shape.size(); - if (batch_index >= 0 && batch_index < input_rank) input_rank -= 1; diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index c975b252d..b5fad3e82 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -172,6 +172,7 @@ pnnx_add_test(torch_clamp) pnnx_add_test(torch_clone) pnnx_add_test(torch_flatten) pnnx_add_test(torch_logsumexp) +pnnx_add_test(torch_matmul) pnnx_add_test(torch_mean) pnnx_add_test(torch_norm) pnnx_add_test(torch_permute) @@ -200,6 +201,7 @@ pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) +pnnx_add_test(pnnx_fuse_select_to_unbind) if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") pnnx_add_test(F_mish) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 45e237629..2dc2c95a3 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -130,6 +130,7 @@ pnnx_ncnn_add_test(torch_chunk) pnnx_ncnn_add_test(torch_clamp) pnnx_ncnn_add_test(torch_clone) pnnx_ncnn_add_test(torch_logsumexp) +pnnx_ncnn_add_test(torch_matmul) pnnx_ncnn_add_test(torch_mean) pnnx_ncnn_add_test(torch_permute) pnnx_ncnn_add_test(torch_prod) @@ -144,6 +145,8 @@ pnnx_ncnn_add_test(resnet18) pnnx_ncnn_add_test(shufflenet_v2_x1_0) pnnx_ncnn_add_test(squeezenet1_1) +pnnx_ncnn_add_test(ncnn_fuse_transpose_matmul) + if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") pnnx_ncnn_add_test(F_mish) pnnx_ncnn_add_test(nn_Mish) diff --git a/tools/pnnx/tests/ncnn/test_ncnn_fuse_transpose_matmul.py b/tools/pnnx/tests/ncnn/test_ncnn_fuse_transpose_matmul.py new file mode 100644 index 000000000..af1b40bf9 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_ncnn_fuse_transpose_matmul.py @@ -0,0 +1,95 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1): + a = torch.matmul(a0, a1.transpose(-2, -1)) + b = torch.matmul(b0, b1.transpose(-2, -1)) + c = torch.matmul(c0, c1.transpose(-2, -1)) + d = torch.matmul(d0, d1.transpose(-2, -1)) + e = torch.matmul(e0, e1.transpose(-2, -1)) + f = torch.matmul(f0, f1.transpose(-2, -1)) + g = torch.matmul(g0, g1.transpose(-2, -1)) + h = torch.matmul(h0, h1.transpose(-2, -1)) + i = torch.matmul(i0, i1.transpose(-2, -1)) + j = torch.matmul(j0, j1.transpose(-2, -1)) + k = torch.matmul(k0, k1.transpose(-2, -1)) + l = torch.matmul(l0, l1.transpose(-2, -1)) + return a, b, c, d, e, f, g, h, i, j, k, l + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(14) + a1 = torch.rand(6, 14) + b0 = torch.rand(13) + b1 = torch.rand(7, 4, 13) + c0 = torch.rand(15) + c1 = torch.rand(5, 7, 9, 15) + d0 = torch.rand(23, 14) + d1 = torch.rand(25, 14) + e0 = torch.rand(4, 5) + e1 = torch.rand(10, 40, 5) + f0 = torch.rand(14, 6) + f1 = torch.rand(2, 4, 20, 6) + g0 = torch.rand(10, 23, 14) + g1 = torch.rand(5, 14) + h0 = torch.rand(7, 8, 13, 14) + h1 = torch.rand(35, 14) + i0 = torch.rand(10, 23, 14) + i1 = torch.rand(10, 5, 14) + j0 = torch.rand(10, 13, 18) + j1 = torch.rand(3, 1, 8, 18) + k0 = torch.rand(1, 5, 23, 11) + k1 = torch.rand(8, 1, 9, 11) + l0 = torch.rand(6, 9, 13, 14) + l1 = torch.rand(6, 9, 15, 14) + + a = net(a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1)) + mod.save("test_ncnn_fuse_transpose_matmul.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_ncnn_fuse_transpose_matmul.pt inputshape=[14],[6,14],[13],[7,4,13],[15],[5,7,9,15],[23,14],[25,14],[4,5],[10,40,5],[14,6],[2,4,20,6],[10,23,14],[5,14],[7,8,13,14],[35,14],[10,23,14],[10,5,14],[10,13,18],[3,1,8,18],[1,5,23,11],[8,1,9,11],[6,9,13,14],[6,9,15,14]") + + # ncnn inference + import test_ncnn_fuse_transpose_matmul_ncnn + b = test_ncnn_fuse_transpose_matmul_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0.shape) + print(b0.shape) + print(a0) + print(b0) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_matmul.py b/tools/pnnx/tests/ncnn/test_torch_matmul.py new file mode 100644 index 000000000..b8faa3286 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_matmul.py @@ -0,0 +1,107 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1): + a = torch.matmul(a0, a1) + b = torch.matmul(b0, b1) + c = torch.matmul(c0, c1) + d = torch.matmul(d0, d1) + e = torch.matmul(e0, e1) + f = torch.matmul(f0, f1) + g = torch.matmul(g0, g1) + h = torch.matmul(h0, h1) + i = torch.matmul(i0, i1) + j = torch.matmul(j0, j1) + k = torch.matmul(k0, k1) + l = torch.matmul(l0, l1) + m = torch.matmul(m0, m1) + n = torch.matmul(n0, n1) + o = torch.matmul(o0, o1) + p = torch.matmul(p0, p1) + return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(13) + a1 = torch.rand(13) + b0 = torch.rand(14) + b1 = torch.rand(14, 6) + c0 = torch.rand(13) + c1 = torch.rand(7, 13, 4) + d0 = torch.rand(15) + d1 = torch.rand(5, 7, 15, 9) + e0 = torch.rand(5, 12) + e1 = torch.rand(12) + f0 = torch.rand(10, 3, 4) + f1 = torch.rand(4) + g0 = torch.rand(6, 3, 7, 14) + g1 = torch.rand(14) + h0 = torch.rand(23, 14) + h1 = torch.rand(14, 25) + i0 = torch.rand(4, 5) + i1 = torch.rand(10, 5, 40) + j0 = torch.rand(14, 6) + j1 = torch.rand(2, 4, 6, 20) + k0 = torch.rand(10, 23, 14) + k1 = torch.rand(14, 5) + l0 = torch.rand(7, 8, 13, 14) + l1 = torch.rand(14, 35) + m0 = torch.rand(10, 23, 14) + m1 = torch.rand(10, 14, 5) + n0 = torch.rand(10, 13, 18) + n1 = torch.rand(3, 1, 18, 8) + o0 = torch.rand(1, 5, 23, 11) + o1 = torch.rand(8, 1, 11, 9) + p0 = torch.rand(6, 9, 13, 14) + p1 = torch.rand(6, 9, 14, 15) + + a = net(a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1)) + mod.save("test_torch_matmul.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_matmul.pt inputshape=[13],[13],[14],[14,6],[13],[7,13,4],[15],[5,7,15,9],[5,12],[12],[10,3,4],[4],[6,3,7,14],[14],[23,14],[14,25],[4,5],[10,5,40],[14,6],[2,4,6,20],[10,23,14],[14,5],[7,8,13,14],[14,35],[10,23,14],[10,14,5],[10,13,18],[3,1,18,8],[1,5,23,11],[8,1,11,9],[6,9,13,14],[6,9,14,15]") + + # ncnn inference + import test_torch_matmul_ncnn + b = test_torch_matmul_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0.shape) + print(b0.shape) + print(a0) + print(b0) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py b/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py new file mode 100644 index 000000000..8b0041928 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_select_to_unbind.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x0 = torch.select(x, 0, 0) + x1 = torch.select(x, 0, 1) + x2 = torch.select(x, 0, 2) + y0 = torch.select(x, 1, 0) + y1 = torch.select(x, 1, 1) + y2 = torch.select(x, 1, 2) + y3 = torch.select(x, 1, 3) + z0 = torch.select(x, 2, 0) + z1 = torch.select(x, 2, 1) + z2 = torch.select(x, 2, 2) + z3 = torch.select(x, 2, 3) + z4 = torch.select(x, 2, 4) + + return x0, x1, x2, y0, y1, y2, y3, z0, z1, z2, z3, z4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 4, 5) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fuse_select_to_unbind.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_select_to_unbind.pt inputshape=[3,4,5]") + + # pnnx inference + import test_pnnx_fuse_select_to_unbind_pnnx + b = test_pnnx_fuse_select_to_unbind_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_matmul.py b/tools/pnnx/tests/test_torch_matmul.py new file mode 100644 index 000000000..b47c0a1e4 --- /dev/null +++ b/tools/pnnx/tests/test_torch_matmul.py @@ -0,0 +1,103 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1): + a = torch.matmul(a0, a1) + b = torch.matmul(b0, b1) + c = torch.matmul(c0, c1) + d = torch.matmul(d0, d1) + e = torch.matmul(e0, e1) + f = torch.matmul(f0, f1) + g = torch.matmul(g0, g1) + h = torch.matmul(h0, h1) + i = torch.matmul(i0, i1) + j = torch.matmul(j0, j1) + k = torch.matmul(k0, k1) + l = torch.matmul(l0, l1) + m = torch.matmul(m0, m1) + n = torch.matmul(n0, n1) + o = torch.matmul(o0, o1) + p = torch.matmul(p0, p1) + return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + a0 = torch.rand(13) + a1 = torch.rand(13) + b0 = torch.rand(14) + b1 = torch.rand(14, 6) + c0 = torch.rand(13) + c1 = torch.rand(7, 13, 4) + d0 = torch.rand(15) + d1 = torch.rand(5, 7, 15, 9) + e0 = torch.rand(5, 12) + e1 = torch.rand(12) + f0 = torch.rand(10, 3, 4) + f1 = torch.rand(4) + g0 = torch.rand(6, 3, 7, 14) + g1 = torch.rand(14) + h0 = torch.rand(23, 14) + h1 = torch.rand(14, 25) + i0 = torch.rand(4, 5) + i1 = torch.rand(10, 5, 40) + j0 = torch.rand(14, 6) + j1 = torch.rand(2, 4, 6, 20) + k0 = torch.rand(10, 23, 14) + k1 = torch.rand(14, 5) + l0 = torch.rand(7, 8, 13, 14) + l1 = torch.rand(14, 35) + m0 = torch.rand(10, 23, 14) + m1 = torch.rand(10, 14, 5) + n0 = torch.rand(10, 13, 18) + n1 = torch.rand(3, 1, 18, 8) + o0 = torch.rand(1, 5, 23, 11) + o1 = torch.rand(8, 1, 11, 9) + p0 = torch.rand(6, 9, 13, 14) + p1 = torch.rand(6, 9, 14, 15) + + a = net(a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1) + + # export torchscript + mod = torch.jit.trace(net, (a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1)) + mod.save("test_torch_matmul.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_matmul.pt inputshape=[13],[13],[14],[14,6],[13],[7,13,4],[15],[5,7,15,9],[5,12],[12],[10,3,4],[4],[6,3,7,14],[14],[23,14],[14,25],[4,5],[10,5,40],[14,6],[2,4,6,20],[10,23,14],[14,5],[7,8,13,14],[14,35],[10,23,14],[10,14,5],[10,13,18],[3,1,18,8],[1,5,23,11],[8,1,11,9],[6,9,13,14],[6,9,14,15]") + + # pnnx inference + import test_torch_matmul_pnnx + b = test_torch_matmul_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)