| @@ -149,6 +149,7 @@ ncnn_add_layer(ConvolutionDepthWise1D) | |||
| ncnn_add_layer(Convolution3D) | |||
| ncnn_add_layer(ConvolutionDepthWise3D) | |||
| ncnn_add_layer(Pooling3D) | |||
| ncnn_add_layer(MatMul) | |||
| if(NCNN_VULKAN) | |||
| ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) | |||
| @@ -0,0 +1,401 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "matmul.h" | |||
| namespace ncnn { | |||
| MatMul::MatMul() | |||
| { | |||
| one_blob_only = false; | |||
| support_inplace = false; | |||
| } | |||
| int MatMul::load_param(const ParamDict& pd) | |||
| { | |||
| transB = pd.get(0, 0); | |||
| return 0; | |||
| } | |||
| static void transpose(const Mat& X, Mat& XT, const Option& opt) | |||
| { | |||
| const int w = X.w; | |||
| const int h = X.h; | |||
| const float* pX = X; | |||
| float* pXT = XT; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int i = 0; i < w; i++) | |||
| { | |||
| float* ptr = pXT + i * h; | |||
| for (int j = 0; j < h; j++) | |||
| { | |||
| ptr[j] = pX[j * w + i]; | |||
| } | |||
| } | |||
| } | |||
| static void matmul_transb(const Mat& A, const Mat& B, Mat& top_blob, const Option& opt) | |||
| { | |||
| const int M = A.h; | |||
| const int K = A.w; // assert A.w == B.w | |||
| const int N = B.h; | |||
| const float* pA = A; | |||
| const float* pB = B; | |||
| float* pOut = top_blob; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int i = 0; i < M; i++) | |||
| { | |||
| const float* ptrA = pA + i * K; | |||
| float* outptr = pOut + i * N; | |||
| for (int j = 0; j < N; j++) | |||
| { | |||
| const float* ptrB = pB + j * K; | |||
| float sum = 0.f; | |||
| for (int k = 0; k < K; k++) | |||
| { | |||
| sum += ptrA[k] * ptrB[k]; | |||
| } | |||
| *outptr++ = sum; | |||
| } | |||
| } | |||
| } | |||
| int MatMul::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const | |||
| { | |||
| const Mat& A = bottom_blobs[0]; | |||
| const Mat& B = bottom_blobs[1]; | |||
| Mat& top_blob = top_blobs[0]; | |||
| const int Adims = A.dims; | |||
| const int Bdims = B.dims; | |||
| const int max_ABdims = std::max(Adims, Bdims); | |||
| const size_t elemsize = A.elemsize; | |||
| if (Adims == 1 && Bdims == 1) | |||
| { | |||
| // dot product | |||
| top_blob.create(1, elemsize, opt.blob_allocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| const int K = A.w; // assert A.w == B.w | |||
| const float* ptrA = A; | |||
| const float* ptrB = B; | |||
| float sum = 0.f; | |||
| for (int k = 0; k < K; k++) | |||
| { | |||
| sum += ptrA[k] * ptrB[k]; | |||
| } | |||
| top_blob[0] = sum; | |||
| } | |||
| else if (Adims == 2 && Bdims == 2) | |||
| { | |||
| // matrix multiply | |||
| const int M = A.h; | |||
| const int N = transB == 0 ? B.w : B.h; | |||
| top_blob.create(N, M, elemsize, opt.blob_allocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| Mat BT; | |||
| if (transB == 0) | |||
| { | |||
| BT.create(B.h, B.w, elemsize, opt.workspace_allocator); | |||
| if (BT.empty()) | |||
| return -100; | |||
| transpose(B, BT, opt); | |||
| } | |||
| else | |||
| { | |||
| BT = B; | |||
| } | |||
| matmul_transb(A, BT, top_blob, opt); | |||
| } | |||
| else if (Adims == 1 && Bdims == 2) | |||
| { | |||
| // matrix multiply | |||
| const int N = transB == 0 ? B.w : B.h; | |||
| Mat top_blob1(N, 1, elemsize, opt.blob_allocator); | |||
| if (top_blob1.empty()) | |||
| return -100; | |||
| Mat A1 = A.reshape(A.w, 1); | |||
| Mat BT; | |||
| if (transB == 0) | |||
| { | |||
| BT.create(B.h, B.w, elemsize, opt.workspace_allocator); | |||
| if (BT.empty()) | |||
| return -100; | |||
| transpose(B, BT, opt); | |||
| } | |||
| else | |||
| { | |||
| BT = B; | |||
| } | |||
| matmul_transb(A1, BT, top_blob1, opt); | |||
| top_blob = top_blob1.reshape(N); | |||
| } | |||
| else if (Adims == 2 && Bdims == 1) | |||
| { | |||
| // matrix multiply | |||
| const int M = A.h; | |||
| Mat top_blob1(1, M, elemsize, opt.blob_allocator); | |||
| if (top_blob1.empty()) | |||
| return -100; | |||
| Mat BT = B.reshape(B.w, 1); | |||
| matmul_transb(A, BT, top_blob1, opt); | |||
| top_blob = top_blob1.reshape(M); | |||
| } | |||
| else if (Adims == 1 && Bdims > 2) | |||
| { | |||
| // batched matrix multiply | |||
| const int N = transB == 0 ? B.w : B.h; | |||
| const int batch_size = B.d * B.c; | |||
| Mat top_blob1(N, 1, batch_size, elemsize, opt.blob_allocator); | |||
| if (top_blob1.empty()) | |||
| return -100; | |||
| Mat A1 = A.reshape(A.w, 1); | |||
| Mat B1 = B.reshape(B.w, B.h, batch_size); | |||
| for (int p = 0; p < batch_size; p++) | |||
| { | |||
| Mat BT; | |||
| if (transB == 0) | |||
| { | |||
| BT.create(B.h, B.w, elemsize, opt.workspace_allocator); | |||
| if (BT.empty()) | |||
| return -100; | |||
| transpose(B1.channel(p), BT, opt); | |||
| } | |||
| else | |||
| { | |||
| BT = B1.channel(p); | |||
| } | |||
| Mat top_blob1_p = top_blob1.channel(p); | |||
| matmul_transb(A1, BT, top_blob1_p, opt); | |||
| } | |||
| if (Bdims == 3) | |||
| top_blob = top_blob1.reshape(N, B.d * B.c); | |||
| else | |||
| top_blob = top_blob1.reshape(N, B.d, B.c); | |||
| } | |||
| else if (Adims > 2 && Bdims == 1) | |||
| { | |||
| // batched matrix multiply | |||
| const int M = A.h; | |||
| const int batch_size = A.d * A.c; | |||
| Mat top_blob1(1, M, batch_size, elemsize, opt.blob_allocator); | |||
| if (top_blob1.empty()) | |||
| return -100; | |||
| Mat A1 = A.reshape(A.w, A.h, batch_size); | |||
| Mat BT = B.reshape(B.w, 1); | |||
| for (int p = 0; p < batch_size; p++) | |||
| { | |||
| Mat top_blob1_p = top_blob1.channel(p); | |||
| matmul_transb(A1.channel(p), BT, top_blob1_p, opt); | |||
| } | |||
| if (Adims == 3) | |||
| top_blob = top_blob1.reshape(M, A.d * A.c); | |||
| else | |||
| top_blob = top_blob1.reshape(M, A.d, A.c); | |||
| } | |||
| else if (max_ABdims == 3) | |||
| { | |||
| Mat A1 = Adims == 2 ? A.reshape(A.w, A.h, 1) : A; | |||
| Mat B1 = Bdims == 2 ? B.reshape(B.w, B.h, 1) : B; | |||
| const int M = A1.h; | |||
| const int N = transB == 0 ? B1.w : B1.h; | |||
| const int batch_size = std::max(A1.c, B1.c); | |||
| top_blob.create(N, M, batch_size, elemsize, opt.blob_allocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| Mat BT0; | |||
| if (B1.c == 1) | |||
| { | |||
| if (transB == 0) | |||
| { | |||
| BT0.create(B1.h, B1.w, elemsize, opt.workspace_allocator); | |||
| if (BT0.empty()) | |||
| return -100; | |||
| transpose(B1.channel(0), BT0, opt); | |||
| } | |||
| else | |||
| { | |||
| BT0 = B1.channel(0); | |||
| } | |||
| } | |||
| for (int p = 0; p < batch_size; p++) | |||
| { | |||
| int Ap = A1.c == 1 ? 0 : p; | |||
| int Bp = B1.c == 1 ? 0 : p; | |||
| Mat BT; | |||
| if (B1.c == 1) | |||
| { | |||
| BT = BT0; | |||
| } | |||
| else | |||
| { | |||
| if (transB == 0) | |||
| { | |||
| BT.create(B1.h, B1.w, elemsize, opt.workspace_allocator); | |||
| if (BT.empty()) | |||
| return -100; | |||
| transpose(B1.channel(Bp), BT, opt); | |||
| } | |||
| else | |||
| { | |||
| BT = B1.channel(Bp); | |||
| } | |||
| } | |||
| Mat top_blob_p = top_blob.channel(p); | |||
| matmul_transb(A1.channel(Ap), BT, top_blob_p, opt); | |||
| } | |||
| } | |||
| else if (max_ABdims == 4) | |||
| { | |||
| Mat A1 = Adims == 3 ? A.reshape(A.w, A.h, A.c, 1) : A; | |||
| Mat B1 = Bdims == 3 ? B.reshape(B.w, B.h, B.c, 1) : B; | |||
| const int M = A1.h; | |||
| const int N = transB == 0 ? B1.w : B1.h; | |||
| const int batch_size_d = std::max(A1.d, B1.d); | |||
| const int batch_size_c = std::max(A1.c, B1.c); | |||
| top_blob.create(N, M, batch_size_d, batch_size_c, elemsize, opt.blob_allocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| Mat BT00; | |||
| if (B1.d == 1 && B1.c == 1) | |||
| { | |||
| if (transB == 0) | |||
| { | |||
| BT00.create(B1.h, B1.w, elemsize, opt.workspace_allocator); | |||
| if (BT00.empty()) | |||
| return -100; | |||
| transpose(B1.channel(0).depth(0), BT00, opt); | |||
| } | |||
| else | |||
| { | |||
| BT00 = B1.channel(0).depth(0); | |||
| } | |||
| } | |||
| for (int p = 0; p < batch_size_c; p++) | |||
| { | |||
| int Ap = A1.c == 1 ? 0 : p; | |||
| int Bp = B1.c == 1 ? 0 : p; | |||
| Mat BT0x; | |||
| if (B1.d == 1 && B1.c != 1) | |||
| { | |||
| if (transB == 0) | |||
| { | |||
| BT0x.create(B1.h, B1.w, elemsize, opt.workspace_allocator); | |||
| if (BT0x.empty()) | |||
| return -100; | |||
| transpose(B1.channel(Bp).depth(0), BT0x, opt); | |||
| } | |||
| else | |||
| { | |||
| BT0x = B1.channel(Bp).depth(0); | |||
| } | |||
| } | |||
| for (int q = 0; q < batch_size_d; q++) | |||
| { | |||
| int Ad = A1.d == 1 ? 0 : q; | |||
| int Bd = B1.d == 1 ? 0 : q; | |||
| Mat BT; | |||
| if (B1.d == 1 && B1.c == 1) | |||
| { | |||
| BT = BT00; | |||
| } | |||
| else if (B1.d == 1 && B1.c != 1) | |||
| { | |||
| BT = BT0x; | |||
| } | |||
| else | |||
| { | |||
| if (transB == 0) | |||
| { | |||
| BT.create(B1.h, B1.w, elemsize, opt.workspace_allocator); | |||
| if (BT.empty()) | |||
| return -100; | |||
| transpose(B1.channel(Bp).depth(Bd), BT, opt); | |||
| } | |||
| else | |||
| { | |||
| BT = B1.channel(Bp).depth(Bd); | |||
| } | |||
| } | |||
| Mat top_blob_p_q = top_blob.channel(p).depth(q); | |||
| matmul_transb(A1.channel(Ap).depth(Ad), BT, top_blob_p_q, opt); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| NCNN_LOGE("impossible matmul %d %d", Adims, Bdims); | |||
| return -1; | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace ncnn | |||
| @@ -0,0 +1,37 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #ifndef LAYER_MATMUL_H | |||
| #define LAYER_MATMUL_H | |||
| #include "layer.h" | |||
| namespace ncnn { | |||
| class MatMul : public Layer | |||
| { | |||
| public: | |||
| MatMul(); | |||
| virtual int load_param(const ParamDict& pd); | |||
| virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const; | |||
| public: | |||
| int transB; | |||
| }; | |||
| } // namespace ncnn | |||
| #endif // LAYER_MATMUL_H | |||
| @@ -90,6 +90,7 @@ ncnn_add_layer_test(Interp) | |||
| ncnn_add_layer_test(LayerNorm) | |||
| ncnn_add_layer_test(LRN) | |||
| ncnn_add_layer_test(LSTM) | |||
| ncnn_add_layer_test(MatMul) | |||
| ncnn_add_layer_test(MemoryData) | |||
| ncnn_add_layer_test(Mish) | |||
| ncnn_add_layer_test(MultiHeadAttention) | |||
| @@ -0,0 +1,255 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "layer/matmul.h" | |||
| #include "testutil.h" | |||
| static int test_matmul(const ncnn::Mat& a, const ncnn::Mat& b) | |||
| { | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, 0); // transB | |||
| std::vector<ncnn::Mat> weights(0); | |||
| std::vector<ncnn::Mat> as(2); | |||
| as[0] = a; | |||
| as[1] = b; | |||
| int ret = test_layer<ncnn::MatMul>("MatMul", pd, weights, as); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_matmul failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c); | |||
| } | |||
| return ret; | |||
| } | |||
| static int test_matmul_transb(const ncnn::Mat& a, const ncnn::Mat& b) | |||
| { | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, 1); // transB | |||
| std::vector<ncnn::Mat> weights(0); | |||
| std::vector<ncnn::Mat> as(2); | |||
| as[0] = a; | |||
| as[1] = b; | |||
| int ret = test_layer<ncnn::MatMul>("MatMul", pd, weights, as); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_matmul_transb failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c); | |||
| } | |||
| return ret; | |||
| } | |||
| static int test_matmul_0() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(124), RandomMat(124)) | |||
| || test_matmul(RandomMat(127), RandomMat(127)) | |||
| || test_matmul(RandomMat(128), RandomMat(128)); | |||
| } | |||
| static int test_matmul_1() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(5), RandomMat(6, 5)) | |||
| || test_matmul(RandomMat(16), RandomMat(12, 16)) | |||
| || test_matmul(RandomMat(11), RandomMat(16, 11)) | |||
| || test_matmul_transb(RandomMat(5), RandomMat(5, 6)) | |||
| || test_matmul_transb(RandomMat(16), RandomMat(16, 12)) | |||
| || test_matmul_transb(RandomMat(11), RandomMat(11, 16)); | |||
| } | |||
| static int test_matmul_2() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(13), RandomMat(7, 13, 12)) | |||
| || test_matmul(RandomMat(24), RandomMat(6, 24, 16)) | |||
| || test_matmul(RandomMat(20), RandomMat(8, 20, 19)) | |||
| || test_matmul_transb(RandomMat(13), RandomMat(13, 7, 12)) | |||
| || test_matmul_transb(RandomMat(24), RandomMat(24, 6, 16)) | |||
| || test_matmul_transb(RandomMat(20), RandomMat(20, 8, 19)); | |||
| } | |||
| static int test_matmul_3() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(13), RandomMat(7, 13, 5, 12)) | |||
| || test_matmul(RandomMat(24), RandomMat(6, 24, 4, 16)) | |||
| || test_matmul(RandomMat(20), RandomMat(8, 20, 3, 19)) | |||
| || test_matmul_transb(RandomMat(13), RandomMat(13, 7, 5, 12)) | |||
| || test_matmul_transb(RandomMat(24), RandomMat(24, 6, 4, 16)) | |||
| || test_matmul_transb(RandomMat(20), RandomMat(20, 8, 3, 19)); | |||
| } | |||
| static int test_matmul_4() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(5, 6), RandomMat(5)) | |||
| || test_matmul(RandomMat(16, 12), RandomMat(16)) | |||
| || test_matmul(RandomMat(11, 16), RandomMat(11)); | |||
| } | |||
| static int test_matmul_5() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(32, 3, 10), RandomMat(32)) | |||
| || test_matmul(RandomMat(31, 4, 16), RandomMat(31)) | |||
| || test_matmul(RandomMat(28, 5, 28), RandomMat(28)); | |||
| } | |||
| static int test_matmul_6() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(32, 3, 4, 10), RandomMat(32)) | |||
| || test_matmul(RandomMat(31, 4, 5, 16), RandomMat(31)) | |||
| || test_matmul(RandomMat(18, 5, 6, 28), RandomMat(18)); | |||
| } | |||
| static int test_matmul_7() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 10), RandomMat(5, 14)) | |||
| || test_matmul(RandomMat(16, 16), RandomMat(10, 16)) | |||
| || test_matmul(RandomMat(14, 28), RandomMat(9, 14)) | |||
| || test_matmul_transb(RandomMat(14, 10), RandomMat(14, 5)) | |||
| || test_matmul_transb(RandomMat(16, 16), RandomMat(16, 10)) | |||
| || test_matmul_transb(RandomMat(14, 28), RandomMat(14, 9)); | |||
| } | |||
| static int test_matmul_8() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(5, 4), RandomMat(4, 5, 12)) | |||
| || test_matmul(RandomMat(5, 14), RandomMat(5, 5, 16)) | |||
| || test_matmul(RandomMat(5, 24), RandomMat(6, 5, 19)) | |||
| || test_matmul_transb(RandomMat(5, 4), RandomMat(5, 4, 12)) | |||
| || test_matmul_transb(RandomMat(5, 14), RandomMat(5, 5, 16)) | |||
| || test_matmul_transb(RandomMat(5, 24), RandomMat(5, 6, 19)); | |||
| } | |||
| static int test_matmul_9() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(5, 4), RandomMat(4, 5, 2, 12)) | |||
| || test_matmul(RandomMat(5, 14), RandomMat(5, 5, 3, 16)) | |||
| || test_matmul(RandomMat(5, 24), RandomMat(6, 5, 4, 19)) | |||
| || test_matmul_transb(RandomMat(5, 4), RandomMat(5, 4, 2, 12)) | |||
| || test_matmul_transb(RandomMat(5, 14), RandomMat(5, 5, 3, 16)) | |||
| || test_matmul_transb(RandomMat(5, 24), RandomMat(5, 6, 4, 19)); | |||
| } | |||
| static int test_matmul_10() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 23, 10), RandomMat(5, 14)) | |||
| || test_matmul(RandomMat(16, 22, 16), RandomMat(10, 16)) | |||
| || test_matmul(RandomMat(14, 20, 28), RandomMat(9, 14)) | |||
| || test_matmul_transb(RandomMat(14, 23, 10), RandomMat(14, 5)) | |||
| || test_matmul_transb(RandomMat(16, 22, 16), RandomMat(16, 10)) | |||
| || test_matmul_transb(RandomMat(14, 20, 28), RandomMat(14, 9)); | |||
| } | |||
| static int test_matmul_11() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 13, 2, 10), RandomMat(5, 14)) | |||
| || test_matmul(RandomMat(16, 12, 3, 16), RandomMat(10, 16)) | |||
| || test_matmul(RandomMat(14, 10, 4, 28), RandomMat(9, 14)) | |||
| || test_matmul_transb(RandomMat(14, 13, 2, 10), RandomMat(14, 5)) | |||
| || test_matmul_transb(RandomMat(16, 12, 3, 16), RandomMat(16, 10)) | |||
| || test_matmul_transb(RandomMat(14, 10, 4, 28), RandomMat(14, 9)); | |||
| } | |||
| static int test_matmul_12() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 23, 10), RandomMat(5, 14, 10)) | |||
| || test_matmul(RandomMat(16, 22, 16), RandomMat(10, 16, 16)) | |||
| || test_matmul(RandomMat(14, 20, 28), RandomMat(9, 14, 28)) | |||
| || test_matmul_transb(RandomMat(14, 23, 10), RandomMat(14, 5, 10)) | |||
| || test_matmul_transb(RandomMat(16, 22, 16), RandomMat(16, 10, 16)) | |||
| || test_matmul_transb(RandomMat(14, 20, 28), RandomMat(14, 9, 28)); | |||
| } | |||
| static int test_matmul_13() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 23, 10), RandomMat(5, 14, 1, 16)) | |||
| || test_matmul(RandomMat(16, 22, 9), RandomMat(10, 16, 1, 17)) | |||
| || test_matmul(RandomMat(14, 20, 8), RandomMat(9, 14, 1, 18)) | |||
| || test_matmul_transb(RandomMat(14, 23, 10), RandomMat(14, 5, 1, 16)) | |||
| || test_matmul_transb(RandomMat(16, 22, 9), RandomMat(16, 10, 1, 17)) | |||
| || test_matmul_transb(RandomMat(14, 20, 8), RandomMat(14, 9, 1, 18)); | |||
| } | |||
| static int test_matmul_14() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 23, 10, 1), RandomMat(5, 14, 1, 16)) | |||
| || test_matmul(RandomMat(16, 22, 9, 1), RandomMat(10, 16, 1, 17)) | |||
| || test_matmul(RandomMat(14, 20, 8, 1), RandomMat(9, 14, 1, 18)) | |||
| || test_matmul_transb(RandomMat(14, 23, 10, 1), RandomMat(14, 5, 1, 16)) | |||
| || test_matmul_transb(RandomMat(16, 22, 9, 1), RandomMat(16, 10, 1, 17)) | |||
| || test_matmul_transb(RandomMat(14, 20, 8, 1), RandomMat(14, 9, 1, 18)); | |||
| } | |||
| static int test_matmul_15() | |||
| { | |||
| return 0 | |||
| || test_matmul(RandomMat(14, 23, 10, 16), RandomMat(5, 14, 10, 16)) | |||
| || test_matmul(RandomMat(16, 22, 9, 17), RandomMat(10, 16, 9, 17)) | |||
| || test_matmul(RandomMat(14, 20, 8, 18), RandomMat(9, 14, 8, 18)) | |||
| || test_matmul_transb(RandomMat(14, 23, 10, 16), RandomMat(14, 5, 10, 16)) | |||
| || test_matmul_transb(RandomMat(16, 22, 9, 17), RandomMat(16, 10, 9, 17)) | |||
| || test_matmul_transb(RandomMat(14, 20, 8, 18), RandomMat(14, 9, 8, 18)); | |||
| } | |||
| int main() | |||
| { | |||
| SRAND(7767517); | |||
| return 0 | |||
| || test_matmul_0() | |||
| || test_matmul_1() | |||
| || test_matmul_2() | |||
| || test_matmul_3() | |||
| || test_matmul_4() | |||
| || test_matmul_5() | |||
| || test_matmul_6() | |||
| || test_matmul_7() | |||
| || test_matmul_8() | |||
| || test_matmul_9() | |||
| || test_matmul_10() | |||
| || test_matmul_11() | |||
| || test_matmul_12() | |||
| || test_matmul_13() | |||
| || test_matmul_14() | |||
| || test_matmul_15(); | |||
| } | |||
| @@ -234,7 +234,9 @@ set(pnnx_pass_level4_SRCS | |||
| set(pnnx_pass_level5_SRCS | |||
| pass_level5/eliminate_dropout.cpp | |||
| pass_level5/eliminate_identity_operator.cpp | |||
| pass_level5/eliminate_maxpool_indices.cpp | |||
| pass_level5/eliminate_noop_pad.cpp | |||
| pass_level5/eliminate_slice.cpp | |||
| pass_level5/eliminate_view_reshape.cpp | |||
| pass_level5/eval_expression.cpp | |||
| @@ -247,6 +249,7 @@ set(pnnx_pass_level5_SRCS | |||
| pass_level5/fuse_convtranspose2d_batchnorm2d.cpp | |||
| pass_level5/fuse_contiguous_view.cpp | |||
| pass_level5/fuse_linear_batchnorm1d.cpp | |||
| pass_level5/fuse_select_to_unbind.cpp | |||
| pass_level5/fuse_slice_indices.cpp | |||
| pass_level5/unroll_rnn_op.cpp | |||
| ) | |||
| @@ -274,6 +277,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/fuse_deconvolution_activation.cpp | |||
| pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp | |||
| pass_ncnn/fuse_innerproduct_activation.cpp | |||
| pass_ncnn/fuse_transpose_matmul.cpp | |||
| pass_ncnn/F_adaptive_avg_pool1d.cpp | |||
| pass_ncnn/F_adaptive_avg_pool2d.cpp | |||
| @@ -390,6 +394,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/torch_clone.cpp | |||
| pass_ncnn/torch_flatten.cpp | |||
| pass_ncnn/torch_logsumexp.cpp | |||
| pass_ncnn/torch_matmul.cpp | |||
| pass_ncnn/torch_mean.cpp | |||
| pass_ncnn/torch_permute.cpp | |||
| pass_ncnn/torch_prod.cpp | |||
| @@ -268,6 +268,38 @@ Parameter::Parameter(const torch::jit::Value* value) | |||
| { | |||
| } | |||
| bool operator==(const Parameter& lhs, const Parameter& rhs) | |||
| { | |||
| if (lhs.type != rhs.type) | |||
| return false; | |||
| if (lhs.type == 0) | |||
| return true; | |||
| if (lhs.type == 1 && lhs.b == rhs.b) | |||
| return true; | |||
| if (lhs.type == 2 && lhs.i == rhs.i) | |||
| return true; | |||
| if (lhs.type == 3 && lhs.f == rhs.f) | |||
| return true; | |||
| if (lhs.type == 4 && lhs.s == rhs.s) | |||
| return true; | |||
| if (lhs.type == 5 && lhs.ai == rhs.ai) | |||
| return true; | |||
| if (lhs.type == 6 && lhs.af == rhs.af) | |||
| return true; | |||
| if (lhs.type == 7 && lhs.as == rhs.as) | |||
| return true; | |||
| return false; | |||
| } | |||
| Attribute::Attribute(const at::Tensor& t) | |||
| { | |||
| type = get_at_tensor_type(t.scalar_type()); | |||
| @@ -343,6 +375,23 @@ Attribute::Attribute(const std::initializer_list<int>& _shape, const std::vector | |||
| } | |||
| } | |||
| bool operator==(const Attribute& lhs, const Attribute& rhs) | |||
| { | |||
| if (lhs.type != rhs.type) | |||
| return false; | |||
| if (lhs.type == 0) | |||
| return true; | |||
| if (lhs.shape != rhs.shape) | |||
| return false; | |||
| if (lhs.data != rhs.data) | |||
| return false; | |||
| return true; | |||
| } | |||
| Parameter Parameter::parse_from_string(const std::string& value) | |||
| { | |||
| Parameter p; | |||
| @@ -132,6 +132,8 @@ public: | |||
| std::vector<std::string> as; | |||
| }; | |||
| bool operator==(const Parameter& lhs, const Parameter& rhs); | |||
| class Attribute | |||
| { | |||
| public: | |||
| @@ -151,6 +153,8 @@ public: | |||
| std::vector<char> data; | |||
| }; | |||
| bool operator==(const Attribute& lhs, const Attribute& rhs); | |||
| class Operator; | |||
| class Operand | |||
| { | |||
| @@ -16,6 +16,8 @@ | |||
| #include "pass_level5/fold_constants.h" | |||
| #include "pass_level5/eliminate_dropout.h" | |||
| #include "pass_level5/eliminate_identity_operator.h" | |||
| #include "pass_level5/eliminate_noop_pad.h" | |||
| #include "pass_level5/eliminate_slice.h" | |||
| #include "pass_level5/eliminate_view_reshape.h" | |||
| #include "pass_level5/eval_expression.h" | |||
| @@ -27,6 +29,7 @@ | |||
| #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" | |||
| #include "pass_level5/fuse_contiguous_view.h" | |||
| #include "pass_level5/fuse_linear_batchnorm1d.h" | |||
| #include "pass_level5/fuse_select_to_unbind.h" | |||
| #include "pass_level5/fuse_slice_indices.h" | |||
| #include "pass_level4/dead_code_elimination.h" | |||
| #include "pass_level4/canonicalize.h" | |||
| @@ -44,6 +47,10 @@ void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_cons | |||
| fuse_slice_indices(g); | |||
| eliminate_identity_operator(g); | |||
| fuse_select_to_unbind(g); | |||
| fuse_conv1d_batchnorm1d(g); | |||
| fuse_conv2d_batchnorm2d(g); | |||
| @@ -54,12 +61,14 @@ void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_cons | |||
| fuse_linear_batchnorm1d(g); | |||
| eliminate_noop_pad(g); | |||
| eliminate_dropout(g); | |||
| fuse_contiguous_view(g); | |||
| eliminate_view_reshape(g); | |||
| eliminate_dropout(g); | |||
| fuse_channel_shuffle(g); | |||
| fold_constants(g, foldable_constants); | |||
| @@ -0,0 +1,118 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "eliminate_identity_operator.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void eliminate_identity_operator(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op0 = graph.ops[i]; | |||
| if (op0->type == "pnnx.Input" || op0->type == "pnnx.Output") | |||
| continue; | |||
| Operator* op1 = 0; | |||
| for (size_t j = i + 1; j < graph.ops.size(); j++) | |||
| { | |||
| op1 = graph.ops[j]; | |||
| if (op1->type == "pnnx.Input" || op1->type == "pnnx.Output") | |||
| continue; | |||
| if (op0->type != op1->type) | |||
| continue; | |||
| if (op0->inputs != op1->inputs) | |||
| continue; | |||
| if (op0->outputs.size() != op1->outputs.size()) | |||
| continue; | |||
| if (op0->params != op1->params) | |||
| continue; | |||
| if (op0->attrs != op1->attrs) | |||
| continue; | |||
| // we find same operator with same inputs | |||
| matched = true; | |||
| break; | |||
| } | |||
| if (!matched) | |||
| continue; | |||
| // fprintf(stderr, "eliminate_identity_operator %s %s %s\n", op0->type.c_str(), op0->name.c_str(), op1->name.c_str()); | |||
| int input_count = (int)op0->inputs.size(); | |||
| for (int j = 0; j < input_count; j++) | |||
| { | |||
| Operand* in0 = op0->inputs[j]; | |||
| in0->consumers.erase(std::find(in0->consumers.begin(), in0->consumers.end(), op1)); | |||
| } | |||
| int output_count = (int)op0->outputs.size(); | |||
| for (int j = 0; j < output_count; j++) | |||
| { | |||
| Operand* out0 = op0->outputs[j]; | |||
| Operand* out1 = op1->outputs[j]; | |||
| for (auto x : out1->consumers) | |||
| { | |||
| for (size_t k = 0; k < x->inputs.size(); k++) | |||
| { | |||
| if (x->inputs[k] == out1) | |||
| x->inputs[k] = out0; | |||
| } | |||
| out0->consumers.push_back(x); | |||
| } | |||
| out1->consumers.clear(); | |||
| } | |||
| // delete op1 and its output operands | |||
| for (int j = 0; j < output_count; j++) | |||
| { | |||
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op1->outputs[j])); | |||
| delete op1->outputs[j]; | |||
| } | |||
| op1->inputs.clear(); | |||
| op1->outputs.clear(); | |||
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op1)); | |||
| delete op1; | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void eliminate_identity_operator(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,91 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "eliminate_noop_pad.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void eliminate_noop_pad(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "F.pad") | |||
| continue; | |||
| const std::vector<int>& pad = op->params.at("pad").ai; | |||
| bool noop_pad = true; | |||
| for (auto p : pad) | |||
| { | |||
| if (p != 0) | |||
| { | |||
| noop_pad = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!noop_pad) | |||
| continue; | |||
| // delete noop-like pad | |||
| matched = true; | |||
| for (auto& x : op->inputs) | |||
| { | |||
| x->remove_consumer(op); | |||
| } | |||
| Operand* pad_out = op->outputs[0]; | |||
| for (auto& x : pad_out->consumers) | |||
| { | |||
| for (size_t j = 0; j < x->inputs.size(); j++) | |||
| { | |||
| if (x->inputs[j] == pad_out) | |||
| x->inputs[j] = op->inputs[0]; | |||
| } | |||
| op->inputs[0]->consumers.push_back(x); | |||
| } | |||
| pad_out->producer = 0; | |||
| pad_out->consumers.clear(); | |||
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), pad_out)); | |||
| delete pad_out; | |||
| op->inputs.clear(); | |||
| op->outputs.clear(); | |||
| graph.ops.erase(graph.ops.begin() + i); | |||
| delete op; | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void eliminate_noop_pad(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,111 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_select_to_unbind.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void fuse_select_to_unbind(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "Tensor.select") | |||
| continue; | |||
| Operand* op_in = op->inputs[0]; | |||
| int input_rank = op_in->shape.size(); | |||
| if (input_rank == 0) | |||
| continue; | |||
| int dim = op->params.at("dim").i; | |||
| const int select_dimsize = op_in->shape[dim]; | |||
| // select 0..n | |||
| std::vector<int> select_n(select_dimsize, 0); | |||
| std::vector<Operator*> select_n_ops(select_dimsize, 0); | |||
| for (auto x : op_in->consumers) | |||
| { | |||
| if (x->type != "Tensor.select") | |||
| continue; | |||
| if (x->inputs[0] != op_in) | |||
| continue; | |||
| int dim2 = x->params.at("dim").i; | |||
| int index2 = x->params.at("index").i; | |||
| if (dim == dim2) | |||
| { | |||
| select_n[index2] = 1; | |||
| select_n_ops[index2] = x; | |||
| } | |||
| } | |||
| bool select_full_index = true; | |||
| for (auto x : select_n) | |||
| { | |||
| if (x == 0) | |||
| { | |||
| select_full_index = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!select_full_index) | |||
| continue; | |||
| matched = true; | |||
| // delete all select ops and replace with unbind | |||
| Operator* op_unbind = graph.new_operator_before("torch.unbind", op->name, op); | |||
| op_unbind->params["dim"] = dim; | |||
| op_unbind->inputs.push_back(op_in); | |||
| for (int j = 0; j < select_dimsize; j++) | |||
| { | |||
| op_in->consumers.erase(std::find(op_in->consumers.begin(), op_in->consumers.end(), select_n_ops[j])); | |||
| } | |||
| op_in->consumers.push_back(op_unbind); | |||
| op_unbind->outputs.resize(select_dimsize); | |||
| for (int j = 0; j < select_dimsize; j++) | |||
| { | |||
| op_unbind->outputs[j] = select_n_ops[j]->outputs[0]; | |||
| select_n_ops[j]->outputs[0]->producer = op_unbind; | |||
| } | |||
| for (int j = 0; j < select_dimsize; j++) | |||
| { | |||
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), select_n_ops[j])); | |||
| delete select_n_ops[j]; | |||
| } | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void fuse_select_to_unbind(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -36,6 +36,7 @@ | |||
| #include "pass_ncnn/fuse_deconvolution_activation.h" | |||
| #include "pass_ncnn/fuse_deconvolutiondepthwise_activation.h" | |||
| #include "pass_ncnn/fuse_innerproduct_activation.h" | |||
| #include "pass_ncnn/fuse_transpose_matmul.h" | |||
| #include "pass_level4/dead_code_elimination.h" | |||
| #include "pass_level4/canonicalize.h" | |||
| @@ -92,6 +93,7 @@ void pass_ncnn(Graph& g) | |||
| ncnn::insert_split(g); | |||
| ncnn::eliminate_noop(g); | |||
| ncnn::fuse_transpose_matmul(g); | |||
| ncnn::fuse_convolution_activation(g); | |||
| ncnn::fuse_convolution1d_activation(g); | |||
| ncnn::fuse_convolutiondepthwise_activation(g); | |||
| @@ -48,6 +48,19 @@ void convert_attribute(Graph& graph) | |||
| new_shape.push_back(data.shape[i]); | |||
| } | |||
| if (new_shape.size() == 5 && batch_index == 233) | |||
| { | |||
| if (new_shape[0] == 1) | |||
| { | |||
| fprintf(stderr, "assume pnnx attribute 5-rank tensor has batch_index 0\n"); | |||
| new_shape.erase(new_shape.begin()); | |||
| } | |||
| else | |||
| { | |||
| fprintf(stderr, "pnnx attribute 5-rank tensor is not supported yet!\n"); | |||
| } | |||
| } | |||
| if (new_shape.size() == 1) | |||
| { | |||
| op->params["0"] = new_shape[0]; | |||
| @@ -0,0 +1,66 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_transpose_matmul.h" | |||
| #include "pass_level2.h" | |||
| #include <float.h> | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class fuse_transpose_matmul_pass : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_a 0 1 a | |||
| pnnx.Input input_b 0 1 b | |||
| Permute op_0 1 1 b bt 0=1 | |||
| MatMul op_1 2 1 a bt out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "MatMul"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "matmultransb"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& /*captured_attrs*/) const | |||
| { | |||
| op->params["0"] = 1; | |||
| } | |||
| }; | |||
| void fuse_transpose_matmul(Graph& graph) | |||
| { | |||
| fuse_transpose_matmul_pass a; | |||
| int opindex = 0; | |||
| pnnx_graph_rewrite(graph, &a, opindex); | |||
| } | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,25 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void fuse_transpose_matmul(Graph& graph); | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -160,6 +160,28 @@ static void solve_batch_index_forward(Operand* operand) | |||
| solve_batch_index_backward(r); | |||
| } | |||
| } | |||
| else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") | |||
| { | |||
| const std::vector<int>& shape = op->params.at("shape").ai; | |||
| if (shape[batch_index] == 1) | |||
| { | |||
| for (Operand* r : op->outputs) | |||
| { | |||
| if (r->params.find("__batch_index") != r->params.end()) | |||
| continue; | |||
| r->params["__batch_index"] = batch_index; | |||
| solve_batch_index_forward(r); | |||
| solve_batch_index_backward(r); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // give up reshape across batch index | |||
| } | |||
| } | |||
| else | |||
| { | |||
| for (Operand* r : op->outputs) | |||
| @@ -207,6 +229,28 @@ static void solve_batch_index_backward(Operand* operand) | |||
| solve_batch_index_forward(r); | |||
| } | |||
| } | |||
| else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") | |||
| { | |||
| const std::vector<int>& shape = op->params.at("shape").ai; | |||
| if (shape[batch_index] == 1) | |||
| { | |||
| for (Operand* r : op->inputs) | |||
| { | |||
| if (r->params.find("__batch_index") != r->params.end()) | |||
| continue; | |||
| r->params["__batch_index"] = batch_index; | |||
| solve_batch_index_backward(r); | |||
| solve_batch_index_forward(r); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // give up reshape across batch index | |||
| } | |||
| } | |||
| else | |||
| { | |||
| for (Operand* r : op->inputs) | |||
| @@ -0,0 +1,54 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_matmul : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Input other 0 1 other | |||
| torch.matmul op_0 2 1 input other out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "MatMul"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "matmul"; | |||
| } | |||
| void write(Operator* /*op*/, const std::map<std::string, Parameter>& /*captured_params*/) const | |||
| { | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_matmul, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -49,14 +49,24 @@ pnnx.Output output 1 0 out | |||
| int dim0 = captured_params.at("dim0").i; | |||
| int dim1 = captured_params.at("dim1").i; | |||
| int input_rank = op->inputs[0]->shape.size(); | |||
| if (dim0 < 0) | |||
| { | |||
| dim0 = input_rank + dim0; | |||
| } | |||
| if (dim1 < 0) | |||
| { | |||
| dim1 = input_rank + dim1; | |||
| } | |||
| if (dim0 == batch_index || dim1 == batch_index) | |||
| { | |||
| fprintf(stderr, "permute across batch dim is not supported yet!\n"); | |||
| return; | |||
| } | |||
| int input_rank = op->inputs[0]->shape.size(); | |||
| if (batch_index >= 0 && batch_index < input_rank) | |||
| input_rank -= 1; | |||
| @@ -172,6 +172,7 @@ pnnx_add_test(torch_clamp) | |||
| pnnx_add_test(torch_clone) | |||
| pnnx_add_test(torch_flatten) | |||
| pnnx_add_test(torch_logsumexp) | |||
| pnnx_add_test(torch_matmul) | |||
| pnnx_add_test(torch_mean) | |||
| pnnx_add_test(torch_norm) | |||
| pnnx_add_test(torch_permute) | |||
| @@ -200,6 +201,7 @@ pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) | |||
| pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) | |||
| pnnx_add_test(pnnx_fuse_linear_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_select_to_unbind) | |||
| if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") | |||
| pnnx_add_test(F_mish) | |||
| @@ -130,6 +130,7 @@ pnnx_ncnn_add_test(torch_chunk) | |||
| pnnx_ncnn_add_test(torch_clamp) | |||
| pnnx_ncnn_add_test(torch_clone) | |||
| pnnx_ncnn_add_test(torch_logsumexp) | |||
| pnnx_ncnn_add_test(torch_matmul) | |||
| pnnx_ncnn_add_test(torch_mean) | |||
| pnnx_ncnn_add_test(torch_permute) | |||
| pnnx_ncnn_add_test(torch_prod) | |||
| @@ -144,6 +145,8 @@ pnnx_ncnn_add_test(resnet18) | |||
| pnnx_ncnn_add_test(shufflenet_v2_x1_0) | |||
| pnnx_ncnn_add_test(squeezenet1_1) | |||
| pnnx_ncnn_add_test(ncnn_fuse_transpose_matmul) | |||
| if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") | |||
| pnnx_ncnn_add_test(F_mish) | |||
| pnnx_ncnn_add_test(nn_Mish) | |||
| @@ -0,0 +1,95 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1): | |||
| a = torch.matmul(a0, a1.transpose(-2, -1)) | |||
| b = torch.matmul(b0, b1.transpose(-2, -1)) | |||
| c = torch.matmul(c0, c1.transpose(-2, -1)) | |||
| d = torch.matmul(d0, d1.transpose(-2, -1)) | |||
| e = torch.matmul(e0, e1.transpose(-2, -1)) | |||
| f = torch.matmul(f0, f1.transpose(-2, -1)) | |||
| g = torch.matmul(g0, g1.transpose(-2, -1)) | |||
| h = torch.matmul(h0, h1.transpose(-2, -1)) | |||
| i = torch.matmul(i0, i1.transpose(-2, -1)) | |||
| j = torch.matmul(j0, j1.transpose(-2, -1)) | |||
| k = torch.matmul(k0, k1.transpose(-2, -1)) | |||
| l = torch.matmul(l0, l1.transpose(-2, -1)) | |||
| return a, b, c, d, e, f, g, h, i, j, k, l | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(14) | |||
| a1 = torch.rand(6, 14) | |||
| b0 = torch.rand(13) | |||
| b1 = torch.rand(7, 4, 13) | |||
| c0 = torch.rand(15) | |||
| c1 = torch.rand(5, 7, 9, 15) | |||
| d0 = torch.rand(23, 14) | |||
| d1 = torch.rand(25, 14) | |||
| e0 = torch.rand(4, 5) | |||
| e1 = torch.rand(10, 40, 5) | |||
| f0 = torch.rand(14, 6) | |||
| f1 = torch.rand(2, 4, 20, 6) | |||
| g0 = torch.rand(10, 23, 14) | |||
| g1 = torch.rand(5, 14) | |||
| h0 = torch.rand(7, 8, 13, 14) | |||
| h1 = torch.rand(35, 14) | |||
| i0 = torch.rand(10, 23, 14) | |||
| i1 = torch.rand(10, 5, 14) | |||
| j0 = torch.rand(10, 13, 18) | |||
| j1 = torch.rand(3, 1, 8, 18) | |||
| k0 = torch.rand(1, 5, 23, 11) | |||
| k1 = torch.rand(8, 1, 9, 11) | |||
| l0 = torch.rand(6, 9, 13, 14) | |||
| l1 = torch.rand(6, 9, 15, 14) | |||
| a = net(a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1)) | |||
| mod.save("test_ncnn_fuse_transpose_matmul.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_ncnn_fuse_transpose_matmul.pt inputshape=[14],[6,14],[13],[7,4,13],[15],[5,7,9,15],[23,14],[25,14],[4,5],[10,40,5],[14,6],[2,4,20,6],[10,23,14],[5,14],[7,8,13,14],[35,14],[10,23,14],[10,5,14],[10,13,18],[3,1,8,18],[1,5,23,11],[8,1,9,11],[6,9,13,14],[6,9,15,14]") | |||
| # ncnn inference | |||
| import test_ncnn_fuse_transpose_matmul_ncnn | |||
| b = test_ncnn_fuse_transpose_matmul_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| print(a0.shape) | |||
| print(b0.shape) | |||
| print(a0) | |||
| print(b0) | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,107 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1): | |||
| a = torch.matmul(a0, a1) | |||
| b = torch.matmul(b0, b1) | |||
| c = torch.matmul(c0, c1) | |||
| d = torch.matmul(d0, d1) | |||
| e = torch.matmul(e0, e1) | |||
| f = torch.matmul(f0, f1) | |||
| g = torch.matmul(g0, g1) | |||
| h = torch.matmul(h0, h1) | |||
| i = torch.matmul(i0, i1) | |||
| j = torch.matmul(j0, j1) | |||
| k = torch.matmul(k0, k1) | |||
| l = torch.matmul(l0, l1) | |||
| m = torch.matmul(m0, m1) | |||
| n = torch.matmul(n0, n1) | |||
| o = torch.matmul(o0, o1) | |||
| p = torch.matmul(p0, p1) | |||
| return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(13) | |||
| a1 = torch.rand(13) | |||
| b0 = torch.rand(14) | |||
| b1 = torch.rand(14, 6) | |||
| c0 = torch.rand(13) | |||
| c1 = torch.rand(7, 13, 4) | |||
| d0 = torch.rand(15) | |||
| d1 = torch.rand(5, 7, 15, 9) | |||
| e0 = torch.rand(5, 12) | |||
| e1 = torch.rand(12) | |||
| f0 = torch.rand(10, 3, 4) | |||
| f1 = torch.rand(4) | |||
| g0 = torch.rand(6, 3, 7, 14) | |||
| g1 = torch.rand(14) | |||
| h0 = torch.rand(23, 14) | |||
| h1 = torch.rand(14, 25) | |||
| i0 = torch.rand(4, 5) | |||
| i1 = torch.rand(10, 5, 40) | |||
| j0 = torch.rand(14, 6) | |||
| j1 = torch.rand(2, 4, 6, 20) | |||
| k0 = torch.rand(10, 23, 14) | |||
| k1 = torch.rand(14, 5) | |||
| l0 = torch.rand(7, 8, 13, 14) | |||
| l1 = torch.rand(14, 35) | |||
| m0 = torch.rand(10, 23, 14) | |||
| m1 = torch.rand(10, 14, 5) | |||
| n0 = torch.rand(10, 13, 18) | |||
| n1 = torch.rand(3, 1, 18, 8) | |||
| o0 = torch.rand(1, 5, 23, 11) | |||
| o1 = torch.rand(8, 1, 11, 9) | |||
| p0 = torch.rand(6, 9, 13, 14) | |||
| p1 = torch.rand(6, 9, 14, 15) | |||
| a = net(a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1)) | |||
| mod.save("test_torch_matmul.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_matmul.pt inputshape=[13],[13],[14],[14,6],[13],[7,13,4],[15],[5,7,15,9],[5,12],[12],[10,3,4],[4],[6,3,7,14],[14],[23,14],[14,25],[4,5],[10,5,40],[14,6],[2,4,6,20],[10,23,14],[14,5],[7,8,13,14],[14,35],[10,23,14],[10,14,5],[10,13,18],[3,1,18,8],[1,5,23,11],[8,1,11,9],[6,9,13,14],[6,9,14,15]") | |||
| # ncnn inference | |||
| import test_torch_matmul_ncnn | |||
| b = test_torch_matmul_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| print(a0.shape) | |||
| print(b0.shape) | |||
| print(a0) | |||
| print(b0) | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,69 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x): | |||
| x0 = torch.select(x, 0, 0) | |||
| x1 = torch.select(x, 0, 1) | |||
| x2 = torch.select(x, 0, 2) | |||
| y0 = torch.select(x, 1, 0) | |||
| y1 = torch.select(x, 1, 1) | |||
| y2 = torch.select(x, 1, 2) | |||
| y3 = torch.select(x, 1, 3) | |||
| z0 = torch.select(x, 2, 0) | |||
| z1 = torch.select(x, 2, 1) | |||
| z2 = torch.select(x, 2, 2) | |||
| z3 = torch.select(x, 2, 3) | |||
| z4 = torch.select(x, 2, 4) | |||
| return x0, x1, x2, y0, y1, y2, y3, z0, z1, z2, z3, z4 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 4, 5) | |||
| a = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_pnnx_fuse_select_to_unbind.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_fuse_select_to_unbind.pt inputshape=[3,4,5]") | |||
| # pnnx inference | |||
| import test_pnnx_fuse_select_to_unbind_pnnx | |||
| b = test_pnnx_fuse_select_to_unbind_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,103 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1): | |||
| a = torch.matmul(a0, a1) | |||
| b = torch.matmul(b0, b1) | |||
| c = torch.matmul(c0, c1) | |||
| d = torch.matmul(d0, d1) | |||
| e = torch.matmul(e0, e1) | |||
| f = torch.matmul(f0, f1) | |||
| g = torch.matmul(g0, g1) | |||
| h = torch.matmul(h0, h1) | |||
| i = torch.matmul(i0, i1) | |||
| j = torch.matmul(j0, j1) | |||
| k = torch.matmul(k0, k1) | |||
| l = torch.matmul(l0, l1) | |||
| m = torch.matmul(m0, m1) | |||
| n = torch.matmul(n0, n1) | |||
| o = torch.matmul(o0, o1) | |||
| p = torch.matmul(p0, p1) | |||
| return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(13) | |||
| a1 = torch.rand(13) | |||
| b0 = torch.rand(14) | |||
| b1 = torch.rand(14, 6) | |||
| c0 = torch.rand(13) | |||
| c1 = torch.rand(7, 13, 4) | |||
| d0 = torch.rand(15) | |||
| d1 = torch.rand(5, 7, 15, 9) | |||
| e0 = torch.rand(5, 12) | |||
| e1 = torch.rand(12) | |||
| f0 = torch.rand(10, 3, 4) | |||
| f1 = torch.rand(4) | |||
| g0 = torch.rand(6, 3, 7, 14) | |||
| g1 = torch.rand(14) | |||
| h0 = torch.rand(23, 14) | |||
| h1 = torch.rand(14, 25) | |||
| i0 = torch.rand(4, 5) | |||
| i1 = torch.rand(10, 5, 40) | |||
| j0 = torch.rand(14, 6) | |||
| j1 = torch.rand(2, 4, 6, 20) | |||
| k0 = torch.rand(10, 23, 14) | |||
| k1 = torch.rand(14, 5) | |||
| l0 = torch.rand(7, 8, 13, 14) | |||
| l1 = torch.rand(14, 35) | |||
| m0 = torch.rand(10, 23, 14) | |||
| m1 = torch.rand(10, 14, 5) | |||
| n0 = torch.rand(10, 13, 18) | |||
| n1 = torch.rand(3, 1, 18, 8) | |||
| o0 = torch.rand(1, 5, 23, 11) | |||
| o1 = torch.rand(8, 1, 11, 9) | |||
| p0 = torch.rand(6, 9, 13, 14) | |||
| p1 = torch.rand(6, 9, 14, 15) | |||
| a = net(a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1, b0, b1, c0, c1, d0, d1, e0, e1, f0, f1, g0, g1, h0, h1, i0, i1, j0, j1, k0, k1, l0, l1, m0, m1, n0, n1, o0, o1, p0, p1)) | |||
| mod.save("test_torch_matmul.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_matmul.pt inputshape=[13],[13],[14],[14,6],[13],[7,13,4],[15],[5,7,15,9],[5,12],[12],[10,3,4],[4],[6,3,7,14],[14],[23,14],[14,25],[4,5],[10,5,40],[14,6],[2,4,6,20],[10,23,14],[14,5],[7,8,13,14],[14,35],[10,23,14],[10,14,5],[10,13,18],[3,1,18,8],[1,5,23,11],[8,1,11,9],[6,9,13,14],[6,9,14,15]") | |||
| # pnnx inference | |||
| import test_torch_matmul_pnnx | |||
| b = test_torch_matmul_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||