Browse Source

fix gemm transpose B wrong result when tile N is not a multiple of 4, optimize load C (#4430)

tags/20230223
nihui GitHub 3 years ago
parent
commit
88dba58992
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 450 additions and 569 deletions
  1. +15
    -9
      src/layer/arm/gemm_arm.cpp
  2. +435
    -560
      src/layer/x86/gemm_x86.cpp

+ 15
- 9
src/layer/arm/gemm_arm.cpp View File

@@ -707,7 +707,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int
{
if (elempack == 4)
{
const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4;
const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4;

int kk = 0;
for (; kk + 3 < max_kk; kk += 4)
@@ -750,7 +750,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int
{
if (elempack == 4)
{
const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4;
const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4;

int kk = 0;
for (; kk + 3 < max_kk; kk += 4)
@@ -787,7 +787,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int
{
if (elempack == 4)
{
const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4;
const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4;

int kk = 0;
for (; kk + 3 < max_kk; kk += 4)
@@ -820,7 +820,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int
#if __ARM_NEON
if (elempack == 4)
{
const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4;
const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4;

int kk = 0;
for (; kk + 3 < max_kk; kk += 4)
@@ -853,7 +853,7 @@ static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int
#if __ARM_NEON
if (elempack == 4)
{
const float* p0 = (const float*)B + k / 4 * 4 * B_hstep + (j + jj) * 4;
const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4;

int kk = 0;
for (; kk + 3 < max_kk; kk += 4)
@@ -3789,11 +3789,11 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N,

#if __aarch64__
TILE_M = tile_size / 8 * 8;
TILE_N = tile_size;
TILE_N = tile_size / 4 * 4;
TILE_K = tile_size / 8 * 8;
#elif __ARM_NEON
TILE_M = tile_size / 4 * 4;
TILE_N = tile_size;
TILE_N = tile_size / 4 * 4;
TILE_K = tile_size / 4 * 4;
#else
TILE_M = tile_size / 2 * 2;
@@ -3818,10 +3818,10 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N,

#if __aarch64__
TILE_M = tile_size / 8 * 8;
TILE_N = tile_size;
TILE_N = tile_size / 4 * 4;
#elif __ARM_NEON
TILE_M = tile_size / 4 * 4;
TILE_N = tile_size;
TILE_N = tile_size / 4 * 4;
#else
TILE_M = tile_size / 2 * 2;
TILE_N = tile_size;
@@ -3846,7 +3846,13 @@ static void get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N,
if (N > 0)
{
int nn_N = (N + TILE_N - 1) / TILE_N;
#if __aarch64__
TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4);
#elif __ARM_NEON
TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4);
#else
TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N);
#endif
}

if (nT > 1)


+ 435
- 560
src/layer/x86/gemm_x86.cpp
File diff suppressed because it is too large
View File


Loading…
Cancel
Save