|
|
|
@@ -18,30 +18,87 @@ |
|
|
|
#include "nnacl/errorcode.h" |
|
|
|
#include "nnacl/op_base.h" |
|
|
|
|
|
|
|
int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data, |
|
|
|
bool affine, float epsilon, float *dst_data, int tid, int thread_num) { |
|
|
|
int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data, |
|
|
|
const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num) { |
|
|
|
if (src_data == NULL || dst_data == NULL) { |
|
|
|
return NNACL_NULL_PTR; |
|
|
|
} |
|
|
|
if (affine && (gamma_data == NULL || beta_data == NULL)) { |
|
|
|
return NNACL_NULL_PTR; |
|
|
|
} |
|
|
|
for (int j = tid; j < outer_size; j += thread_num) { |
|
|
|
|
|
|
|
for (size_t j = task_id; j < outer_size; j += thread_num) { |
|
|
|
const float *src = src_data + j * inner_size; |
|
|
|
float *dst = dst_data + j * inner_size; |
|
|
|
float mean = 0.0f; |
|
|
|
float square_mean = 0.0f; |
|
|
|
for (int i = 0; i < inner_size; i++) { |
|
|
|
mean += src[i]; |
|
|
|
square_mean += src[i] * src[i]; |
|
|
|
|
|
|
|
int index = 0; |
|
|
|
#ifdef ENABLE_NEON |
|
|
|
float32x4_t sum = vdupq_n_f32(0); |
|
|
|
float32x4_t square_sum = vdupq_n_f32(0); |
|
|
|
for (; index < inner_size - C8NUM; index += C8NUM) { |
|
|
|
float32x4_t srcv1 = vld1q_f32(src + index); |
|
|
|
float32x4_t srcv2 = vld1q_f32(src + index + 4); |
|
|
|
float32x4_t squarev1 = vmulq_f32(srcv1, srcv1); |
|
|
|
float32x4_t squarev2 = vmulq_f32(srcv2, srcv2); |
|
|
|
sum = vaddq_f32(sum, srcv1); |
|
|
|
sum = vaddq_f32(sum, srcv2); |
|
|
|
square_sum = vaddq_f32(square_sum, squarev1); |
|
|
|
square_sum = vaddq_f32(square_sum, squarev2); |
|
|
|
} |
|
|
|
mean = sum[0] + sum[1] + sum[2] + sum[3]; |
|
|
|
square_mean = square_sum[0] + square_sum[1] + square_sum[2] + square_sum[3]; |
|
|
|
#endif |
|
|
|
for (; index < inner_size; index++) { |
|
|
|
mean += src[index]; |
|
|
|
square_mean += src[index] * src[index]; |
|
|
|
} |
|
|
|
|
|
|
|
mean /= (float)inner_size; |
|
|
|
square_mean /= (float)inner_size; |
|
|
|
const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon); |
|
|
|
for (int i = 0; i < inner_size; ++i) { |
|
|
|
dst[i] = (src[i] - mean) * deno; |
|
|
|
|
|
|
|
index = 0; |
|
|
|
#ifdef ENABLE_NEON |
|
|
|
float32x4_t meanv = vdupq_n_f32(mean); |
|
|
|
float32x4_t denov = vdupq_n_f32(deno); |
|
|
|
if (affine) { |
|
|
|
for (; index < inner_size - C8NUM; index += C8NUM) { |
|
|
|
float32x4_t srcv1 = vld1q_f32(src + index); |
|
|
|
float32x4_t srcv2 = vld1q_f32(src + index + 4); |
|
|
|
float32x4_t outv1 = vsubq_f32(srcv1, meanv); |
|
|
|
float32x4_t outv2 = vsubq_f32(srcv2, meanv); |
|
|
|
outv1 = vmulq_f32(outv1, denov); |
|
|
|
outv2 = vmulq_f32(outv2, denov); |
|
|
|
float32x4_t gammav1 = vld1q_f32(gamma_data + index); |
|
|
|
float32x4_t gammav2 = vld1q_f32(gamma_data + index + 4); |
|
|
|
float32x4_t betav1 = vld1q_f32(beta_data + index); |
|
|
|
float32x4_t betav2 = vld1q_f32(beta_data + index + 4); |
|
|
|
outv1 = vmulq_f32(outv1, gammav1); |
|
|
|
outv2 = vmulq_f32(outv2, gammav2); |
|
|
|
outv1 = vaddq_f32(outv1, betav1); |
|
|
|
outv2 = vaddq_f32(outv2, betav2); |
|
|
|
vst1q_f32(dst + index, outv1); |
|
|
|
vst1q_f32(dst + index + 4, outv2); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (; index < inner_size - C8NUM; index += C8NUM) { |
|
|
|
float32x4_t srcv1 = vld1q_f32(src + index); |
|
|
|
float32x4_t srcv2 = vld1q_f32(src + index + 4); |
|
|
|
float32x4_t outv1 = vsubq_f32(srcv1, meanv); |
|
|
|
float32x4_t outv2 = vsubq_f32(srcv2, meanv); |
|
|
|
outv1 = vmulq_f32(outv1, denov); |
|
|
|
outv2 = vmulq_f32(outv2, denov); |
|
|
|
vst1q_f32(dst + index, outv1); |
|
|
|
vst1q_f32(dst + index + 4, outv2); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; index < inner_size; index++) { |
|
|
|
dst[index] = (src[index] - mean) * deno; |
|
|
|
if (affine) { |
|
|
|
dst[i] = dst[i] * gamma_data[i] + beta_data[i]; |
|
|
|
dst[index] = dst[index] * gamma_data[index] + beta_data[index]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|