|
|
|
@@ -49,7 +49,7 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const |
|
|
|
__m256 in4 = _mm256_loadu_ps(in[k + 3]); |
|
|
|
__m256 w4 = _mm256_loadu_ps(w + 24); |
|
|
|
out1 = _mm256_fmadd_ps(in3, w3, out1); |
|
|
|
__m256 in5 = _mm256_loadu_ps(in[k + 8]); |
|
|
|
__m256 in5 = _mm256_loadu_ps(in[k + 4]); |
|
|
|
__m256 w5 = _mm256_loadu_ps(w + 32); |
|
|
|
out1 = _mm256_fmadd_ps(in4, w4, out1); |
|
|
|
w += 40; |
|
|
|
@@ -68,7 +68,10 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const |
|
|
|
__m256 zero = _mm256_setzero_ps(); |
|
|
|
out1 = _mm256_max_ps(out1, zero); |
|
|
|
} |
|
|
|
if (c == C8NUM) { |
|
|
|
if (c > C8NUM || c8_mod == 0) { |
|
|
|
_mm256_storeu_ps(output, out1); |
|
|
|
output += C8NUM; |
|
|
|
} else { |
|
|
|
__m128 tmp; |
|
|
|
switch (c8_mod) { |
|
|
|
case 1: |
|
|
|
@@ -105,10 +108,7 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const |
|
|
|
_mm256_storeu_ps(output, out1); |
|
|
|
break; |
|
|
|
} |
|
|
|
output += c8_mod == 0 ? C8NUM : c8_mod; |
|
|
|
} else { |
|
|
|
_mm256_storeu_ps(output, out1); |
|
|
|
output += C8NUM; |
|
|
|
output += c8_mod; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|