diff --git a/mindspore/lite/nnacl/fp32/layer_norm_fp32.c b/mindspore/lite/nnacl/fp32/layer_norm_fp32.c index a7de910a4b..a75d5dff76 100644 --- a/mindspore/lite/nnacl/fp32/layer_norm_fp32.c +++ b/mindspore/lite/nnacl/fp32/layer_norm_fp32.c @@ -18,30 +18,87 @@ #include "nnacl/errorcode.h" #include "nnacl/op_base.h" -int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data, - bool affine, float epsilon, float *dst_data, int tid, int thread_num) { +int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data, + const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num) { if (src_data == NULL || dst_data == NULL) { return NNACL_NULL_PTR; } if (affine && (gamma_data == NULL || beta_data == NULL)) { return NNACL_NULL_PTR; } - for (int j = tid; j < outer_size; j += thread_num) { + + for (size_t j = task_id; j < outer_size; j += thread_num) { const float *src = src_data + j * inner_size; float *dst = dst_data + j * inner_size; float mean = 0.0f; float square_mean = 0.0f; - for (int i = 0; i < inner_size; i++) { - mean += src[i]; - square_mean += src[i] * src[i]; + + int index = 0; +#ifdef ENABLE_NEON + float32x4_t sum = vdupq_n_f32(0); + float32x4_t square_sum = vdupq_n_f32(0); + for (; index < inner_size - C8NUM; index += C8NUM) { + float32x4_t srcv1 = vld1q_f32(src + index); + float32x4_t srcv2 = vld1q_f32(src + index + 4); + float32x4_t squarev1 = vmulq_f32(srcv1, srcv1); + float32x4_t squarev2 = vmulq_f32(srcv2, srcv2); + sum = vaddq_f32(sum, srcv1); + sum = vaddq_f32(sum, srcv2); + square_sum = vaddq_f32(square_sum, squarev1); + square_sum = vaddq_f32(square_sum, squarev2); } + mean = sum[0] + sum[1] + sum[2] + sum[3]; + square_mean = square_sum[0] + square_sum[1] + square_sum[2] + square_sum[3]; +#endif + for (; index < inner_size; index++) { + mean += src[index]; + square_mean += src[index] * src[index]; + } + mean /= (float)inner_size; square_mean /= (float)inner_size; const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon); - for (int i = 0; i < inner_size; ++i) { - dst[i] = (src[i] - mean) * deno; + + index = 0; +#ifdef ENABLE_NEON + float32x4_t meanv = vdupq_n_f32(mean); + float32x4_t denov = vdupq_n_f32(deno); + if (affine) { + for (; index < inner_size - C8NUM; index += C8NUM) { + float32x4_t srcv1 = vld1q_f32(src + index); + float32x4_t srcv2 = vld1q_f32(src + index + 4); + float32x4_t outv1 = vsubq_f32(srcv1, meanv); + float32x4_t outv2 = vsubq_f32(srcv2, meanv); + outv1 = vmulq_f32(outv1, denov); + outv2 = vmulq_f32(outv2, denov); + float32x4_t gammav1 = vld1q_f32(gamma_data + index); + float32x4_t gammav2 = vld1q_f32(gamma_data + index + 4); + float32x4_t betav1 = vld1q_f32(beta_data + index); + float32x4_t betav2 = vld1q_f32(beta_data + index + 4); + outv1 = vmulq_f32(outv1, gammav1); + outv2 = vmulq_f32(outv2, gammav2); + outv1 = vaddq_f32(outv1, betav1); + outv2 = vaddq_f32(outv2, betav2); + vst1q_f32(dst + index, outv1); + vst1q_f32(dst + index + 4, outv2); + } + } else { + for (; index < inner_size - C8NUM; index += C8NUM) { + float32x4_t srcv1 = vld1q_f32(src + index); + float32x4_t srcv2 = vld1q_f32(src + index + 4); + float32x4_t outv1 = vsubq_f32(srcv1, meanv); + float32x4_t outv2 = vsubq_f32(srcv2, meanv); + outv1 = vmulq_f32(outv1, denov); + outv2 = vmulq_f32(outv2, denov); + vst1q_f32(dst + index, outv1); + vst1q_f32(dst + index + 4, outv2); + } + } +#endif + for (; index < inner_size; index++) { + dst[index] = (src[index] - mean) * deno; if (affine) { - dst[i] = dst[i] * gamma_data[i] + beta_data[i]; + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; } } } diff --git a/mindspore/lite/nnacl/fp32/layer_norm_fp32.h b/mindspore/lite/nnacl/fp32/layer_norm_fp32.h index a71ccabc8f..20b4a2901e 100644 --- a/mindspore/lite/nnacl/fp32/layer_norm_fp32.h +++ b/mindspore/lite/nnacl/fp32/layer_norm_fp32.h @@ -23,8 +23,8 @@ extern "C" { #endif -int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data, - bool affine, float epsilon, float *dst_data, int tid, int thread_num); +int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data, + const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num); #ifdef __cplusplus } #endif