|
|
|
@@ -35,7 +35,7 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data |
|
|
|
const float *src = src_b + c * param->inner_size_; |
|
|
|
float *dst = dst_b + c * param->inner_size_; |
|
|
|
double mean = 0.0f; |
|
|
|
double square_mean = 0.0f; |
|
|
|
double squ_m = 0.0f; |
|
|
|
|
|
|
|
int index = 0; |
|
|
|
#if defined(ENABLE_AVX) |
|
|
|
@@ -46,7 +46,7 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data |
|
|
|
__m128 square128 = _mm_add_ps(_mm256_extractf128_ps(squarev, 0), _mm256_extractf128_ps(squarev, 1)); |
|
|
|
for (int i = 0; i < C4NUM; ++i) { |
|
|
|
mean += MS_F32X4_GETI(src128, i); |
|
|
|
square_mean += MS_F32X4_GETI(square128, i); |
|
|
|
squ_m += MS_F32X4_GETI(square128, i); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
@@ -57,11 +57,11 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data |
|
|
|
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
mean += vaddvq_f32(srcv); |
|
|
|
square_mean += vaddvq_f32(squarev); |
|
|
|
squ_m += vaddvq_f32(squarev); |
|
|
|
#elif defined(ENABLE_SSE) |
|
|
|
for (int i = 0; i < C4NUM; ++i) { |
|
|
|
mean += MS_F32X4_GETI(srcv, i); |
|
|
|
square_mean += MS_F32X4_GETI(squarev, i); |
|
|
|
squ_m += MS_F32X4_GETI(squarev, i); |
|
|
|
} |
|
|
|
#else |
|
|
|
float32x2_t src_add2 = vadd_f32(vget_low_f32(srcv), vget_high_f32(srcv)); |
|
|
|
@@ -69,18 +69,18 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data |
|
|
|
mean += vget_lane_f32(src_add4, 0); |
|
|
|
float32x2_t square_add2 = vadd_f32(vget_low_f32(squarev), vget_high_f32(squarev)); |
|
|
|
float32x2_t square_add4 = vpadd_f32(square_add2, square_add2); |
|
|
|
square_mean += vget_lane_f32(square_add4, 0); |
|
|
|
squ_m += vget_lane_f32(square_add4, 0); |
|
|
|
#endif |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; index < param->inner_size_; index++) { |
|
|
|
mean += src[index]; |
|
|
|
square_mean += src[index] * src[index]; |
|
|
|
squ_m += src[index] * src[index]; |
|
|
|
} |
|
|
|
|
|
|
|
mean /= (float)param->inner_size_; |
|
|
|
square_mean /= (float)param->inner_size_; |
|
|
|
const double deno = gamma_data[c] / sqrt(square_mean - mean * mean + param->epsilon_); |
|
|
|
squ_m /= (float)param->inner_size_; |
|
|
|
const double deno = gamma_data[c] / sqrt(squ_m - mean * mean + param->epsilon_); |
|
|
|
|
|
|
|
index = 0; |
|
|
|
#if defined(ENABLE_AVX) |
|
|
|
@@ -112,6 +112,112 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
#if defined(ENABLE_SSE) || defined(ENABLE_ARM) |
|
|
|
void InstanceNormC4HW4ArmSse(const float *src_b, float *dst_b, const float *gamma_data, const float *beta_data, |
|
|
|
int *c_src, const InstanceNormParameter *param, int channel, int channel_end, int hw_plane, |
|
|
|
MS_FLOAT32X4 hw_planev) { |
|
|
|
int c = *c_src; |
|
|
|
for (; c <= channel_end - C16NUM; c += C16NUM) { |
|
|
|
const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; |
|
|
|
const float *src2 = src_b + (c + C8NUM) * hw_plane, *src3 = src_b + (c + C12NUM) * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f), mean3 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 squ_m2 = MS_MOVQ_F32(0.0f), squ_m3 = MS_MOVQ_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); |
|
|
|
MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2), squarev3 = MS_MULQ_F32(srcv3, srcv3); |
|
|
|
MS_ADDQ_F32_VEC(mean, mean1, mean2, mean3, srcv, srcv1, srcv2, srcv3); |
|
|
|
MS_ADDQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, squarev, squarev1, squarev2, squarev3); |
|
|
|
} |
|
|
|
MS_DIVQ_F32_VEC(mean, mean1, mean2, mean3, hw_planev); |
|
|
|
MS_DIVQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, hw_planev); |
|
|
|
|
|
|
|
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno2 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno3 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
|
|
|
|
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); |
|
|
|
deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); |
|
|
|
deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2)); |
|
|
|
deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3)); |
|
|
|
|
|
|
|
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); |
|
|
|
MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM), betav3 = MS_LDQ_F32(beta_data + c + C12NUM); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); |
|
|
|
MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2), outv3 = MS_SUBQ_F32(srcv3, mean3); |
|
|
|
|
|
|
|
outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); |
|
|
|
outv2 = MS_MULQ_F32(outv2, gammav2), outv3 = MS_MULQ_F32(outv3, gammav3); |
|
|
|
MS_ADDQ_F32_VEC(outv, outv1, outv2, outv3, betav, betav1, betav2, betav3); |
|
|
|
|
|
|
|
MS_STQ_F32(dst + index * channel, outv), MS_STQ_F32(dst + index * channel + C4NUM, outv1); |
|
|
|
MS_STQ_F32(dst + index * channel + C8NUM, outv2), MS_STQ_F32(dst + index * channel + C12NUM, outv3); |
|
|
|
} |
|
|
|
} |
|
|
|
for (; c <= channel_end - C8NUM; c += C8NUM) { |
|
|
|
const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); |
|
|
|
mean = MS_ADDQ_F32(mean, srcv), mean1 = MS_ADDQ_F32(mean1, srcv1); |
|
|
|
squ_m = MS_ADDQ_F32(squ_m, squarev), squ_m1 = MS_ADDQ_F32(squ_m1, squarev1); |
|
|
|
} |
|
|
|
|
|
|
|
MS_DIVQ_F32_VEC(mean, mean1, squ_m, squ_m1, hw_planev); |
|
|
|
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); |
|
|
|
deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); |
|
|
|
|
|
|
|
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); |
|
|
|
outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); |
|
|
|
outv = MS_ADDQ_F32(outv, betav), outv1 = MS_ADDQ_F32(outv1, betav1); |
|
|
|
MS_STQ_F32(dst + index * channel, outv); |
|
|
|
MS_STQ_F32(dst + index * channel + C4NUM, outv1); |
|
|
|
} |
|
|
|
} |
|
|
|
for (; c <= channel_end - C4NUM; c += C4NUM) { |
|
|
|
const float *src = src_b + c * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), squ_m = MS_MOVQ_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), squarev = MS_MULQ_F32(srcv, srcv); |
|
|
|
mean = MS_ADDQ_F32(mean, srcv), squ_m = MS_ADDQ_F32(squ_m, squarev); |
|
|
|
} |
|
|
|
mean = MS_DIVQ_F32(mean, hw_planev), squ_m = MS_DIVQ_F32(squ_m, hw_planev); |
|
|
|
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); |
|
|
|
|
|
|
|
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno), betav = MS_LDQ_F32(beta_data + c); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), outv = MS_SUBQ_F32(srcv, mean); |
|
|
|
MS_STQ_F32(dst + index * channel, MS_ADDQ_F32(MS_MULQ_F32(outv, gammav), betav)); |
|
|
|
} |
|
|
|
} |
|
|
|
*c_src = c; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, |
|
|
|
const InstanceNormParameter *param, size_t task_id) { |
|
|
|
NNACL_CHECK_NULL_RETURN_ERR(src_data); |
|
|
|
@@ -130,161 +236,7 @@ int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamm |
|
|
|
float *dst_b = dst_data + b * channel * hw_plane; |
|
|
|
int c = channel_begin; |
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) |
|
|
|
for (; c <= channel_end - C16NUM; c += C16NUM) { |
|
|
|
const float *src = src_b + c * hw_plane; |
|
|
|
const float *src1 = src_b + (c + C4NUM) * hw_plane; |
|
|
|
const float *src2 = src_b + (c + C8NUM) * hw_plane; |
|
|
|
const float *src3 = src_b + (c + C12NUM) * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 mean1 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 mean3 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean1 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean2 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean3 = MS_MOVQ_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv3 = MS_LDQ_F32(src3 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); |
|
|
|
MS_FLOAT32X4 squarev1 = MS_MULQ_F32(srcv1, srcv1); |
|
|
|
MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2); |
|
|
|
MS_FLOAT32X4 squarev3 = MS_MULQ_F32(srcv3, srcv3); |
|
|
|
mean = MS_ADDQ_F32(mean, srcv); |
|
|
|
mean1 = MS_ADDQ_F32(mean1, srcv1); |
|
|
|
mean2 = MS_ADDQ_F32(mean2, srcv2); |
|
|
|
mean3 = MS_ADDQ_F32(mean3, srcv3); |
|
|
|
square_mean = MS_ADDQ_F32(square_mean, squarev); |
|
|
|
square_mean1 = MS_ADDQ_F32(square_mean1, squarev1); |
|
|
|
square_mean2 = MS_ADDQ_F32(square_mean2, squarev2); |
|
|
|
square_mean3 = MS_ADDQ_F32(square_mean3, squarev3); |
|
|
|
} |
|
|
|
mean = MS_DIVQ_F32(mean, hw_planev); |
|
|
|
mean1 = MS_DIVQ_F32(mean1, hw_planev); |
|
|
|
mean2 = MS_DIVQ_F32(mean2, hw_planev); |
|
|
|
mean3 = MS_DIVQ_F32(mean3, hw_planev); |
|
|
|
square_mean = MS_DIVQ_F32(square_mean, hw_planev); |
|
|
|
square_mean1 = MS_DIVQ_F32(square_mean1, hw_planev); |
|
|
|
square_mean2 = MS_DIVQ_F32(square_mean2, hw_planev); |
|
|
|
square_mean3 = MS_DIVQ_F32(square_mean3, hw_planev); |
|
|
|
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno1 = |
|
|
|
MS_ADDQ_F32(MS_SUBQ_F32(square_mean1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno2 = |
|
|
|
MS_ADDQ_F32(MS_SUBQ_F32(square_mean2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno3 = |
|
|
|
MS_ADDQ_F32(MS_SUBQ_F32(square_mean3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); |
|
|
|
deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); |
|
|
|
deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2)); |
|
|
|
deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3)); |
|
|
|
|
|
|
|
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c); |
|
|
|
MS_FLOAT32X4 betav1 = MS_LDQ_F32(beta_data + c + C4NUM); |
|
|
|
MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM); |
|
|
|
MS_FLOAT32X4 betav3 = MS_LDQ_F32(beta_data + c + C12NUM); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv3 = MS_LDQ_F32(src3 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean); |
|
|
|
MS_FLOAT32X4 outv1 = MS_SUBQ_F32(srcv1, mean1); |
|
|
|
MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2); |
|
|
|
MS_FLOAT32X4 outv3 = MS_SUBQ_F32(srcv3, mean3); |
|
|
|
outv = MS_MULQ_F32(outv, gammav); |
|
|
|
outv1 = MS_MULQ_F32(outv1, gammav1); |
|
|
|
outv2 = MS_MULQ_F32(outv2, gammav2); |
|
|
|
outv3 = MS_MULQ_F32(outv3, gammav3); |
|
|
|
outv = MS_ADDQ_F32(outv, betav); |
|
|
|
outv1 = MS_ADDQ_F32(outv1, betav1); |
|
|
|
outv2 = MS_ADDQ_F32(outv2, betav2); |
|
|
|
outv3 = MS_ADDQ_F32(outv3, betav3); |
|
|
|
MS_STQ_F32(dst + index * channel, outv); |
|
|
|
MS_STQ_F32(dst + index * channel + C4NUM, outv1); |
|
|
|
MS_STQ_F32(dst + index * channel + C8NUM, outv2); |
|
|
|
MS_STQ_F32(dst + index * channel + C12NUM, outv3); |
|
|
|
} |
|
|
|
} |
|
|
|
for (; c <= channel_end - C8NUM; c += C8NUM) { |
|
|
|
const float *src = src_b + c * hw_plane; |
|
|
|
const float *src1 = src_b + (c + C4NUM) * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 mean1 = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean1 = MS_MOVQ_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); |
|
|
|
MS_FLOAT32X4 squarev1 = MS_MULQ_F32(srcv1, srcv1); |
|
|
|
mean = MS_ADDQ_F32(mean, srcv); |
|
|
|
mean1 = MS_ADDQ_F32(mean1, srcv1); |
|
|
|
square_mean = MS_ADDQ_F32(square_mean, squarev); |
|
|
|
square_mean1 = MS_ADDQ_F32(square_mean1, squarev1); |
|
|
|
} |
|
|
|
mean = MS_DIVQ_F32(mean, hw_planev); |
|
|
|
mean1 = MS_DIVQ_F32(mean1, hw_planev); |
|
|
|
square_mean = MS_DIVQ_F32(square_mean, hw_planev); |
|
|
|
square_mean1 = MS_DIVQ_F32(square_mean1, hw_planev); |
|
|
|
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X4 deno1 = |
|
|
|
MS_ADDQ_F32(MS_SUBQ_F32(square_mean1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); |
|
|
|
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); |
|
|
|
deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); |
|
|
|
|
|
|
|
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c); |
|
|
|
MS_FLOAT32X4 betav1 = MS_LDQ_F32(beta_data + c + C4NUM); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); |
|
|
|
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); |
|
|
|
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean); |
|
|
|
MS_FLOAT32X4 outv1 = MS_SUBQ_F32(srcv1, mean1); |
|
|
|
outv = MS_MULQ_F32(outv, gammav); |
|
|
|
outv1 = MS_MULQ_F32(outv1, gammav1); |
|
|
|
outv = MS_ADDQ_F32(outv, betav); |
|
|
|
outv1 = MS_ADDQ_F32(outv1, betav1); |
|
|
|
MS_STQ_F32(dst + index * channel, outv); |
|
|
|
MS_STQ_F32(dst + index * channel + C4NUM, outv1); |
|
|
|
} |
|
|
|
} |
|
|
|
for (; c <= channel_end - C4NUM; c += C4NUM) { |
|
|
|
const float *src = src_b + c * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); |
|
|
|
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); |
|
|
|
mean = MS_ADDQ_F32(mean, srcv); |
|
|
|
square_mean = MS_ADDQ_F32(square_mean, squarev); |
|
|
|
} |
|
|
|
mean = MS_DIVQ_F32(mean, hw_planev); |
|
|
|
square_mean = MS_DIVQ_F32(square_mean, hw_planev); |
|
|
|
MS_FLOAT32X4 deno = |
|
|
|
MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); // question |
|
|
|
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); |
|
|
|
|
|
|
|
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); |
|
|
|
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean); |
|
|
|
outv = MS_MULQ_F32(outv, gammav); |
|
|
|
outv = MS_ADDQ_F32(outv, betav); |
|
|
|
MS_STQ_F32(dst + index * channel, outv); |
|
|
|
} |
|
|
|
} |
|
|
|
InstanceNormC4HW4ArmSse(src_b, dst_b, gamma_data, beta_data, &c, param, channel, channel_end, hw_plane, hw_planev); |
|
|
|
#endif |
|
|
|
for (; c < channel_end; ++c) { |
|
|
|
int c4_down_loop = c / C4NUM * C4NUM; |
|
|
|
@@ -293,15 +245,15 @@ int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamm |
|
|
|
const float *src = src_b + c4_down_loop * hw_plane + c4_mod; |
|
|
|
float *dst = dst_b + c; |
|
|
|
float mean = 0.0f; |
|
|
|
float square_mean = 0.0f; |
|
|
|
float squ_m = 0.0f; |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
float tmp = src[index * c_res]; |
|
|
|
mean += tmp; |
|
|
|
square_mean += tmp * tmp; |
|
|
|
squ_m += tmp * tmp; |
|
|
|
} |
|
|
|
mean /= (float)hw_plane; |
|
|
|
square_mean /= (float)hw_plane; |
|
|
|
const float deno = gamma_data[c] / sqrtf(square_mean - mean * mean + param->epsilon_); |
|
|
|
squ_m /= (float)hw_plane; |
|
|
|
const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; |
|
|
|
} |
|
|
|
@@ -316,8 +268,7 @@ int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamm |
|
|
|
NNACL_CHECK_NULL_RETURN_ERR(src_data); |
|
|
|
NNACL_CHECK_NULL_RETURN_ERR(dst_data); |
|
|
|
NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); |
|
|
|
int channel = param->channel_; |
|
|
|
int hw_plane = param->inner_size_; |
|
|
|
int channel = param->channel_, hw_plane = param->inner_size_; |
|
|
|
int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; |
|
|
|
int channel_begin = (int)(task_id)*channel_step; |
|
|
|
int channel_end = MSMIN(channel_begin + channel_step, channel); |
|
|
|
@@ -330,40 +281,33 @@ int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamm |
|
|
|
const float *src = src_b + c * hw_plane; |
|
|
|
const float *src1 = src_b + (c + C8NUM) * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 mean1 = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 square_mean = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 square_mean1 = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), mean1 = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 squ_m = MS_MOV256_F32(0.0f), squ_m1 = MS_MOV256_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); |
|
|
|
MS_FLOAT32X8 srcv1 = MS_LD256_F32(src1 + index * C8NUM); |
|
|
|
MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv); |
|
|
|
MS_FLOAT32X8 squarev1 = MS_MUL256_F32(srcv1, srcv1); |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); |
|
|
|
MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv), squarev1 = MS_MUL256_F32(srcv1, srcv1); |
|
|
|
mean = MS_ADD256_F32(mean, srcv); |
|
|
|
mean1 = MS_ADD256_F32(mean1, srcv1); |
|
|
|
square_mean = MS_ADD256_F32(square_mean, squarev); |
|
|
|
square_mean1 = MS_ADD256_F32(square_mean1, squarev1); |
|
|
|
squ_m = MS_ADD256_F32(squ_m, squarev); |
|
|
|
squ_m1 = MS_ADD256_F32(squ_m1, squarev1); |
|
|
|
} |
|
|
|
mean = MS_DIV256_F32(mean, hw_planev); |
|
|
|
mean1 = MS_DIV256_F32(mean1, hw_planev); |
|
|
|
square_mean = MS_DIV256_F32(square_mean, hw_planev); |
|
|
|
square_mean1 = MS_DIV256_F32(square_mean1, hw_planev); |
|
|
|
squ_m = MS_DIV256_F32(squ_m, hw_planev); |
|
|
|
squ_m1 = MS_DIV256_F32(squ_m1, hw_planev); |
|
|
|
MS_FLOAT32X8 deno = |
|
|
|
MS_ADD256_F32(MS_SUB256_F32(square_mean, MS_MUL256_F32(mean, mean)), MS_MOV256_F32(param->epsilon_)); |
|
|
|
MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), MS_MOV256_F32(param->epsilon_)); |
|
|
|
MS_FLOAT32X8 deno1 = |
|
|
|
MS_ADD256_F32(MS_SUB256_F32(square_mean1, MS_MUL256_F32(mean1, mean1)), MS_MOV256_F32(param->epsilon_)); |
|
|
|
MS_ADD256_F32(MS_SUB256_F32(squ_m1, MS_MUL256_F32(mean1, mean1)), MS_MOV256_F32(param->epsilon_)); |
|
|
|
deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); |
|
|
|
deno1 = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno1)); |
|
|
|
|
|
|
|
MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X8 gammav1 = MS_MUL256_F32(MS_LD256_F32(gamma_data + c + C8NUM), deno1); // deno1 * gamma_data[c] |
|
|
|
MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c); |
|
|
|
MS_FLOAT32X8 betav1 = MS_LD256_F32(beta_data + c + C8NUM); |
|
|
|
MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c), betav1 = MS_LD256_F32(beta_data + c + C8NUM); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); |
|
|
|
MS_FLOAT32X8 srcv1 = MS_LD256_F32(src1 + index * C8NUM); |
|
|
|
MS_FLOAT32X8 outv = MS_SUB256_F32(srcv, mean); |
|
|
|
MS_FLOAT32X8 outv1 = MS_SUB256_F32(srcv1, mean1); |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); |
|
|
|
MS_FLOAT32X8 outv = MS_SUB256_F32(srcv, mean), outv1 = MS_SUB256_F32(srcv1, mean1); |
|
|
|
outv = MS_MUL256_F32(outv, gammav); |
|
|
|
outv1 = MS_MUL256_F32(outv1, gammav1); |
|
|
|
outv = MS_ADD256_F32(outv, betav); |
|
|
|
@@ -375,46 +319,42 @@ int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamm |
|
|
|
for (; c <= channel_end - C8NUM; c += C8NUM) { |
|
|
|
const float *src = src_b + c * hw_plane; |
|
|
|
float *dst = dst_b + c; |
|
|
|
MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 square_mean = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), squ_m = MS_MOV256_F32(0.0f); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); |
|
|
|
MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv); |
|
|
|
mean = MS_ADD256_F32(mean, srcv); |
|
|
|
square_mean = MS_ADD256_F32(square_mean, squarev); |
|
|
|
squ_m = MS_ADD256_F32(squ_m, squarev); |
|
|
|
} |
|
|
|
mean = MS_DIV256_F32(mean, hw_planev); |
|
|
|
square_mean = MS_DIV256_F32(square_mean, hw_planev); |
|
|
|
MS_FLOAT32X8 deno = MS_ADD256_F32(MS_SUB256_F32(square_mean, MS_MUL256_F32(mean, mean)), |
|
|
|
squ_m = MS_DIV256_F32(squ_m, hw_planev); |
|
|
|
MS_FLOAT32X8 deno = MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), |
|
|
|
MS_MOV256_F32(param->epsilon_)); // 256uestion |
|
|
|
deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); |
|
|
|
|
|
|
|
MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] |
|
|
|
MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); |
|
|
|
MS_FLOAT32X8 outv = MS_SUB256_F32(srcv, mean); |
|
|
|
MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), outv = MS_SUB256_F32(srcv, mean); |
|
|
|
outv = MS_MUL256_F32(outv, gammav); |
|
|
|
outv = MS_ADD256_F32(outv, betav); |
|
|
|
MS_ST256_F32(dst + index * channel, outv); |
|
|
|
} |
|
|
|
} |
|
|
|
for (; c < channel_end; ++c) { |
|
|
|
int c8_down_loop = c / C8NUM * C8NUM; |
|
|
|
int c8_mod = c % C8NUM; |
|
|
|
int c8_down_loop = c / C8NUM * C8NUM, c8_mod = c % C8NUM; |
|
|
|
int c_res = MSMIN(channel_end - c8_down_loop, C8NUM); |
|
|
|
const float *src = src_b + c8_down_loop * hw_plane + c8_mod; |
|
|
|
float *dst = dst_b + c; |
|
|
|
float mean = 0.0f; |
|
|
|
float square_mean = 0.0f; |
|
|
|
float mean = 0.0f, squ_m = 0.0f; |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
float tmp = src[index * c_res]; |
|
|
|
mean += tmp; |
|
|
|
square_mean += tmp * tmp; |
|
|
|
squ_m += tmp * tmp; |
|
|
|
} |
|
|
|
mean /= (float)hw_plane; |
|
|
|
square_mean /= (float)hw_plane; |
|
|
|
const float deno = gamma_data[c] / sqrtf(square_mean - mean * mean + param->epsilon_); |
|
|
|
squ_m /= (float)hw_plane; |
|
|
|
const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); |
|
|
|
for (int index = 0; index < hw_plane; ++index) { |
|
|
|
dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; |
|
|
|
} |
|
|
|
|