From 12eaa6f9ba6c51cd79b6ab0e003e6efc85553ede Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 22 Jun 2021 15:29:07 +0800 Subject: [PATCH] update concat test --- tests/test_concat.cpp | 200 ++++++++++++++++++++++++++++++++---------- 1 file changed, 154 insertions(+), 46 deletions(-) diff --git a/tests/test_concat.cpp b/tests/test_concat.cpp index 166e974a3..dace600af 100644 --- a/tests/test_concat.cpp +++ b/tests/test_concat.cpp @@ -34,9 +34,14 @@ static int test_concat(const std::vector& a, int axis) static int test_concat_0() { std::vector a(3); - a[0] = RandomMat(16, 12, 64); - a[1] = RandomMat(16, 12, 64); - a[2] = RandomMat(16, 12, 64); + a[0] = RandomMat(16, 12, 24); + a[1] = RandomMat(16, 12, 24); + a[2] = RandomMat(16, 12, 24); + + std::vector b(3); + b[0] = RandomMat(16, 12, 64); + b[1] = RandomMat(16, 12, 64); + b[2] = RandomMat(16, 12, 64); return 0 || test_concat(a, 0) @@ -44,127 +49,230 @@ static int test_concat_0() || test_concat(a, 2) || test_concat(a, -1) || test_concat(a, -2) - || test_concat(a, -3); + || test_concat(a, -3) + + || test_concat(b, 0) + || test_concat(b, 1) + || test_concat(b, 2) + || test_concat(b, -1) + || test_concat(b, -2) + || test_concat(b, -3); } static int test_concat_1() { std::vector a(3); - a[0] = RandomMat(7, 3, 6); - a[1] = RandomMat(7, 3, 16); - a[2] = RandomMat(7, 3, 10); + a[0] = RandomMat(7, 3, 3); + a[1] = RandomMat(7, 3, 8); + a[2] = RandomMat(7, 3, 5); std::vector b(3); - b[0] = RandomMat(9, 5, 16); - b[1] = RandomMat(9, 5, 8); - b[2] = RandomMat(9, 5, 24); + b[0] = RandomMat(9, 5, 8); + b[1] = RandomMat(9, 5, 4); + b[2] = RandomMat(9, 5, 12); + + std::vector c(3); + c[0] = RandomMat(7, 3, 6); + c[1] = RandomMat(7, 3, 16); + c[2] = RandomMat(7, 3, 10); + + std::vector d(3); + d[0] = RandomMat(9, 5, 16); + d[1] = RandomMat(9, 5, 8); + d[2] = RandomMat(9, 5, 24); return 0 || test_concat(a, 0) || test_concat(a, -3) || test_concat(b, 0) - || test_concat(b, -3); + || test_concat(b, -3) + + || test_concat(c, 0) + || test_concat(c, -3) + + || test_concat(d, 0) + || test_concat(d, -3); } static int test_concat_2() { std::vector a(3); - a[0] = RandomMat(7, 3, 10); - a[1] = RandomMat(7, 8, 10); - a[2] = RandomMat(7, 5, 10); + a[0] = RandomMat(7, 3, 5); + a[1] = RandomMat(7, 8, 5); + a[2] = RandomMat(7, 5, 5); std::vector b(3); - b[0] = RandomMat(9, 8, 24); - b[1] = RandomMat(9, 3, 24); - b[2] = RandomMat(9, 5, 24); + b[0] = RandomMat(9, 8, 12); + b[1] = RandomMat(9, 3, 12); + b[2] = RandomMat(9, 5, 12); + + std::vector c(3); + c[0] = RandomMat(7, 3, 10); + c[1] = RandomMat(7, 8, 10); + c[2] = RandomMat(7, 5, 10); + + std::vector d(3); + d[0] = RandomMat(9, 8, 24); + d[1] = RandomMat(9, 3, 24); + d[2] = RandomMat(9, 5, 24); return 0 || test_concat(a, 1) || test_concat(a, -2) || test_concat(b, 1) - || test_concat(b, -2); + || test_concat(b, -2) + + || test_concat(c, 1) + || test_concat(c, -2) + + || test_concat(d, 1) + || test_concat(d, -2); } static int test_concat_3() { std::vector a(3); - a[0] = RandomMat(8, 9, 6); - a[1] = RandomMat(3, 9, 6); - a[2] = RandomMat(5, 9, 6); + a[0] = RandomMat(8, 9, 3); + a[1] = RandomMat(3, 9, 3); + a[2] = RandomMat(5, 9, 3); std::vector b(3); - b[0] = RandomMat(1, 7, 32); - b[1] = RandomMat(8, 7, 32); - b[2] = RandomMat(7, 7, 32); + b[0] = RandomMat(1, 7, 16); + b[1] = RandomMat(8, 7, 16); + b[2] = RandomMat(7, 7, 16); + + std::vector c(3); + c[0] = RandomMat(8, 9, 6); + c[1] = RandomMat(3, 9, 6); + c[2] = RandomMat(5, 9, 6); + + std::vector d(3); + d[0] = RandomMat(1, 7, 32); + d[1] = RandomMat(8, 7, 32); + d[2] = RandomMat(7, 7, 32); return 0 || test_concat(a, 2) || test_concat(a, -1) || test_concat(b, 2) - || test_concat(b, -1); + || test_concat(b, -1) + + || test_concat(c, 2) + || test_concat(c, -1) + + || test_concat(d, 2) + || test_concat(d, -1); } static int test_concat_4() { std::vector a(3); - a[0] = RandomMat(11, 6); - a[1] = RandomMat(11, 16); - a[2] = RandomMat(11, 10); + a[0] = RandomMat(11, 3); + a[1] = RandomMat(11, 8); + a[2] = RandomMat(11, 5); std::vector b(3); - b[0] = RandomMat(15, 24); - b[1] = RandomMat(15, 16); - b[2] = RandomMat(15, 8); + b[0] = RandomMat(15, 12); + b[1] = RandomMat(15, 8); + b[2] = RandomMat(15, 4); + + std::vector c(3); + c[0] = RandomMat(11, 6); + c[1] = RandomMat(11, 16); + c[2] = RandomMat(11, 10); + + std::vector d(3); + d[0] = RandomMat(15, 24); + d[1] = RandomMat(15, 16); + d[2] = RandomMat(15, 8); return 0 || test_concat(a, 0) || test_concat(a, -2) || test_concat(b, 0) - || test_concat(b, -2); + || test_concat(b, -2) + + || test_concat(c, 0) + || test_concat(c, -2) + + || test_concat(d, 0) + || test_concat(d, -2); } static int test_concat_5() { std::vector a(3); - a[0] = RandomMat(9, 14); - a[1] = RandomMat(8, 14); - a[2] = RandomMat(11, 14); + a[0] = RandomMat(9, 7); + a[1] = RandomMat(8, 7); + a[2] = RandomMat(11, 7); std::vector b(3); - b[0] = RandomMat(13, 48); - b[1] = RandomMat(18, 48); - b[2] = RandomMat(15, 48); + b[0] = RandomMat(13, 24); + b[1] = RandomMat(18, 24); + b[2] = RandomMat(15, 24); + + std::vector c(3); + c[0] = RandomMat(9, 14); + c[1] = RandomMat(8, 14); + c[2] = RandomMat(11, 14); + + std::vector d(3); + d[0] = RandomMat(13, 48); + d[1] = RandomMat(18, 48); + d[2] = RandomMat(15, 48); return 0 || test_concat(a, 1) || test_concat(a, -1) || test_concat(b, 1) - || test_concat(b, -1); + || test_concat(b, -1) + + || test_concat(c, 1) + || test_concat(c, -1) + + || test_concat(d, 1) + || test_concat(d, -1); } static int test_concat_6() { std::vector a(3); - a[0] = RandomMat(6); - a[1] = RandomMat(16); - a[2] = RandomMat(10); + a[0] = RandomMat(3); + a[1] = RandomMat(8); + a[2] = RandomMat(5); std::vector b(3); - b[0] = RandomMat(8); - b[1] = RandomMat(16); - b[2] = RandomMat(24); + b[0] = RandomMat(4); + b[1] = RandomMat(8); + b[2] = RandomMat(12); + + std::vector c(3); + c[0] = RandomMat(6); + c[1] = RandomMat(16); + c[2] = RandomMat(10); + + std::vector d(3); + d[0] = RandomMat(8); + d[1] = RandomMat(16); + d[2] = RandomMat(24); return 0 || test_concat(a, 0) || test_concat(a, -1) || test_concat(b, 0) - || test_concat(b, -1); + || test_concat(b, -1) + + || test_concat(c, 0) + || test_concat(c, -1) + + || test_concat(d, 0) + || test_concat(d, -1); } int main()