| @@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) { | |||
| x += 4; | |||
| } | |||
| if (rest_n & 3) { | |||
| x[0] *= beta; | |||
| if ((rest_n & 3) > 1) | |||
| x[1] *= beta; | |||
| if ((rest_n & 3) > 2) | |||
| x[2] *= beta; | |||
| for (BLASLONG i = 0; i < (rest_n & 3); i ++) { | |||
| x[i] *= beta; | |||
| } | |||
| } | |||
| return; | |||
| @@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7; | |||
| bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7; | |||
| bfloat16x8_t x_vec; | |||
| bfloat16x4_t x_vecx4; | |||
| float32x4_t y1_vec, y2_vec; | |||
| float32x4_t fp32_low, fp32_high; | |||
| @@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| if (incx == 1 && incy == 1) { | |||
| if (beta != 1) { | |||
| beta_op(y, n, beta); | |||
| beta_op(y, m, beta); | |||
| } | |||
| for (i = 0; i < n / 8; i++) { | |||
| @@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| a_ptr += 4 * lda; | |||
| bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr); | |||
| x_vecx4 = vld1_bf16(x_ptr); | |||
| if (alpha != 1) { | |||
| x_vec = vcombine_bf16(x_vecx4, bf16_zero); | |||
| fp32_low = vreinterpretq_f32_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), | |||
| vreinterpretq_u16_bf16(x_vec))); | |||
| fp32_low = vcvt_f32_bf16(x_vecx4); | |||
| fp32_low = vmulq_n_f32(fp32_low, alpha); | |||
| x_vecx4 = vcvt_bf16_f32(fp32_low); | |||
| } | |||
| @@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| y1_vec = vld1q_f32(y_ptr); | |||
| a0 = vcombine_bf16(a0x4, bf16_zero); | |||
| a1 = vcombine_bf16(a1x4, bf16_zero); | |||
| a2 = vcombine_bf16(a2x4, bf16_zero); | |||
| a3 = vcombine_bf16(a3x4, bf16_zero); | |||
| a0 = vcombine_bf16(a0x4, a2x4); | |||
| a1 = vcombine_bf16(a1x4, a3x4); | |||
| t0 = vreinterpretq_bf16_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| t1 = vreinterpretq_bf16_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); | |||
| t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); | |||
| y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); | |||
| @@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| } | |||
| if (rest_m) { | |||
| x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); | |||
| x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); | |||
| x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); | |||
| x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); | |||
| fp32_low = vcvt_f32_bf16(x_vecx4); | |||
| x0 = vgetq_lane_f32(fp32_low, 0); | |||
| x1 = vgetq_lane_f32(fp32_low, 1); | |||
| x2 = vgetq_lane_f32(fp32_low, 2); | |||
| x3 = vgetq_lane_f32(fp32_low, 3); | |||
| for (BLASLONG j = 0; j < rest_m; j++) { | |||
| y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); | |||
| @@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| a_ptr += 2 * lda; | |||
| bfloat16_t tmp_buffer[4]; | |||
| memset((void*)tmp_buffer, 0, sizeof(bfloat16_t)); | |||
| tmp_buffer[0] = x_ptr[0]; | |||
| tmp_buffer[1] = x_ptr[1]; | |||
| x_vecx4 = vreinterpret_bf16_u16(vzip1_u16( | |||
| vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])), | |||
| vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1])) | |||
| )); | |||
| bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer); | |||
| if (alpha != 1) { | |||
| x_vec = vcombine_bf16(x_vecx4, bf16_zero); | |||
| fp32_low = vreinterpretq_f32_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), | |||
| vreinterpretq_u16_bf16(x_vec))); | |||
| fp32_low = vcvt_f32_bf16(x_vecx4); | |||
| fp32_low = vmulq_n_f32(fp32_low, alpha); | |||
| x_vecx4 = vcvt_bf16_f32(fp32_low); | |||
| } | |||
| @@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| t0 = vreinterpretq_bf16_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| t4 = vreinterpretq_bf16_u16( | |||
| t1 = vreinterpretq_bf16_u16( | |||
| vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); | |||
| y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); | |||
| y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0); | |||
| y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1); | |||
| y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0); | |||
| y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1); | |||
| vst1q_f32(y_ptr, y1_vec); | |||
| vst1q_f32(y_ptr + 4, y2_vec); | |||
| @@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| a0 = vcombine_bf16(a0x4, bf16_zero); | |||
| a1 = vcombine_bf16(a1x4, bf16_zero); | |||
| t0 = vreinterpretq_bf16_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| t1 = vreinterpretq_bf16_u16( | |||
| vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); | |||
| t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); | |||
| y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); | |||
| y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); | |||
| y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2); | |||
| y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3); | |||
| vst1q_f32(y_ptr, y1_vec); | |||
| a_ptr0 += 4; | |||
| a_ptr1 += 4; | |||
| a_ptr2 += 4; | |||
| a_ptr3 += 4; | |||
| y_ptr += 4; | |||
| } | |||
| if (m & 2) { | |||
| x0 = alpha * (vcvtah_f32_bf16(x_ptr[0])); | |||
| x1 = alpha * (vcvtah_f32_bf16(x_ptr[1])); | |||
| fp32_low = vcvt_f32_bf16(x_vecx4); | |||
| x0 = vgetq_lane_f32(fp32_low, 0); | |||
| x1 = vgetq_lane_f32(fp32_low, 1); | |||
| y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); | |||
| y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); | |||
| @@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, | |||
| } | |||
| if (m & 1) { | |||
| x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); | |||
| x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); | |||
| fp32_low = vcvt_f32_bf16(x_vecx4); | |||
| x0 = vgetq_lane_f32(fp32_low, 0); | |||
| x1 = vgetq_lane_f32(fp32_low, 1); | |||
| y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); | |||
| y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); | |||