|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
#include "mindspore/core/utils/log_adapter.h" |
|
|
|
#include "common/common_test.h" |
|
|
|
#include "mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h" |
|
|
|
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
#include "src/lite_kernel.h" |
|
|
|
|
|
|
|
@@ -26,6 +27,144 @@ class TestMatMulFp32 : public mindspore::Common { |
|
|
|
TestMatMulFp32() {} |
|
|
|
}; |
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row2Col8Test1) { |
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, |
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, |
|
|
|
0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, |
|
|
|
0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, |
|
|
|
0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, |
|
|
|
0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; |
|
|
|
float co[] = {0.21, 0.67, 0.53, 0.09, 0.43, 0.35, 0.04, 0.43, 0.38, 0.81, 0.97, 0.50, 0.14, 0.52, 0.10, 0.14, |
|
|
|
0.81, 0.57, 0.92, 0.39, 0.67, 0.02, 0.18, 0.67, 0.98, 0.70, 0.35, 0.09, 0.10, 0.33, 0.92, 0.10, |
|
|
|
0.09, 0.27, 0.74, 0.93, 0.73, 0.99, 0.46, 0.73, 0.68, 0.90, 0.78, 0.91, 0.37, 0.49, 0.08, 0.37, |
|
|
|
0.02, 0.07, 0.87, 0.20, 0.24, 0.67, 0.04, 0.24, 0.33, 0.13, 0.23, 0.97, 0.93, 0.75, 0.24, 0.93, |
|
|
|
0.85, 0.03, 0.34, 0.61, 0.31, 0.66, 0.52, 0.31, 0.35, 0.04, 0, 0, 0, 0, 0, 0, |
|
|
|
0.52, 0.10, 0, 0, 0, 0, 0, 0, 0.02, 0.18, 0, 0, 0, 0, 0, 0, |
|
|
|
0.33, 0.92, 0, 0, 0, 0, 0, 0, 0.99, 0.46, 0, 0, 0, 0, 0, 0, |
|
|
|
0.49, 0.08, 0, 0, 0, 0, 0, 0, 0.67, 0.04, 0, 0, 0, 0, 0, 0, |
|
|
|
0.75, 0.24, 0, 0, 0, 0, 0, 0, 0.66, 0.52, 0, 0, 0, 0, 0, 0}; |
|
|
|
float out[144] = {0}; |
|
|
|
RowMajor2Col8Major(in, out, 10, 9); |
|
|
|
CompareOutputData(out, co, 144, 0.0001); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row2Col8Test2) { |
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, |
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, |
|
|
|
0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, |
|
|
|
0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, |
|
|
|
0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, |
|
|
|
0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; |
|
|
|
float co[] = {0.21, 0.68, 0.81, 0.07, 0.92, 0.23, 0.09, 0.61, 0.38, 0.02, 0.57, 0.13, 0.35, 0.34, 0.93, |
|
|
|
0.43, 0.81, 0.33, 0.70, 0.03, 0.74, 0.09, 0.91, 0.14, 0.98, 0.85, 0.27, 0.53, 0.78, 0.50, |
|
|
|
0.20, 0.67, 0.09, 0.67, 0.90, 0.97, 0.87, 0.39, 0.97, 0.10, 0.73, 0.35, 0.49, 0.10, 0.04, |
|
|
|
0.67, 0.93, 0.33, 0.37, 0.52, 0.67, 0.18, 0.24, 0.10, 0.31, 0.99, 0.24, 0.02, 0.75, 0.92, |
|
|
|
0.52, 0.73, 0.35, 0.49, 0.93, 0.33, 0.66, 0.46, 0.43, 0.37, 0.52, 0.67, 0.31, 0.99, 0.04, |
|
|
|
0.08, 0.14, 0.24, 0.02, 0.75, 0.66, 0.46, 0, 0, 0, 0, 0, 0, 0.04, 0.08, |
|
|
|
0, 0, 0, 0, 0, 0, 0.10, 0.04, 0, 0, 0, 0, 0, 0, 0.18, |
|
|
|
0.24, 0, 0, 0, 0, 0, 0, 0.92, 0.52, 0, 0, 0, 0, 0, 0}; |
|
|
|
float out[120] = {0}; |
|
|
|
RowMajor2Col8Major(in, out, 18, 5); |
|
|
|
CompareOutputData(out, co, 120, 0.0001); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest1) { |
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0, |
|
|
|
0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0, |
|
|
|
0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0, |
|
|
|
0.09, 0.93, 0.91, 0.20, 0.97, 0, 0, 0, 0.61, 0.43, 0.14, 0.67, 0.10, 0, 0, 0, |
|
|
|
0.73, 0.37, 0.24, 0.93, 0.31, 0, 0, 0, 0.35, 0.52, 0.02, 0.33, 0.99, 0, 0, 0, |
|
|
|
0.49, 0.67, 0.75, 0.66, 0.04, 0, 0, 0, 0.10, 0.18, 0.92, 0.46, 0.08, 0, 0, 0, |
|
|
|
0.04, 0.24, 0.52, 0.43, 0.14, 0, 0, 0, 0.67, 0.10, 0.73, 0.37, 0.24, 0, 0, 0, |
|
|
|
0.93, 0.31, 0.35, 0.52, 0.02, 0, 0, 0, 0.33, 0.99, 0.49, 0.67, 0.75, 0, 0, 0, |
|
|
|
0.66, 0.04, 0.10, 0.18, 0.92, 0, 0, 0, 0.46, 0.08, 0.04, 0.24, 0.52, 0, 0, 0, |
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
|
|
|
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, |
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, |
|
|
|
0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, |
|
|
|
0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, |
|
|
|
0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, |
|
|
|
0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52}; |
|
|
|
float out[90] = {0}; |
|
|
|
Row8x8Major2RowMajor(in, out, 18, 5); |
|
|
|
CompareOutputData(out, co, 90, 0.0001); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest2) { |
|
|
|
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0, |
|
|
|
0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0, |
|
|
|
0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0, |
|
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
|
|
|
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, |
|
|
|
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39}; |
|
|
|
float out[30] = {0}; |
|
|
|
Row8x8Major2RowMajor(in, out, 6, 5); |
|
|
|
CompareOutputData(out, co, 30, 0.0001); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestMatMulFp32, Row8x82RowTest3) { |
|
|
|
float in[] = { |
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73, |
|
|
|
0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, |
|
|
|
0.10, 0.18, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.93, |
|
|
|
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.75, 0.66, 0.04, 0.10, |
|
|
|
0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, |
|
|
|
0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, |
|
|
|
0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50, |
|
|
|
0.39, 0.09, 0.93, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, |
|
|
|
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81, |
|
|
|
0.98, 0.09, 0.68, 0.02, 0.33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
|
|
|
0, 0, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52, |
|
|
|
0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04, |
|
|
|
0.24, 0.52, 0.21, 0.38, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, |
|
|
|
0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.04, 0.24, |
|
|
|
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85, 0.67, 0.81, 0.57, 0.70, |
|
|
|
0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, |
|
|
|
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97, |
|
|
|
0.61, 0.43, 0.14, 0.67, 0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, |
|
|
|
0.99, 0.49, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85, |
|
|
|
0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
|
|
|
0, 0, 0, 0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0, |
|
|
|
0, 0.04, 0.10, 0.18, 0, 0, 0, 0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.81, 0.98, |
|
|
|
0.09, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73, 0.37, 0.24, 0, 0, |
|
|
|
0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0, 0, 0, 0, 0, |
|
|
|
0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0, 0, 0.13, 0.03, 0.53, |
|
|
|
0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0, 0, 0.04, 0.10, 0.18, 0, 0, 0, |
|
|
|
0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73, |
|
|
|
0.37, 0.24, 0, 0, 0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0, |
|
|
|
0, 0, 0, 0, 0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0, |
|
|
|
0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
|
|
|
0, 0, 0, 0, 0, 0}; |
|
|
|
float co[] = { |
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, |
|
|
|
0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, |
|
|
|
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, |
|
|
|
0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, |
|
|
|
0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.21, 0.38, 0.81, 0.98, 0.09, |
|
|
|
0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, |
|
|
|
0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, |
|
|
|
0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, |
|
|
|
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67, |
|
|
|
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, |
|
|
|
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, |
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, |
|
|
|
0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, |
|
|
|
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, |
|
|
|
0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, |
|
|
|
0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, |
|
|
|
0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, |
|
|
|
0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, |
|
|
|
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67, |
|
|
|
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, |
|
|
|
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, |
|
|
|
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53}; |
|
|
|
float out[418] = {0}; |
|
|
|
Row8x8Major2RowMajor(in, out, 22, 19); |
|
|
|
CompareOutputData(out, co, 418, 0.0001); |
|
|
|
} |
|
|
|
|
|
|
|
int MMTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_, |
|
|
|
float *a_ptr, float *b_ptr, std::vector<int> a_shape, std::vector<int> b_shape, |
|
|
|
std::vector<int> c_shape) { |
|
|
|
|