| @@ -294,8 +294,8 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| const float* weight_xc_RUN = weight_xc.row(q / 4); | |||
| const float* weight_hc_RUN = weight_hc.row(q / 4); | |||
| float32x4_t _R = vld1q_f32(bias_c_RUBNWN); | |||
| float32x4_t _U = vld1q_f32(bias_c_RUBNWN + 4); | |||
| float32x4_t _gru_R = vld1q_f32(bias_c_RUBNWN); | |||
| float32x4_t _gru_U = vld1q_f32(bias_c_RUBNWN + 4); | |||
| float32x4_t _sum1 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum2 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum3 = vdupq_n_f32(0.f); | |||
| @@ -316,8 +316,8 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _weight_xc_R_3 = vld1q_f32(weight_xc_RUN + 24); | |||
| float32x4_t _weight_xc_U_3 = vld1q_f32(weight_xc_RUN + 28); | |||
| #if __aarch64__ | |||
| _R = vfmaq_laneq_f32(_R, _weight_xc_R, _xi, 0); | |||
| _U = vfmaq_laneq_f32(_U, _weight_xc_U, _xi, 0); | |||
| _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); | |||
| _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); | |||
| @@ -325,8 +325,8 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); | |||
| _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); | |||
| #else | |||
| _R = vmlaq_lane_f32(_R, _weight_xc_R, vget_low_f32(_xi), 0); | |||
| _U = vmlaq_lane_f32(_U, _weight_xc_U, vget_low_f32(_xi), 0); | |||
| _gru_R = vmlaq_lane_f32(_gru_R, _weight_xc_R, vget_low_f32(_xi), 0); | |||
| _gru_U = vmlaq_lane_f32(_gru_U, _weight_xc_U, vget_low_f32(_xi), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_R_1, vget_low_f32(_xi), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_U_1, vget_low_f32(_xi), 1); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_R_2, vget_high_f32(_xi), 0); | |||
| @@ -344,8 +344,8 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _xi = vdupq_n_f32(xi); | |||
| float32x4_t _weight_xc_R = vld1q_f32(weight_xc_RUN); | |||
| float32x4_t _weight_xc_U = vld1q_f32(weight_xc_RUN + 4); | |||
| _R = vmlaq_f32(_R, _weight_xc_R, _xi); | |||
| _U = vmlaq_f32(_U, _weight_xc_U, _xi); | |||
| _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); | |||
| _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); | |||
| weight_xc_RUN += 8; | |||
| } | |||
| @@ -363,8 +363,8 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _weight_hc_R_3 = vld1q_f32(weight_hc_RUN + 24); | |||
| float32x4_t _weight_hc_U_3 = vld1q_f32(weight_hc_RUN + 28); | |||
| #if __aarch64__ | |||
| _R = vfmaq_laneq_f32(_R, _weight_hc_R, _h_cont, 0); | |||
| _U = vfmaq_laneq_f32(_U, _weight_hc_U, _h_cont, 0); | |||
| _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); | |||
| _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); | |||
| @@ -372,8 +372,8 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); | |||
| _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); | |||
| #else | |||
| _R = vmlaq_lane_f32(_R, _weight_hc_R, vget_low_f32(_h_cont), 0); | |||
| _U = vmlaq_lane_f32(_U, _weight_hc_U, vget_low_f32(_h_cont), 0); | |||
| _gru_R = vmlaq_lane_f32(_gru_R, _weight_hc_R, vget_low_f32(_h_cont), 0); | |||
| _gru_U = vmlaq_lane_f32(_gru_U, _weight_hc_U, vget_low_f32(_h_cont), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_R_1, vget_low_f32(_h_cont), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_U_1, vget_low_f32(_h_cont), 1); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_R_2, vget_high_f32(_h_cont), 0); | |||
| @@ -391,26 +391,26 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _h_cont = vdupq_n_f32(h_cont); | |||
| float32x4_t _weight_hc_R = vld1q_f32(weight_hc_RUN); | |||
| float32x4_t _weight_hc_U = vld1q_f32(weight_hc_RUN + 4); | |||
| _R = vmlaq_f32(_R, _weight_hc_R, _h_cont); | |||
| _U = vmlaq_f32(_U, _weight_hc_U, _h_cont); | |||
| _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); | |||
| _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); | |||
| weight_hc_RUN += 8; | |||
| } | |||
| _R = vaddq_f32(_R, _sum1); | |||
| _U = vaddq_f32(_U, _sum2); | |||
| _gru_R = vaddq_f32(_gru_R, _sum1); | |||
| _gru_U = vaddq_f32(_gru_U, _sum2); | |||
| _sum3 = vaddq_f32(_sum3, _sum5); | |||
| _sum4 = vaddq_f32(_sum4, _sum6); | |||
| _R = vaddq_f32(_R, _sum3); | |||
| _U = vaddq_f32(_U, _sum4); | |||
| _gru_R = vaddq_f32(_gru_R, _sum3); | |||
| _gru_U = vaddq_f32(_gru_U, _sum4); | |||
| // sigmoid(R) | |||
| // sigmoid(U) | |||
| _R = sigmoid_ps(_R); | |||
| _U = sigmoid_ps(_U); | |||
| _gru_R = sigmoid_ps(_gru_R); | |||
| _gru_U = sigmoid_ps(_gru_U); | |||
| // gate new | |||
| float32x4_t _N = vld1q_f32(bias_c_RUBNWN + 8); | |||
| float32x4_t _gru_N = vld1q_f32(bias_c_RUBNWN + 8); | |||
| _sum1 = vdupq_n_f32(0.f); | |||
| _sum2 = vdupq_n_f32(0.f); | |||
| _sum3 = vdupq_n_f32(0.f); | |||
| @@ -424,12 +424,12 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _weight_hc_N_2 = vld1q_f32(weight_hc_RUN + 8); | |||
| float32x4_t _weight_hc_N_3 = vld1q_f32(weight_hc_RUN + 12); | |||
| #if __aarch64__ | |||
| _N = vfmaq_laneq_f32(_N, _weight_hc_N, _h_cont, 0); | |||
| _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); | |||
| #else | |||
| _N = vmlaq_lane_f32(_N, _weight_hc_N, vget_low_f32(_h_cont), 0); | |||
| _gru_N = vmlaq_lane_f32(_gru_N, _weight_hc_N, vget_low_f32(_h_cont), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_N_1, vget_low_f32(_h_cont), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_N_2, vget_high_f32(_h_cont), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_N_3, vget_high_f32(_h_cont), 1); | |||
| @@ -443,16 +443,16 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _h_cont = vdupq_n_f32(h_cont); | |||
| float32x4_t _weight_hc_N = vld1q_f32(weight_hc_RUN); | |||
| _N = vmlaq_f32(_N, _weight_hc_N, _h_cont); | |||
| _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); | |||
| weight_hc_RUN += 4; | |||
| } | |||
| _N = vaddq_f32(_N, _sum1); | |||
| _gru_N = vaddq_f32(_gru_N, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _N = vaddq_f32(_N, _sum2); | |||
| _gru_N = vaddq_f32(_gru_N, _sum2); | |||
| _N = vmlaq_f32(vld1q_f32(bias_c_RUBNWN + 12), _R, _N); | |||
| _gru_N = vmlaq_f32(vld1q_f32(bias_c_RUBNWN + 12), _gru_R, _gru_N); | |||
| _sum1 = vdupq_n_f32(0.f); | |||
| _sum2 = vdupq_n_f32(0.f); | |||
| _sum3 = vdupq_n_f32(0.f); | |||
| @@ -466,12 +466,12 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _weight_xc_N_2 = vld1q_f32(weight_xc_RUN + 8); | |||
| float32x4_t _weight_xc_N_3 = vld1q_f32(weight_xc_RUN + 12); | |||
| #if __aarch64__ | |||
| _N = vfmaq_laneq_f32(_N, _weight_xc_N, _xi, 0); | |||
| _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); | |||
| #else | |||
| _N = vmlaq_lane_f32(_N, _weight_xc_N, vget_low_f32(_xi), 0); | |||
| _gru_N = vmlaq_lane_f32(_gru_N, _weight_xc_N, vget_low_f32(_xi), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_N_1, vget_low_f32(_xi), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_N_2, vget_high_f32(_xi), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_N_3, vget_high_f32(_xi), 1); | |||
| @@ -485,22 +485,22 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _xi = vdupq_n_f32(xi); | |||
| float32x4_t _weight_xc_N = vld1q_f32(weight_xc_RUN); | |||
| _N = vmlaq_f32(_N, _weight_xc_N, _xi); | |||
| _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); | |||
| weight_xc_RUN += 4; | |||
| } | |||
| _N = vaddq_f32(_N, _sum1); | |||
| _gru_N = vaddq_f32(_gru_N, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _N = vaddq_f32(_N, _sum2); | |||
| _gru_N = vaddq_f32(_gru_N, _sum2); | |||
| // tanh(N) | |||
| _N = tanh_ps(_N); | |||
| _gru_N = tanh_ps(_gru_N); | |||
| float* gates_data = gates.row(q / 4); | |||
| vst1q_f32(gates_data, _U); | |||
| vst1q_f32(gates_data + 4, _N); | |||
| vst1q_f32(gates_data, _gru_U); | |||
| vst1q_f32(gates_data + 4, _gru_N); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -599,13 +599,13 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| const float* gates_data = gates.row(q / 4); | |||
| float32x4_t _U = vld1q_f32(gates_data); | |||
| float32x4_t _N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _gru_U = vld1q_f32(gates_data); | |||
| float32x4_t _gru_N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _U), _N), vmulq_f32(_U, vld1q_f32(hidden_ptr + q))); | |||
| float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1q_f32(output_data + q, _H); | |||
| vst1q_f32(hidden_ptr + q, _gru_H); | |||
| vst1q_f32(output_data + q, _gru_H); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -836,8 +836,8 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| const unsigned short* weight_xc_RUN = weight_xc.row<const unsigned short>(q / 4); | |||
| const unsigned short* weight_hc_RUN = weight_hc.row<const unsigned short>(q / 4); | |||
| float32x4_t _R = bfloat2float(vld1_u16(bias_c_RUBNWN)); | |||
| float32x4_t _U = bfloat2float(vld1_u16(bias_c_RUBNWN + 4)); | |||
| float32x4_t _gru_R = bfloat2float(vld1_u16(bias_c_RUBNWN)); | |||
| float32x4_t _gru_U = bfloat2float(vld1_u16(bias_c_RUBNWN + 4)); | |||
| float32x4_t _sum1 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum2 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum3 = vdupq_n_f32(0.f); | |||
| @@ -858,8 +858,8 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_xc_R_3 = bfloat2float(vld1_u16(weight_xc_RUN + 24)); | |||
| float32x4_t _weight_xc_U_3 = bfloat2float(vld1_u16(weight_xc_RUN + 28)); | |||
| #if __aarch64__ | |||
| _R = vfmaq_laneq_f32(_R, _weight_xc_R, _xi, 0); | |||
| _U = vfmaq_laneq_f32(_U, _weight_xc_U, _xi, 0); | |||
| _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); | |||
| _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); | |||
| @@ -867,8 +867,8 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); | |||
| _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); | |||
| #else | |||
| _R = vmlaq_lane_f32(_R, _weight_xc_R, vget_low_f32(_xi), 0); | |||
| _U = vmlaq_lane_f32(_U, _weight_xc_U, vget_low_f32(_xi), 0); | |||
| _gru_R = vmlaq_lane_f32(_gru_R, _weight_xc_R, vget_low_f32(_xi), 0); | |||
| _gru_U = vmlaq_lane_f32(_gru_U, _weight_xc_U, vget_low_f32(_xi), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_R_1, vget_low_f32(_xi), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_U_1, vget_low_f32(_xi), 1); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_R_2, vget_high_f32(_xi), 0); | |||
| @@ -886,8 +886,8 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); | |||
| float32x4_t _weight_xc_R = bfloat2float(vld1_u16(weight_xc_RUN)); | |||
| float32x4_t _weight_xc_U = bfloat2float(vld1_u16(weight_xc_RUN + 4)); | |||
| _R = vmlaq_f32(_R, _weight_xc_R, _xi); | |||
| _U = vmlaq_f32(_U, _weight_xc_U, _xi); | |||
| _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); | |||
| _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); | |||
| weight_xc_RUN += 8; | |||
| } | |||
| @@ -905,8 +905,8 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_hc_R_3 = bfloat2float(vld1_u16(weight_hc_RUN + 24)); | |||
| float32x4_t _weight_hc_U_3 = bfloat2float(vld1_u16(weight_hc_RUN + 28)); | |||
| #if __aarch64__ | |||
| _R = vfmaq_laneq_f32(_R, _weight_hc_R, _h_cont, 0); | |||
| _U = vfmaq_laneq_f32(_U, _weight_hc_U, _h_cont, 0); | |||
| _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); | |||
| _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); | |||
| @@ -914,8 +914,8 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); | |||
| _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); | |||
| #else | |||
| _R = vmlaq_lane_f32(_R, _weight_hc_R, vget_low_f32(_h_cont), 0); | |||
| _U = vmlaq_lane_f32(_U, _weight_hc_U, vget_low_f32(_h_cont), 0); | |||
| _gru_R = vmlaq_lane_f32(_gru_R, _weight_hc_R, vget_low_f32(_h_cont), 0); | |||
| _gru_U = vmlaq_lane_f32(_gru_U, _weight_hc_U, vget_low_f32(_h_cont), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_R_1, vget_low_f32(_h_cont), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_U_1, vget_low_f32(_h_cont), 1); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_R_2, vget_high_f32(_h_cont), 0); | |||
| @@ -933,26 +933,26 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _h_cont = vdupq_n_f32(h_cont); | |||
| float32x4_t _weight_hc_R = bfloat2float(vld1_u16(weight_hc_RUN)); | |||
| float32x4_t _weight_hc_U = bfloat2float(vld1_u16(weight_hc_RUN + 4)); | |||
| _R = vmlaq_f32(_R, _weight_hc_R, _h_cont); | |||
| _U = vmlaq_f32(_U, _weight_hc_U, _h_cont); | |||
| _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); | |||
| _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); | |||
| weight_hc_RUN += 8; | |||
| } | |||
| _R = vaddq_f32(_R, _sum1); | |||
| _U = vaddq_f32(_U, _sum2); | |||
| _gru_R = vaddq_f32(_gru_R, _sum1); | |||
| _gru_U = vaddq_f32(_gru_U, _sum2); | |||
| _sum3 = vaddq_f32(_sum3, _sum5); | |||
| _sum4 = vaddq_f32(_sum4, _sum6); | |||
| _R = vaddq_f32(_R, _sum3); | |||
| _U = vaddq_f32(_U, _sum4); | |||
| _gru_R = vaddq_f32(_gru_R, _sum3); | |||
| _gru_U = vaddq_f32(_gru_U, _sum4); | |||
| // sigmoid(R) | |||
| // sigmoid(U) | |||
| _R = sigmoid_ps(_R); | |||
| _U = sigmoid_ps(_U); | |||
| _gru_R = sigmoid_ps(_gru_R); | |||
| _gru_U = sigmoid_ps(_gru_U); | |||
| // gate new | |||
| float32x4_t _N = bfloat2float(vld1_u16(bias_c_RUBNWN + 8)); | |||
| float32x4_t _gru_N = bfloat2float(vld1_u16(bias_c_RUBNWN + 8)); | |||
| _sum1 = vdupq_n_f32(0.f); | |||
| _sum2 = vdupq_n_f32(0.f); | |||
| _sum3 = vdupq_n_f32(0.f); | |||
| @@ -966,12 +966,12 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_hc_N_2 = bfloat2float(vld1_u16(weight_hc_RUN + 8)); | |||
| float32x4_t _weight_hc_N_3 = bfloat2float(vld1_u16(weight_hc_RUN + 12)); | |||
| #if __aarch64__ | |||
| _N = vfmaq_laneq_f32(_N, _weight_hc_N, _h_cont, 0); | |||
| _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); | |||
| #else | |||
| _N = vmlaq_lane_f32(_N, _weight_hc_N, vget_low_f32(_h_cont), 0); | |||
| _gru_N = vmlaq_lane_f32(_gru_N, _weight_hc_N, vget_low_f32(_h_cont), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_N_1, vget_low_f32(_h_cont), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_N_2, vget_high_f32(_h_cont), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_N_3, vget_high_f32(_h_cont), 1); | |||
| @@ -985,16 +985,16 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _h_cont = vdupq_n_f32(h_cont); | |||
| float32x4_t _weight_hc_N = bfloat2float(vld1_u16(weight_hc_RUN)); | |||
| _N = vmlaq_f32(_N, _weight_hc_N, _h_cont); | |||
| _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); | |||
| weight_hc_RUN += 4; | |||
| } | |||
| _N = vaddq_f32(_N, _sum1); | |||
| _gru_N = vaddq_f32(_gru_N, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _N = vaddq_f32(_N, _sum2); | |||
| _gru_N = vaddq_f32(_gru_N, _sum2); | |||
| _N = vmlaq_f32(bfloat2float(vld1_u16(bias_c_RUBNWN + 12)), _R, _N); | |||
| _gru_N = vmlaq_f32(bfloat2float(vld1_u16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); | |||
| _sum1 = vdupq_n_f32(0.f); | |||
| _sum2 = vdupq_n_f32(0.f); | |||
| _sum3 = vdupq_n_f32(0.f); | |||
| @@ -1008,12 +1008,12 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_xc_N_2 = bfloat2float(vld1_u16(weight_xc_RUN + 8)); | |||
| float32x4_t _weight_xc_N_3 = bfloat2float(vld1_u16(weight_xc_RUN + 12)); | |||
| #if __aarch64__ | |||
| _N = vfmaq_laneq_f32(_N, _weight_xc_N, _xi, 0); | |||
| _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); | |||
| #else | |||
| _N = vmlaq_lane_f32(_N, _weight_xc_N, vget_low_f32(_xi), 0); | |||
| _gru_N = vmlaq_lane_f32(_gru_N, _weight_xc_N, vget_low_f32(_xi), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_N_1, vget_low_f32(_xi), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_N_2, vget_high_f32(_xi), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_N_3, vget_high_f32(_xi), 1); | |||
| @@ -1027,22 +1027,22 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _xi = bfloat2float(vdup_n_u16(xi)); | |||
| float32x4_t _weight_xc_N = bfloat2float(vld1_u16(weight_xc_RUN)); | |||
| _N = vmlaq_f32(_N, _weight_xc_N, _xi); | |||
| _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); | |||
| weight_xc_RUN += 4; | |||
| } | |||
| _N = vaddq_f32(_N, _sum1); | |||
| _gru_N = vaddq_f32(_gru_N, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _N = vaddq_f32(_N, _sum2); | |||
| _gru_N = vaddq_f32(_gru_N, _sum2); | |||
| // tanh(N) | |||
| _N = tanh_ps(_N); | |||
| _gru_N = tanh_ps(_gru_N); | |||
| float* gates_data = gates.row(q / 4); | |||
| vst1q_f32(gates_data, _U); | |||
| vst1q_f32(gates_data + 4, _N); | |||
| vst1q_f32(gates_data, _gru_U); | |||
| vst1q_f32(gates_data + 4, _gru_N); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -1141,13 +1141,13 @@ static int gru_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| const float* gates_data = gates.row(q / 4); | |||
| float32x4_t _U = vld1q_f32(gates_data); | |||
| float32x4_t _N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _gru_U = vld1q_f32(gates_data); | |||
| float32x4_t _gru_N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _U), _N), vmulq_f32(_U, vld1q_f32(hidden_ptr + q))); | |||
| float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_u16(output_data + q, float2bfloat(_H)); | |||
| vst1q_f32(hidden_ptr + q, _gru_H); | |||
| vst1_u16(output_data + q, float2bfloat(_gru_H)); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -57,8 +57,8 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| const __fp16* weight_xc_RUN = weight_xc.row<const __fp16>(q / 4); | |||
| const __fp16* weight_hc_RUN = weight_hc.row<const __fp16>(q / 4); | |||
| float32x4_t _R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); | |||
| float32x4_t _U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); | |||
| float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); | |||
| float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); | |||
| float32x4_t _sum1 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum2 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum3 = vdupq_n_f32(0.f); | |||
| @@ -78,8 +78,8 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_xc_U_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 20)); | |||
| float32x4_t _weight_xc_R_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 24)); | |||
| float32x4_t _weight_xc_U_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 28)); | |||
| _R = vfmaq_laneq_f32(_R, _weight_xc_R, _xi, 0); | |||
| _U = vfmaq_laneq_f32(_U, _weight_xc_U, _xi, 0); | |||
| _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); | |||
| _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); | |||
| @@ -96,8 +96,8 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); | |||
| float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); | |||
| float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); | |||
| _R = vmlaq_f32(_R, _weight_xc_R, _xi); | |||
| _U = vmlaq_f32(_U, _weight_xc_U, _xi); | |||
| _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); | |||
| _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); | |||
| weight_xc_RUN += 8; | |||
| } | |||
| @@ -114,8 +114,8 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_hc_U_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 20)); | |||
| float32x4_t _weight_hc_R_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 24)); | |||
| float32x4_t _weight_hc_U_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 28)); | |||
| _R = vfmaq_laneq_f32(_R, _weight_hc_R, _h_cont, 0); | |||
| _U = vfmaq_laneq_f32(_U, _weight_hc_U, _h_cont, 0); | |||
| _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); | |||
| _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); | |||
| @@ -132,26 +132,26 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _h_cont = vdupq_n_f32(h_cont); | |||
| float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); | |||
| float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); | |||
| _R = vmlaq_f32(_R, _weight_hc_R, _h_cont); | |||
| _U = vmlaq_f32(_U, _weight_hc_U, _h_cont); | |||
| _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); | |||
| _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); | |||
| weight_hc_RUN += 8; | |||
| } | |||
| _R = vaddq_f32(_R, _sum1); | |||
| _U = vaddq_f32(_U, _sum2); | |||
| _gru_R = vaddq_f32(_gru_R, _sum1); | |||
| _gru_U = vaddq_f32(_gru_U, _sum2); | |||
| _sum3 = vaddq_f32(_sum3, _sum5); | |||
| _sum4 = vaddq_f32(_sum4, _sum6); | |||
| _R = vaddq_f32(_R, _sum3); | |||
| _U = vaddq_f32(_U, _sum4); | |||
| _gru_R = vaddq_f32(_gru_R, _sum3); | |||
| _gru_U = vaddq_f32(_gru_U, _sum4); | |||
| // sigmoid(R) | |||
| // sigmoid(U) | |||
| _R = sigmoid_ps(_R); | |||
| _U = sigmoid_ps(_U); | |||
| _gru_R = sigmoid_ps(_gru_R); | |||
| _gru_U = sigmoid_ps(_gru_U); | |||
| // gate new | |||
| float32x4_t _N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); | |||
| float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); | |||
| _sum1 = vdupq_n_f32(0.f); | |||
| _sum2 = vdupq_n_f32(0.f); | |||
| _sum3 = vdupq_n_f32(0.f); | |||
| @@ -164,7 +164,7 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_hc_N_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); | |||
| float32x4_t _weight_hc_N_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); | |||
| float32x4_t _weight_hc_N_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); | |||
| _N = vfmaq_laneq_f32(_N, _weight_hc_N, _h_cont, 0); | |||
| _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); | |||
| @@ -177,16 +177,16 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _h_cont = vdupq_n_f32(h_cont); | |||
| float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); | |||
| _N = vmlaq_f32(_N, _weight_hc_N, _h_cont); | |||
| _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); | |||
| weight_hc_RUN += 4; | |||
| } | |||
| _N = vaddq_f32(_N, _sum1); | |||
| _gru_N = vaddq_f32(_gru_N, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _N = vaddq_f32(_N, _sum2); | |||
| _gru_N = vaddq_f32(_gru_N, _sum2); | |||
| _N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _R, _N); | |||
| _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); | |||
| _sum1 = vdupq_n_f32(0.f); | |||
| _sum2 = vdupq_n_f32(0.f); | |||
| _sum3 = vdupq_n_f32(0.f); | |||
| @@ -199,7 +199,7 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_xc_N_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); | |||
| float32x4_t _weight_xc_N_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); | |||
| float32x4_t _weight_xc_N_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); | |||
| _N = vfmaq_laneq_f32(_N, _weight_xc_N, _xi, 0); | |||
| _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); | |||
| @@ -212,22 +212,22 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); | |||
| float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); | |||
| _N = vmlaq_f32(_N, _weight_xc_N, _xi); | |||
| _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); | |||
| weight_xc_RUN += 4; | |||
| } | |||
| _N = vaddq_f32(_N, _sum1); | |||
| _gru_N = vaddq_f32(_gru_N, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _N = vaddq_f32(_N, _sum2); | |||
| _gru_N = vaddq_f32(_gru_N, _sum2); | |||
| // tanh(N) | |||
| _N = tanh_ps(_N); | |||
| _gru_N = tanh_ps(_gru_N); | |||
| float* gates_data = gates.row(q / 4); | |||
| vst1q_f32(gates_data, _U); | |||
| vst1q_f32(gates_data + 4, _N); | |||
| vst1q_f32(gates_data, _gru_U); | |||
| vst1q_f32(gates_data + 4, _gru_N); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| @@ -314,13 +314,13 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| const float* gates_data = gates.row(q / 4); | |||
| float32x4_t _U = vld1q_f32(gates_data); | |||
| float32x4_t _N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _gru_U = vld1q_f32(gates_data); | |||
| float32x4_t _gru_N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _U), _N), vmulq_f32(_U, vld1q_f32(hidden_ptr + q))); | |||
| float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| vst1q_f32(hidden_ptr + q, _gru_H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_gru_H)); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| @@ -463,7 +463,7 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| hidden_ptr = hidden_state; | |||
| // gate new | |||
| float16x4_t _N = vld1_f16(bias_c_RUBNWN + 8); | |||
| float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); | |||
| float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); | |||
| float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); | |||
| float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); | |||
| @@ -481,13 +481,13 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| "fmla %5.4h, v3.4h, v4.h[3] \n" | |||
| : "=r"(hidden_ptr), | |||
| "=r"(weight_hc_RUN), | |||
| "=w"(_N), | |||
| "=w"(_gru_N), | |||
| "=w"(_sum4), | |||
| "=w"(_sum5), | |||
| "=w"(_sum6) | |||
| : "0"(hidden_ptr), | |||
| "1"(weight_hc_RUN), | |||
| "2"(_N), | |||
| "2"(_gru_N), | |||
| "3"(_sum4), | |||
| "4"(_sum5), | |||
| "5"(_sum6) | |||
| @@ -499,16 +499,16 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); | |||
| float16x4_t _weight_hc_N = vld1_f16(weight_hc_RUN); | |||
| _N = vfma_f16(_N, _weight_hc_N, _h_cont); | |||
| _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); | |||
| weight_hc_RUN += 4; | |||
| } | |||
| _N = vadd_f16(_N, _sum4); | |||
| _gru_N = vadd_f16(_gru_N, _sum4); | |||
| _sum5 = vadd_f16(_sum5, _sum6); | |||
| _N = vadd_f16(_N, _sum5); | |||
| _gru_N = vadd_f16(_gru_N, _sum5); | |||
| _N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _N); | |||
| _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); | |||
| _sum4 = vdup_n_f16((__fp16)0.f); | |||
| _sum5 = vdup_n_f16((__fp16)0.f); | |||
| _sum6 = vdup_n_f16((__fp16)0.f); | |||
| @@ -525,13 +525,13 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| "fmla %5.4h, v3.4h, v4.h[3] \n" | |||
| : "=r"(x), | |||
| "=r"(weight_xc_RUN), | |||
| "=w"(_N), | |||
| "=w"(_gru_N), | |||
| "=w"(_sum4), | |||
| "=w"(_sum5), | |||
| "=w"(_sum6) | |||
| : "0"(x), | |||
| "1"(weight_xc_RUN), | |||
| "2"(_N), | |||
| "2"(_gru_N), | |||
| "3"(_sum4), | |||
| "4"(_sum5), | |||
| "5"(_sum6) | |||
| @@ -543,17 +543,17 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x4_t _xi = vdup_n_f16(xi); | |||
| float16x4_t _weight_xc_N = vld1_f16(weight_xc_RUN); | |||
| _N = vfma_f16(_N, _weight_xc_N, _xi); | |||
| _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); | |||
| weight_xc_RUN += 4; | |||
| } | |||
| _N = vadd_f16(_N, _sum4); | |||
| _gru_N = vadd_f16(_gru_N, _sum4); | |||
| _sum5 = vadd_f16(_sum5, _sum6); | |||
| _N = vadd_f16(_N, _sum5); | |||
| _gru_N = vadd_f16(_gru_N, _sum5); | |||
| // tanh(N) | |||
| float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_N)); | |||
| float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); | |||
| float* gates_data = gates.row(q / 4); | |||
| @@ -645,13 +645,13 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| const float* gates_data = gates.row(q / 4); | |||
| float32x4_t _U = vld1q_f32(gates_data); | |||
| float32x4_t _N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _gru_U = vld1q_f32(gates_data); | |||
| float32x4_t _gru_N = vld1q_f32(gates_data + 4); | |||
| float32x4_t _H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _U), _N), vmulq_f32(_U, vld1q_f32(hidden_ptr + q))); | |||
| float32x4_t _gru_H = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U), _gru_N), vmulq_f32(_gru_U, vld1q_f32(hidden_ptr + q))); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| vst1q_f32(hidden_ptr + q, _gru_H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_gru_H)); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| @@ -254,11 +254,11 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _rows2 = vld1q_f32(rows2p); | |||
| float32x4_t _rows3 = vld1q_f32(rows3p); | |||
| float32x4_t _D = vmulq_lane_f32(_rows0, vget_low_f32(_b0123), 0); | |||
| _D = vmlaq_lane_f32(_D, _rows1, vget_low_f32(_b0123), 1); | |||
| _D = vmlaq_lane_f32(_D, _rows2, vget_high_f32(_b0123), 0); | |||
| _D = vmlaq_lane_f32(_D, _rows3, vget_high_f32(_b0123), 1); | |||
| vst1q_f32(Dp, _D); | |||
| float32x4_t _Dp = vmulq_lane_f32(_rows0, vget_low_f32(_b0123), 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows1, vget_low_f32(_b0123), 1); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows2, vget_high_f32(_b0123), 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows3, vget_high_f32(_b0123), 1); | |||
| vst1q_f32(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -254,11 +254,11 @@ static void resize_bicubic_image_pack4_bf16s(const Mat& src, Mat& dst, float* al | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _rows2 = vld1q_f32(rows2p); | |||
| float32x4_t _rows3 = vld1q_f32(rows3p); | |||
| float32x4_t _D = vmulq_lane_f32(_rows0, vget_low_f32(_b0123), 0); | |||
| _D = vmlaq_lane_f32(_D, _rows1, vget_low_f32(_b0123), 1); | |||
| _D = vmlaq_lane_f32(_D, _rows2, vget_high_f32(_b0123), 0); | |||
| _D = vmlaq_lane_f32(_D, _rows3, vget_high_f32(_b0123), 1); | |||
| vst1_u16(Dp, float2bfloat(_D)); | |||
| float32x4_t _Dp = vmulq_lane_f32(_rows0, vget_low_f32(_b0123), 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows1, vget_low_f32(_b0123), 1); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows2, vget_high_f32(_b0123), 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows3, vget_high_f32(_b0123), 1); | |||
| vst1_u16(Dp, float2bfloat(_Dp)); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -253,11 +253,11 @@ static void resize_bicubic_image_pack4_fp16s(const Mat& src, Mat& dst, float* al | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _rows2 = vld1q_f32(rows2p); | |||
| float32x4_t _rows3 = vld1q_f32(rows3p); | |||
| float32x4_t _D = vmulq_laneq_f32(_rows0, _b0123, 0); | |||
| _D = vfmaq_laneq_f32(_D, _rows1, _b0123, 1); | |||
| _D = vfmaq_laneq_f32(_D, _rows2, _b0123, 2); | |||
| _D = vfmaq_laneq_f32(_D, _rows3, _b0123, 3); | |||
| vst1_f16(Dp, vcvt_f16_f32(_D)); | |||
| float32x4_t _Dp = vmulq_laneq_f32(_rows0, _b0123, 0); | |||
| _Dp = vfmaq_laneq_f32(_Dp, _rows1, _b0123, 1); | |||
| _Dp = vfmaq_laneq_f32(_Dp, _rows2, _b0123, 2); | |||
| _Dp = vfmaq_laneq_f32(_Dp, _rows3, _b0123, 3); | |||
| vst1_f16(Dp, vcvt_f16_f32(_Dp)); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -511,11 +511,11 @@ static void resize_bicubic_image_pack4_fp16sa(const Mat& src, Mat& dst, __fp16* | |||
| float16x4_t _rows1 = vld1_f16(rows1p); | |||
| float16x4_t _rows2 = vld1_f16(rows2p); | |||
| float16x4_t _rows3 = vld1_f16(rows3p); | |||
| float16x4_t _D = vmul_lane_f16(_rows0, _b0123, 0); | |||
| _D = vfma_lane_f16(_D, _rows1, _b0123, 1); | |||
| _D = vfma_lane_f16(_D, _rows2, _b0123, 2); | |||
| _D = vfma_lane_f16(_D, _rows3, _b0123, 3); | |||
| vst1_f16(Dp, _D); | |||
| float16x4_t _Dp = vmul_lane_f16(_rows0, _b0123, 0); | |||
| _Dp = vfma_lane_f16(_Dp, _rows1, _b0123, 1); | |||
| _Dp = vfma_lane_f16(_Dp, _rows2, _b0123, 2); | |||
| _Dp = vfma_lane_f16(_Dp, _rows3, _b0123, 3); | |||
| vst1_f16(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -253,11 +253,11 @@ static void resize_bicubic_image_pack8_fp16sa(const Mat& src, Mat& dst, __fp16* | |||
| float16x8_t _rows1 = vld1q_f16(rows1p); | |||
| float16x8_t _rows2 = vld1q_f16(rows2p); | |||
| float16x8_t _rows3 = vld1q_f16(rows3p); | |||
| float16x8_t _D = vmulq_lane_f16(_rows0, _b0123, 0); | |||
| _D = vfmaq_lane_f16(_D, _rows1, _b0123, 1); | |||
| _D = vfmaq_lane_f16(_D, _rows2, _b0123, 2); | |||
| _D = vfmaq_lane_f16(_D, _rows3, _b0123, 3); | |||
| vst1q_f16(Dp, _D); | |||
| float16x8_t _Dp = vmulq_lane_f16(_rows0, _b0123, 0); | |||
| _Dp = vfmaq_lane_f16(_Dp, _rows1, _b0123, 1); | |||
| _Dp = vfmaq_lane_f16(_Dp, _rows2, _b0123, 2); | |||
| _Dp = vfmaq_lane_f16(_Dp, _rows3, _b0123, 3); | |||
| vst1q_f16(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -193,18 +193,18 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x | |||
| float32x4_t _rows0 = vld1q_f32(rows0p); | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _D = vmulq_f32(_rows0, _b0); | |||
| _D = vmlaq_f32(_D, _rows1, _b1); | |||
| float32x4_t _Dp = vmulq_f32(_rows0, _b0); | |||
| _Dp = vmlaq_f32(_Dp, _rows1, _b1); | |||
| vst1q_f32(Dp, _D); | |||
| vst1q_f32(Dp, _Dp); | |||
| float32x4_t _rows0n = vld1q_f32(rows0p + 4); | |||
| float32x4_t _rows1n = vld1q_f32(rows1p + 4); | |||
| float32x4_t _Dn = vmulq_f32(_rows0n, _b0); | |||
| _Dn = vmlaq_f32(_Dn, _rows1n, _b1); | |||
| float32x4_t _Dpn = vmulq_f32(_rows0n, _b0); | |||
| _Dpn = vmlaq_f32(_Dpn, _rows1n, _b1); | |||
| vst1q_f32(Dp + 4, _Dn); | |||
| vst1q_f32(Dp + 4, _Dpn); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -106,18 +106,18 @@ static void resize_bilinear_image_bf16s(const Mat& src, Mat& dst, float* alpha, | |||
| float32x4_t _rows0 = vld1q_f32(rows0p); | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _D = vmulq_f32(_rows0, _b0); | |||
| _D = vmlaq_f32(_D, _rows1, _b1); | |||
| float32x4_t _Dp = vmulq_f32(_rows0, _b0); | |||
| _Dp = vmlaq_f32(_Dp, _rows1, _b1); | |||
| vst1_u16(Dp, float2bfloat(_D)); | |||
| vst1_u16(Dp, float2bfloat(_Dp)); | |||
| float32x4_t _rows0n = vld1q_f32(rows0p + 4); | |||
| float32x4_t _rows1n = vld1q_f32(rows1p + 4); | |||
| float32x4_t _Dn = vmulq_f32(_rows0n, _b0); | |||
| _Dn = vmlaq_f32(_Dn, _rows1n, _b1); | |||
| float32x4_t _Dpn = vmulq_f32(_rows0n, _b0); | |||
| _Dpn = vmlaq_f32(_Dpn, _rows1n, _b1); | |||
| vst1_u16(Dp + 4, float2bfloat(_Dn)); | |||
| vst1_u16(Dp + 4, float2bfloat(_Dpn)); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -138,10 +138,10 @@ static void resize_bilinear_image_fp16s(const Mat& src, Mat& dst, float* alpha, | |||
| float32x4_t _rows0 = vld1q_f32(rows0p); | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _D = vmulq_f32(_rows0, _b0); | |||
| _D = vfmaq_f32(_D, _rows1, _b1); | |||
| float32x4_t _Dp = vmulq_f32(_rows0, _b0); | |||
| _Dp = vfmaq_f32(_Dp, _rows1, _b1); | |||
| vst1_f16(Dp, vcvt_f16_f32(_D)); | |||
| vst1_f16(Dp, vcvt_f16_f32(_Dp)); | |||
| float32x4_t _rows0n = vld1q_f32(rows0p + 4); | |||
| float32x4_t _rows1n = vld1q_f32(rows1p + 4); | |||
| @@ -254,10 +254,10 @@ static void resize_bilinear_image_fp16sa(const Mat& src, Mat& dst, __fp16* alpha | |||
| float16x8_t _rows0 = vld1q_f16(rows0p); | |||
| float16x8_t _rows1 = vld1q_f16(rows1p); | |||
| float16x8_t _D = vmulq_f16(_rows0, _b0); | |||
| _D = vfmaq_f16(_D, _rows1, _b1); | |||
| float16x8_t _Dp = vmulq_f16(_rows0, _b0); | |||
| _Dp = vfmaq_f16(_Dp, _rows1, _b1); | |||
| vst1q_f16(Dp, _D); | |||
| vst1q_f16(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -106,9 +106,9 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, | |||
| { | |||
| float32x4_t _rows0 = vld1q_f32(rows0p); | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _D = vmulq_lane_f32(_rows0, _b01, 0); | |||
| _D = vmlaq_lane_f32(_D, _rows1, _b01, 1); | |||
| vst1q_f32(Dp, _D); | |||
| float32x4_t _Dp = vmulq_lane_f32(_rows0, _b01, 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows1, _b01, 1); | |||
| vst1q_f32(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -106,9 +106,9 @@ static void resize_bilinear_image_pack4_bf16s(const Mat& src, Mat& dst, float* a | |||
| { | |||
| float32x4_t _rows0 = vld1q_f32(rows0p); | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _D = vmulq_lane_f32(_rows0, _b01, 0); | |||
| _D = vmlaq_lane_f32(_D, _rows1, _b01, 1); | |||
| vst1_u16(Dp, float2bfloat(_D)); | |||
| float32x4_t _Dp = vmulq_lane_f32(_rows0, _b01, 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows1, _b01, 1); | |||
| vst1_u16(Dp, float2bfloat(_Dp)); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -106,9 +106,9 @@ static void resize_bilinear_image_pack4_fp16s(const Mat& src, Mat& dst, float* a | |||
| { | |||
| float32x4_t _rows0 = vld1q_f32(rows0p); | |||
| float32x4_t _rows1 = vld1q_f32(rows1p); | |||
| float32x4_t _D = vmulq_lane_f32(_rows0, _b01, 0); | |||
| _D = vmlaq_lane_f32(_D, _rows1, _b01, 1); | |||
| vst1_f16(Dp, vcvt_f16_f32(_D)); | |||
| float32x4_t _Dp = vmulq_lane_f32(_rows0, _b01, 0); | |||
| _Dp = vmlaq_lane_f32(_Dp, _rows1, _b01, 1); | |||
| vst1_f16(Dp, vcvt_f16_f32(_Dp)); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -213,9 +213,9 @@ static void resize_bilinear_image_pack4_fp16sa(const Mat& src, Mat& dst, __fp16* | |||
| { | |||
| float16x4_t _rows0 = vld1_f16(rows0p); | |||
| float16x4_t _rows1 = vld1_f16(rows1p); | |||
| float16x4_t _D = vmul_lane_f16(_rows0, _b01, 0); | |||
| _D = vfma_lane_f16(_D, _rows1, _b01, 1); | |||
| vst1_f16(Dp, _D); | |||
| float16x4_t _Dp = vmul_lane_f16(_rows0, _b01, 0); | |||
| _Dp = vfma_lane_f16(_Dp, _rows1, _b01, 1); | |||
| vst1_f16(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -106,9 +106,9 @@ static void resize_bilinear_image_pack8_fp16sa(const Mat& src, Mat& dst, __fp16* | |||
| { | |||
| float16x8_t _rows0 = vld1q_f16(rows0p); | |||
| float16x8_t _rows1 = vld1q_f16(rows1p); | |||
| float16x8_t _D = vmulq_lane_f16(_rows0, _b01, 0); | |||
| _D = vfmaq_lane_f16(_D, _rows1, _b01, 1); | |||
| vst1q_f16(Dp, _D); | |||
| float16x8_t _Dp = vmulq_lane_f16(_rows0, _b01, 0); | |||
| _Dp = vfmaq_lane_f16(_Dp, _rows1, _b01, 1); | |||
| vst1q_f16(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -323,24 +323,24 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); | |||
| float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]); | |||
| float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]); | |||
| float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]); | |||
| float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]); | |||
| float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); | |||
| float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); | |||
| float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); | |||
| float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_I, _G)); | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); | |||
| float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1q_f32(output_data + q, _H); | |||
| vst1q_f32(hidden_ptr + q, _lstm_H); | |||
| vst1q_f32(output_data + q, _lstm_H); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| vst1q_f32(tmp_hidden_ptr + q, _lstm_H); | |||
| } | |||
| } | |||
| #endif // __ARM_NEON | |||
| @@ -778,24 +778,24 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); | |||
| float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]); | |||
| float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]); | |||
| float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]); | |||
| float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]); | |||
| float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); | |||
| float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); | |||
| float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); | |||
| float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_I, _G)); | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); | |||
| float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_u16(output_data + q, float2bfloat(_H)); | |||
| vst1q_f32(hidden_ptr + q, _lstm_H); | |||
| vst1_u16(output_data + q, float2bfloat(_lstm_H)); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| vst1q_f32(tmp_hidden_ptr + q, _lstm_H); | |||
| } | |||
| } | |||
| #endif // __ARM_NEON | |||
| @@ -163,24 +163,24 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); | |||
| float32x4_t _I = sigmoid_ps(_IFOG_4x4.val[0]); | |||
| float32x4_t _F = sigmoid_ps(_IFOG_4x4.val[1]); | |||
| float32x4_t _O = sigmoid_ps(_IFOG_4x4.val[2]); | |||
| float32x4_t _G = tanh_ps(_IFOG_4x4.val[3]); | |||
| float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); | |||
| float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); | |||
| float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); | |||
| float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_I, _G)); | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); | |||
| float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| vst1q_f32(hidden_ptr + q, _lstm_H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_lstm_H)); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| vst1q_f32(tmp_hidden_ptr + q, _lstm_H); | |||
| } | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -503,24 +503,24 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x4x4_t _IFOG_4x4 = vld4_f16(gates_data); | |||
| float32x4_t _I = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[0])); | |||
| float32x4_t _F = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[1])); | |||
| float32x4_t _O = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[2])); | |||
| float32x4_t _G = tanh_ps(vcvt_f32_f16(_IFOG_4x4.val[3])); | |||
| float32x4_t _lstm_I = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[0])); | |||
| float32x4_t _lstm_F = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[1])); | |||
| float32x4_t _lstm_O = sigmoid_ps(vcvt_f32_f16(_IFOG_4x4.val[2])); | |||
| float32x4_t _lstm_G = tanh_ps(vcvt_f32_f16(_IFOG_4x4.val[3])); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_I, _G)); | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); | |||
| float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| vst1q_f32(hidden_ptr + q, _lstm_H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_lstm_H)); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| vst1q_f32(tmp_hidden_ptr + q, _lstm_H); | |||
| } | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -176,7 +176,7 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| const float* weight_xc_ptr = weight_xc.row(q / 4); | |||
| const float* weight_hc_ptr = weight_hc.row(q / 4); | |||
| float32x4_t _H = vld1q_f32((const float*)bias_c + q); | |||
| float32x4_t _rnn_H = vld1q_f32((const float*)bias_c + q); | |||
| float32x4_t _sum1 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum2 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum3 = vdupq_n_f32(0.f); | |||
| @@ -190,12 +190,12 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _weight_xc_2 = vld1q_f32(weight_xc_ptr + 8); | |||
| float32x4_t _weight_xc_3 = vld1q_f32(weight_xc_ptr + 12); | |||
| #if __aarch64__ | |||
| _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0); | |||
| _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); | |||
| #else | |||
| _H = vmlaq_lane_f32(_H, _weight_xc, vget_low_f32(_x), 0); | |||
| _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_xc, vget_low_f32(_x), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1); | |||
| @@ -207,7 +207,7 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| { | |||
| float32x4_t _x = vdupq_n_f32(x[i]); | |||
| float32x4_t _weight_xc = vld1q_f32(weight_xc_ptr); | |||
| _H = vmlaq_f32(_H, _weight_xc, _x); | |||
| _rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x); | |||
| weight_xc_ptr += 4; | |||
| } | |||
| @@ -221,12 +221,12 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| float32x4_t _weight_hc_2 = vld1q_f32(weight_hc_ptr + 8); | |||
| float32x4_t _weight_hc_3 = vld1q_f32(weight_hc_ptr + 12); | |||
| #if __aarch64__ | |||
| _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0); | |||
| _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); | |||
| #else | |||
| _H = vmlaq_lane_f32(_H, _weight_hc, vget_low_f32(_hidden_state), 0); | |||
| _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_hc, vget_low_f32(_hidden_state), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1); | |||
| @@ -238,18 +238,18 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| { | |||
| float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); | |||
| float32x4_t _weight_hc = vld1q_f32(weight_hc_ptr); | |||
| _H = vmlaq_f32(_H, _weight_hc, _hidden_state); | |||
| _rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state); | |||
| weight_hc_ptr += 4; | |||
| } | |||
| _H = vaddq_f32(_H, _sum1); | |||
| _rnn_H = vaddq_f32(_rnn_H, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _H = vaddq_f32(_H, _sum2); | |||
| _rnn_H = vaddq_f32(_rnn_H, _sum2); | |||
| _H = tanh_ps(_H); | |||
| _rnn_H = tanh_ps(_rnn_H); | |||
| vst1q_f32((float*)gates + q, _H); | |||
| vst1q_f32((float*)gates + q, _rnn_H); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -293,10 +293,10 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we | |||
| { | |||
| int q = qq * 4; | |||
| float32x4_t _H = vld1q_f32((float*)gates + q); | |||
| float32x4_t _rnn_H = vld1q_f32((float*)gates + q); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1q_f32(output_data + q, _H); | |||
| vst1q_f32(hidden_ptr + q, _rnn_H); | |||
| vst1q_f32(output_data + q, _rnn_H); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -511,7 +511,7 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| const unsigned short* weight_xc_ptr = weight_xc.row<const unsigned short>(q / 4); | |||
| const unsigned short* weight_hc_ptr = weight_hc.row<const unsigned short>(q / 4); | |||
| float32x4_t _H = bfloat2float(vld1_u16((const unsigned short*)bias_c + q)); | |||
| float32x4_t _rnn_H = bfloat2float(vld1_u16((const unsigned short*)bias_c + q)); | |||
| float32x4_t _sum1 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum2 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum3 = vdupq_n_f32(0.f); | |||
| @@ -525,12 +525,12 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_xc_2 = bfloat2float(vld1_u16(weight_xc_ptr + 8)); | |||
| float32x4_t _weight_xc_3 = bfloat2float(vld1_u16(weight_xc_ptr + 12)); | |||
| #if __aarch64__ | |||
| _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0); | |||
| _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); | |||
| #else | |||
| _H = vmlaq_lane_f32(_H, _weight_xc, vget_low_f32(_x), 0); | |||
| _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_xc, vget_low_f32(_x), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_xc_1, vget_low_f32(_x), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_xc_2, vget_high_f32(_x), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_xc_3, vget_high_f32(_x), 1); | |||
| @@ -542,7 +542,7 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| { | |||
| float32x4_t _x = bfloat2float(vdup_n_u16(x[i])); | |||
| float32x4_t _weight_xc = bfloat2float(vld1_u16(weight_xc_ptr)); | |||
| _H = vmlaq_f32(_H, _weight_xc, _x); | |||
| _rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x); | |||
| weight_xc_ptr += 4; | |||
| } | |||
| @@ -556,12 +556,12 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_hc_2 = bfloat2float(vld1_u16(weight_hc_ptr + 8)); | |||
| float32x4_t _weight_hc_3 = bfloat2float(vld1_u16(weight_hc_ptr + 12)); | |||
| #if __aarch64__ | |||
| _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0); | |||
| _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); | |||
| #else | |||
| _H = vmlaq_lane_f32(_H, _weight_hc, vget_low_f32(_hidden_state), 0); | |||
| _rnn_H = vmlaq_lane_f32(_rnn_H, _weight_hc, vget_low_f32(_hidden_state), 0); | |||
| _sum1 = vmlaq_lane_f32(_sum1, _weight_hc_1, vget_low_f32(_hidden_state), 1); | |||
| _sum2 = vmlaq_lane_f32(_sum2, _weight_hc_2, vget_high_f32(_hidden_state), 0); | |||
| _sum3 = vmlaq_lane_f32(_sum3, _weight_hc_3, vget_high_f32(_hidden_state), 1); | |||
| @@ -573,18 +573,18 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| { | |||
| float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); | |||
| float32x4_t _weight_hc = bfloat2float(vld1_u16(weight_hc_ptr)); | |||
| _H = vmlaq_f32(_H, _weight_hc, _hidden_state); | |||
| _rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state); | |||
| weight_hc_ptr += 4; | |||
| } | |||
| _H = vaddq_f32(_H, _sum1); | |||
| _rnn_H = vaddq_f32(_rnn_H, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _H = vaddq_f32(_H, _sum2); | |||
| _rnn_H = vaddq_f32(_rnn_H, _sum2); | |||
| _H = tanh_ps(_H); | |||
| _rnn_H = tanh_ps(_rnn_H); | |||
| vst1q_f32((float*)gates + q, _H); | |||
| vst1q_f32((float*)gates + q, _rnn_H); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -628,10 +628,10 @@ static int rnn_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| { | |||
| int q = qq * 4; | |||
| float32x4_t _H = vld1q_f32((float*)gates + q); | |||
| float32x4_t _rnn_H = vld1q_f32((float*)gates + q); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_u16(output_data + q, float2bfloat(_H)); | |||
| vst1q_f32(hidden_ptr + q, _rnn_H); | |||
| vst1_u16(output_data + q, float2bfloat(_rnn_H)); | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -54,7 +54,7 @@ static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 4); | |||
| const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 4); | |||
| float32x4_t _H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); | |||
| float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); | |||
| float32x4_t _sum1 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum2 = vdupq_n_f32(0.f); | |||
| float32x4_t _sum3 = vdupq_n_f32(0.f); | |||
| @@ -67,7 +67,7 @@ static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4)); | |||
| float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8)); | |||
| float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12)); | |||
| _H = vfmaq_laneq_f32(_H, _weight_xc, _x, 0); | |||
| _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); | |||
| @@ -78,7 +78,7 @@ static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| { | |||
| float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); | |||
| float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); | |||
| _H = vfmaq_f32(_H, _weight_xc, _x); | |||
| _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); | |||
| weight_xc_ptr += 4; | |||
| } | |||
| @@ -91,7 +91,7 @@ static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4)); | |||
| float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8)); | |||
| float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12)); | |||
| _H = vfmaq_laneq_f32(_H, _weight_hc, _hidden_state, 0); | |||
| _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); | |||
| _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); | |||
| _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); | |||
| _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); | |||
| @@ -102,18 +102,18 @@ static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| { | |||
| float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); | |||
| float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); | |||
| _H = vfmaq_f32(_H, _weight_hc, _hidden_state); | |||
| _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); | |||
| weight_hc_ptr += 4; | |||
| } | |||
| _H = vaddq_f32(_H, _sum1); | |||
| _rnn_H = vaddq_f32(_rnn_H, _sum1); | |||
| _sum2 = vaddq_f32(_sum2, _sum3); | |||
| _H = vaddq_f32(_H, _sum2); | |||
| _rnn_H = vaddq_f32(_rnn_H, _sum2); | |||
| _H = tanh_ps(_H); | |||
| _rnn_H = tanh_ps(_rnn_H); | |||
| vst1q_f32((float*)gates + q, _H); | |||
| vst1q_f32((float*)gates + q, _rnn_H); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| @@ -149,10 +149,10 @@ static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M | |||
| { | |||
| int q = qq * 4; | |||
| float32x4_t _H = vld1q_f32((float*)gates + q); | |||
| float32x4_t _rnn_H = vld1q_f32((float*)gates + q); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| vst1q_f32(hidden_ptr + q, _rnn_H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| @@ -196,7 +196,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8); | |||
| const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8); | |||
| float16x8_t _H = vld1q_f16((const __fp16*)bias_c + q); | |||
| float16x8_t _rnn_H = vld1q_f16((const __fp16*)bias_c + q); | |||
| float16x8_t _sum1 = vdupq_n_f16(0.f); | |||
| float16x8_t _sum2 = vdupq_n_f16(0.f); | |||
| float16x8_t _sum3 = vdupq_n_f16(0.f); | |||
| @@ -209,7 +209,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x8_t _weight_xc_1 = vld1q_f16(weight_xc_ptr + 8); | |||
| float16x8_t _weight_xc_2 = vld1q_f16(weight_xc_ptr + 16); | |||
| float16x8_t _weight_xc_3 = vld1q_f16(weight_xc_ptr + 24); | |||
| _H = vfmaq_lane_f16(_H, _weight_xc, _x, 0); | |||
| _rnn_H = vfmaq_lane_f16(_rnn_H, _weight_xc, _x, 0); | |||
| _sum1 = vfmaq_lane_f16(_sum1, _weight_xc_1, _x, 1); | |||
| _sum2 = vfmaq_lane_f16(_sum2, _weight_xc_2, _x, 2); | |||
| _sum3 = vfmaq_lane_f16(_sum3, _weight_xc_3, _x, 3); | |||
| @@ -220,7 +220,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| { | |||
| float16x8_t _x = vdupq_n_f16(x[i]); | |||
| float16x8_t _weight_xc = vld1q_f16(weight_xc_ptr); | |||
| _H = vfmaq_f16(_H, _weight_xc, _x); | |||
| _rnn_H = vfmaq_f16(_rnn_H, _weight_xc, _x); | |||
| weight_xc_ptr += 8; | |||
| } | |||
| @@ -233,7 +233,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x8_t _weight_hc_1 = vld1q_f16(weight_hc_ptr + 8); | |||
| float16x8_t _weight_hc_2 = vld1q_f16(weight_hc_ptr + 16); | |||
| float16x8_t _weight_hc_3 = vld1q_f16(weight_hc_ptr + 24); | |||
| _H = vfmaq_lane_f16(_H, _weight_hc, _hidden_state, 0); | |||
| _rnn_H = vfmaq_lane_f16(_rnn_H, _weight_hc, _hidden_state, 0); | |||
| _sum1 = vfmaq_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1); | |||
| _sum2 = vfmaq_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2); | |||
| _sum3 = vfmaq_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3); | |||
| @@ -244,17 +244,17 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| { | |||
| float16x8_t _hidden_state = vdupq_n_f16((__fp16)hidden_state[i]); | |||
| float16x8_t _weight_hc = vld1q_f16(weight_hc_ptr); | |||
| _H = vfmaq_f16(_H, _weight_hc, _hidden_state); | |||
| _rnn_H = vfmaq_f16(_rnn_H, _weight_hc, _hidden_state); | |||
| weight_hc_ptr += 8; | |||
| } | |||
| _H = vaddq_f16(_H, _sum1); | |||
| _rnn_H = vaddq_f16(_rnn_H, _sum1); | |||
| _sum2 = vaddq_f16(_sum2, _sum3); | |||
| _H = vaddq_f16(_H, _sum2); | |||
| _rnn_H = vaddq_f16(_rnn_H, _sum2); | |||
| float32x4_t _H32low = tanh_ps(vcvt_f32_f16(vget_low_f16(_H))); | |||
| float32x4_t _H32high = tanh_ps(vcvt_f32_f16(vget_high_f16(_H))); | |||
| float32x4_t _H32low = tanh_ps(vcvt_f32_f16(vget_low_f16(_rnn_H))); | |||
| float32x4_t _H32high = tanh_ps(vcvt_f32_f16(vget_high_f16(_rnn_H))); | |||
| vst1q_f32((float*)gates + q, _H32low); | |||
| vst1q_f32((float*)gates + q + 4, _H32high); | |||
| @@ -268,7 +268,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| const __fp16* weight_xc_ptr = weight_xc.row<const __fp16>(q / 8 + (q % 8) / 4); | |||
| const __fp16* weight_hc_ptr = weight_hc.row<const __fp16>(q / 8 + (q % 8) / 4); | |||
| float16x4_t _H = vld1_f16((const __fp16*)bias_c + q); | |||
| float16x4_t _rnn_H = vld1_f16((const __fp16*)bias_c + q); | |||
| float16x4_t _sum1 = vdup_n_f16(0.f); | |||
| float16x4_t _sum2 = vdup_n_f16(0.f); | |||
| float16x4_t _sum3 = vdup_n_f16(0.f); | |||
| @@ -281,7 +281,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x4_t _weight_xc_1 = vld1_f16(weight_xc_ptr + 4); | |||
| float16x4_t _weight_xc_2 = vld1_f16(weight_xc_ptr + 8); | |||
| float16x4_t _weight_xc_3 = vld1_f16(weight_xc_ptr + 12); | |||
| _H = vfma_lane_f16(_H, _weight_xc, _x, 0); | |||
| _rnn_H = vfma_lane_f16(_rnn_H, _weight_xc, _x, 0); | |||
| _sum1 = vfma_lane_f16(_sum1, _weight_xc_1, _x, 1); | |||
| _sum2 = vfma_lane_f16(_sum2, _weight_xc_2, _x, 2); | |||
| _sum3 = vfma_lane_f16(_sum3, _weight_xc_3, _x, 3); | |||
| @@ -292,7 +292,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| { | |||
| float16x4_t _x = vdup_n_f16(x[i]); | |||
| float16x4_t _weight_xc = vld1_f16(weight_xc_ptr); | |||
| _H = vfma_f16(_H, _weight_xc, _x); | |||
| _rnn_H = vfma_f16(_rnn_H, _weight_xc, _x); | |||
| weight_xc_ptr += 4; | |||
| } | |||
| @@ -305,7 +305,7 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float16x4_t _weight_hc_1 = vld1_f16(weight_hc_ptr + 4); | |||
| float16x4_t _weight_hc_2 = vld1_f16(weight_hc_ptr + 8); | |||
| float16x4_t _weight_hc_3 = vld1_f16(weight_hc_ptr + 12); | |||
| _H = vfma_lane_f16(_H, _weight_hc, _hidden_state, 0); | |||
| _rnn_H = vfma_lane_f16(_rnn_H, _weight_hc, _hidden_state, 0); | |||
| _sum1 = vfma_lane_f16(_sum1, _weight_hc_1, _hidden_state, 1); | |||
| _sum2 = vfma_lane_f16(_sum2, _weight_hc_2, _hidden_state, 2); | |||
| _sum3 = vfma_lane_f16(_sum3, _weight_hc_3, _hidden_state, 3); | |||
| @@ -316,16 +316,16 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| { | |||
| float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]); | |||
| float16x4_t _weight_hc = vld1_f16(weight_hc_ptr); | |||
| _H = vfma_f16(_H, _weight_hc, _hidden_state); | |||
| _rnn_H = vfma_f16(_rnn_H, _weight_hc, _hidden_state); | |||
| weight_hc_ptr += 4; | |||
| } | |||
| _H = vadd_f16(_H, _sum1); | |||
| _rnn_H = vadd_f16(_rnn_H, _sum1); | |||
| _sum2 = vadd_f16(_sum2, _sum3); | |||
| _H = vadd_f16(_H, _sum2); | |||
| _rnn_H = vadd_f16(_rnn_H, _sum2); | |||
| float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_H)); | |||
| float32x4_t _H32 = tanh_ps(vcvt_f32_f16(_rnn_H)); | |||
| vst1q_f32((float*)gates + q, _H32); | |||
| } | |||
| @@ -364,10 +364,10 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| { | |||
| int q = qq * 4; | |||
| float32x4_t _H = vld1q_f32((float*)gates + q); | |||
| float32x4_t _rnn_H = vld1q_f32((float*)gates + q); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| vst1q_f32(hidden_ptr + q, _rnn_H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| @@ -268,11 +268,11 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i | |||
| __m128 _rows1 = (__m128)__lsx_vld(rows1p, 0); | |||
| __m128 _rows2 = (__m128)__lsx_vld(rows2p, 0); | |||
| __m128 _rows3 = (__m128)__lsx_vld(rows3p, 0); | |||
| __m128 _D = __lsx_vfmul_s(_rows0, _b0); | |||
| _D = __lsx_vfmadd_s(_b1, _rows1, _D); | |||
| _D = __lsx_vfmadd_s(_b2, _rows2, _D); | |||
| _D = __lsx_vfmadd_s(_b3, _rows3, _D); | |||
| __lsx_vst(_D, Dp, 0); | |||
| __m128 _Dp = __lsx_vfmul_s(_rows0, _b0); | |||
| _Dp = __lsx_vfmadd_s(_b1, _rows1, _Dp); | |||
| _Dp = __lsx_vfmadd_s(_b2, _rows2, _Dp); | |||
| _Dp = __lsx_vfmadd_s(_b3, _rows3, _Dp); | |||
| __lsx_vst(_Dp, Dp, 0); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -143,18 +143,18 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x | |||
| __m128 _rows0 = (__m128)__lsx_vld(rows0p, 0); | |||
| __m128 _rows1 = (__m128)__lsx_vld(rows1p, 0); | |||
| __m128 _D = __lsx_vfmul_s(_rows0, _b0); | |||
| _D = __lsx_vfmadd_s(_b1, _rows1, _D); | |||
| __m128 _Dp = __lsx_vfmul_s(_rows0, _b0); | |||
| _Dp = __lsx_vfmadd_s(_b1, _rows1, _Dp); | |||
| __lsx_vst(_D, Dp, 0); | |||
| __lsx_vst(_Dp, Dp, 0); | |||
| __m128 _rows0n = (__m128)__lsx_vld(rows0p + 4, 0); | |||
| __m128 _rows1n = (__m128)__lsx_vld(rows1p + 4, 0); | |||
| __m128 _Dn = __lsx_vfmul_s(_rows0n, _b0); | |||
| _Dn = __lsx_vfmadd_s(_b1, _rows1n, _Dn); | |||
| __m128 _Dpn = __lsx_vfmul_s(_rows0n, _b0); | |||
| _Dpn = __lsx_vfmadd_s(_b1, _rows1n, _Dpn); | |||
| __lsx_vst(_Dn, Dp + 4, 0); | |||
| __lsx_vst(_Dpn, Dp + 4, 0); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -109,9 +109,9 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, | |||
| { | |||
| __m128 _rows0 = (__m128)__lsx_vld(rows0p, 0); | |||
| __m128 _rows1 = (__m128)__lsx_vld(rows1p, 0); | |||
| __m128 _D = __lsx_vfmul_s(_rows0, _b0); | |||
| _D = __lsx_vfmadd_s(_b1, _rows1, _D); | |||
| __lsx_vst(_D, Dp, 0); | |||
| __m128 _Dp = __lsx_vfmul_s(_rows0, _b0); | |||
| _Dp = __lsx_vfmadd_s(_b1, _rows1, _Dp); | |||
| __lsx_vst(_Dp, Dp, 0); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -268,11 +268,11 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i | |||
| v4f32 _rows1 = (v4f32)__msa_ld_w(rows1p, 0); | |||
| v4f32 _rows2 = (v4f32)__msa_ld_w(rows2p, 0); | |||
| v4f32 _rows3 = (v4f32)__msa_ld_w(rows3p, 0); | |||
| v4f32 _D = __msa_fmul_w(_rows0, _b0); | |||
| _D = __msa_fmadd_w(_D, _rows1, _b1); | |||
| _D = __msa_fmadd_w(_D, _rows2, _b2); | |||
| _D = __msa_fmadd_w(_D, _rows3, _b3); | |||
| __msa_st_w((v4i32)_D, Dp, 0); | |||
| v4f32 _Dp = __msa_fmul_w(_rows0, _b0); | |||
| _Dp = __msa_fmadd_w(_Dp, _rows1, _b1); | |||
| _Dp = __msa_fmadd_w(_Dp, _rows2, _b2); | |||
| _Dp = __msa_fmadd_w(_Dp, _rows3, _b3); | |||
| __msa_st_w((v4i32)_Dp, Dp, 0); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -143,18 +143,18 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x | |||
| v4f32 _rows0 = (v4f32)__msa_ld_w(rows0p, 0); | |||
| v4f32 _rows1 = (v4f32)__msa_ld_w(rows1p, 0); | |||
| v4f32 _D = __msa_fmul_w(_rows0, _b0); | |||
| _D = __msa_fmadd_w(_D, _rows1, _b1); | |||
| v4f32 _Dp = __msa_fmul_w(_rows0, _b0); | |||
| _Dp = __msa_fmadd_w(_Dp, _rows1, _b1); | |||
| __msa_st_w((v4i32)_D, Dp, 0); | |||
| __msa_st_w((v4i32)_Dp, Dp, 0); | |||
| v4f32 _rows0n = (v4f32)__msa_ld_w(rows0p + 4, 0); | |||
| v4f32 _rows1n = (v4f32)__msa_ld_w(rows1p + 4, 0); | |||
| v4f32 _Dn = __msa_fmul_w(_rows0n, _b0); | |||
| _Dn = __msa_fmadd_w(_Dn, _rows1n, _b1); | |||
| v4f32 _Dpn = __msa_fmul_w(_rows0n, _b0); | |||
| _Dpn = __msa_fmadd_w(_Dpn, _rows1n, _b1); | |||
| __msa_st_w((v4i32)_Dn, Dp + 4, 0); | |||
| __msa_st_w((v4i32)_Dpn, Dp + 4, 0); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -109,9 +109,9 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, | |||
| { | |||
| v4f32 _rows0 = (v4f32)__msa_ld_w(rows0p, 0); | |||
| v4f32 _rows1 = (v4f32)__msa_ld_w(rows1p, 0); | |||
| v4f32 _D = __msa_fmul_w(_rows0, _b0); | |||
| _D = __msa_fmadd_w(_D, _rows1, _b1); | |||
| __msa_st_w((v4i32)_D, Dp, 0); | |||
| v4f32 _Dp = __msa_fmul_w(_rows0, _b0); | |||
| _Dp = __msa_fmadd_w(_Dp, _rows1, _b1); | |||
| __msa_st_w((v4i32)_Dp, Dp, 0); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -226,9 +226,9 @@ static void resize_bicubic_image_packn(const Mat& src, Mat& dst, float* alpha, i | |||
| vfloat32m1_t _rows2 = vle32_v_f32m1(rows2p, vl); | |||
| vfloat32m1_t _rows3 = vle32_v_f32m1(rows3p, vl); | |||
| vfloat32m1_t _D = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_rows0, b0, vl), b1, _rows1, vl), b2, _rows2, vl), b3, _rows3, vl); | |||
| vfloat32m1_t _Dp = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_rows0, b0, vl), b1, _rows1, vl), b2, _rows2, vl), b3, _rows3, vl); | |||
| vse32_v_f32m1(Dp, _D, vl); | |||
| vse32_v_f32m1(Dp, _Dp, vl); | |||
| Dp += packn; | |||
| rows0p += packn; | |||
| @@ -226,9 +226,9 @@ static void resize_bicubic_image_packn_fp16s(const Mat& src, Mat& dst, float* al | |||
| vfloat32m2_t _rows2 = vle32_v_f32m2(rows2p, vl); | |||
| vfloat32m2_t _rows3 = vle32_v_f32m2(rows3p, vl); | |||
| vfloat32m2_t _D = vfmacc_vf_f32m2(vfmacc_vf_f32m2(vfmacc_vf_f32m2(vfmul_vf_f32m2(_rows0, b0, vl), b1, _rows1, vl), b2, _rows2, vl), b3, _rows3, vl); | |||
| vfloat32m2_t _Dp = vfmacc_vf_f32m2(vfmacc_vf_f32m2(vfmacc_vf_f32m2(vfmul_vf_f32m2(_rows0, b0, vl), b1, _rows1, vl), b2, _rows2, vl), b3, _rows3, vl); | |||
| vse16_v_f16m1(Dp, vfncvt_f_f_w_f16m1(_D, vl), vl); | |||
| vse16_v_f16m1(Dp, vfncvt_f_f_w_f16m1(_Dp, vl), vl); | |||
| Dp += packn; | |||
| rows0p += packn; | |||
| @@ -455,9 +455,9 @@ static void resize_bicubic_image_packn_fp16sa(const Mat& src, Mat& dst, __fp16* | |||
| vfloat16m1_t _rows2 = vle16_v_f16m1(rows2p, vl); | |||
| vfloat16m1_t _rows3 = vle16_v_f16m1(rows3p, vl); | |||
| vfloat16m1_t _D = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_rows0, b0, vl), b1, _rows1, vl), b2, _rows2, vl), b3, _rows3, vl); | |||
| vfloat16m1_t _Dp = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_rows0, b0, vl), b1, _rows1, vl), b2, _rows2, vl), b3, _rows3, vl); | |||
| vse16_v_f16m1(Dp, _D, vl); | |||
| vse16_v_f16m1(Dp, _Dp, vl); | |||
| Dp += packn; | |||
| rows0p += packn; | |||
| @@ -200,9 +200,9 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x | |||
| vfloat32m8_t _rows0 = vle32_v_f32m8(rows0p, vl); | |||
| vfloat32m8_t _rows1 = vle32_v_f32m8(rows1p, vl); | |||
| vfloat32m8_t _D = vfmacc_vf_f32m8(vfmul_vf_f32m8(_rows0, b0, vl), b1, _rows1, vl); | |||
| vfloat32m8_t _Dp = vfmacc_vf_f32m8(vfmul_vf_f32m8(_rows0, b0, vl), b1, _rows1, vl); | |||
| vse32_v_f32m8(Dp, _D, vl); | |||
| vse32_v_f32m8(Dp, _Dp, vl); | |||
| Dp += vl; | |||
| rows0p += vl; | |||
| @@ -136,9 +136,9 @@ static void resize_bilinear_image_fp16s(const Mat& src, Mat& dst, float* alpha, | |||
| vfloat32m8_t _rows0 = vle32_v_f32m8(rows0p, vl); | |||
| vfloat32m8_t _rows1 = vle32_v_f32m8(rows1p, vl); | |||
| vfloat32m8_t _D = vfmacc_vf_f32m8(vfmul_vf_f32m8(_rows0, b0, vl), b1, _rows1, vl); | |||
| vfloat32m8_t _Dp = vfmacc_vf_f32m8(vfmul_vf_f32m8(_rows0, b0, vl), b1, _rows1, vl); | |||
| vse16_v_f16m4(Dp, vfncvt_f_f_w_f16m4(_D, vl), vl); | |||
| vse16_v_f16m4(Dp, vfncvt_f_f_w_f16m4(_Dp, vl), vl); | |||
| Dp += vl; | |||
| rows0p += vl; | |||
| @@ -237,9 +237,9 @@ static void resize_bilinear_image_fp16sa(const Mat& src, Mat& dst, __fp16* alpha | |||
| vfloat16m8_t _rows0 = vle16_v_f16m8(rows0p, vl); | |||
| vfloat16m8_t _rows1 = vle16_v_f16m8(rows1p, vl); | |||
| vfloat16m8_t _D = vfmacc_vf_f16m8(vfmul_vf_f16m8(_rows0, b0, vl), b1, _rows1, vl); | |||
| vfloat16m8_t _Dp = vfmacc_vf_f16m8(vfmul_vf_f16m8(_rows0, b0, vl), b1, _rows1, vl); | |||
| vse16_v_f16m8(Dp, _D, vl); | |||
| vse16_v_f16m8(Dp, _Dp, vl); | |||
| Dp += vl; | |||
| rows0p += vl; | |||
| @@ -106,9 +106,9 @@ static void resize_bilinear_image_packn(const Mat& src, Mat& dst, float* alpha, | |||
| vfloat32m1_t _rows0 = vle32_v_f32m1(rows0p, vl); | |||
| vfloat32m1_t _rows1 = vle32_v_f32m1(rows1p, vl); | |||
| vfloat32m1_t _D = vfmacc_vf_f32m1(vfmul_vf_f32m1(_rows0, b0, vl), b1, _rows1, vl); | |||
| vfloat32m1_t _Dp = vfmacc_vf_f32m1(vfmul_vf_f32m1(_rows0, b0, vl), b1, _rows1, vl); | |||
| vse32_v_f32m1(Dp, _D, vl); | |||
| vse32_v_f32m1(Dp, _Dp, vl); | |||
| Dp += packn; | |||
| rows0p += packn; | |||
| @@ -106,9 +106,9 @@ static void resize_bilinear_image_packn_fp16s(const Mat& src, Mat& dst, float* a | |||
| vfloat32m2_t _rows0 = vle32_v_f32m2(rows0p, vl); | |||
| vfloat32m2_t _rows1 = vle32_v_f32m2(rows1p, vl); | |||
| vfloat32m2_t _D = vfmacc_vf_f32m2(vfmul_vf_f32m2(_rows0, b0, vl), b1, _rows1, vl); | |||
| vfloat32m2_t _Dp = vfmacc_vf_f32m2(vfmul_vf_f32m2(_rows0, b0, vl), b1, _rows1, vl); | |||
| vse16_v_f16m1(Dp, vfncvt_f_f_w_f16m1(_D, vl), vl); | |||
| vse16_v_f16m1(Dp, vfncvt_f_f_w_f16m1(_Dp, vl), vl); | |||
| Dp += packn; | |||
| rows0p += packn; | |||
| @@ -213,9 +213,9 @@ static void resize_bilinear_image_packn_fp16sa(const Mat& src, Mat& dst, __fp16* | |||
| vfloat16m1_t _rows0 = vle16_v_f16m1(rows0p, vl); | |||
| vfloat16m1_t _rows1 = vle16_v_f16m1(rows1p, vl); | |||
| vfloat16m1_t _D = vfmacc_vf_f16m1(vfmul_vf_f16m1(_rows0, b0, vl), b1, _rows1, vl); | |||
| vfloat16m1_t _Dp = vfmacc_vf_f16m1(vfmul_vf_f16m1(_rows0, b0, vl), b1, _rows1, vl); | |||
| vse16_v_f16m1(Dp, _D, vl); | |||
| vse16_v_f16m1(Dp, _Dp, vl); | |||
| Dp += packn; | |||
| rows0p += packn; | |||
| @@ -264,11 +264,11 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo | |||
| __m256 _rows1 = _mm256_loadu_ps(rows1p); | |||
| __m256 _rows2 = _mm256_loadu_ps(rows2p); | |||
| __m256 _rows3 = _mm256_loadu_ps(rows3p); | |||
| __m256 _D = _mm256_mul_ps(_rows0, _b0_256); | |||
| _D = _mm256_comp_fmadd_ps(_rows1, _b1_256, _D); | |||
| _D = _mm256_comp_fmadd_ps(_rows2, _b2_256, _D); | |||
| _D = _mm256_comp_fmadd_ps(_rows3, _b3_256, _D); | |||
| _mm256_storeu_ps(Dp, _D); | |||
| __m256 _Dp = _mm256_mul_ps(_rows0, _b0_256); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows1, _b1_256, _Dp); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows2, _b2_256, _Dp); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows3, _b3_256, _Dp); | |||
| _mm256_storeu_ps(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -287,11 +287,11 @@ static void resize_bicubic_image(const Mat& src, Mat& dst, float* alpha, int* xo | |||
| __m128 _rows1 = _mm_loadu_ps(rows1p); | |||
| __m128 _rows2 = _mm_loadu_ps(rows2p); | |||
| __m128 _rows3 = _mm_loadu_ps(rows3p); | |||
| __m128 _D = _mm_mul_ps(_rows0, _b0_128); | |||
| _D = _mm_comp_fmadd_ps(_rows1, _b1_128, _D); | |||
| _D = _mm_comp_fmadd_ps(_rows2, _b2_128, _D); | |||
| _D = _mm_comp_fmadd_ps(_rows3, _b3_128, _D); | |||
| _mm_storeu_ps(Dp, _D); | |||
| __m128 _Dp = _mm_mul_ps(_rows0, _b0_128); | |||
| _Dp = _mm_comp_fmadd_ps(_rows1, _b1_128, _Dp); | |||
| _Dp = _mm_comp_fmadd_ps(_rows2, _b2_128, _Dp); | |||
| _Dp = _mm_comp_fmadd_ps(_rows3, _b3_128, _Dp); | |||
| _mm_storeu_ps(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -268,11 +268,11 @@ static void resize_bicubic_image_pack16(const Mat& src, Mat& dst, float* alpha, | |||
| __m512 _rows1 = _mm512_load_ps(rows1p); | |||
| __m512 _rows2 = _mm512_load_ps(rows2p); | |||
| __m512 _rows3 = _mm512_load_ps(rows3p); | |||
| __m512 _D = _mm512_mul_ps(_rows0, _b0); | |||
| _D = _mm512_fmadd_ps(_rows1, _b1, _D); | |||
| _D = _mm512_fmadd_ps(_rows2, _b2, _D); | |||
| _D = _mm512_fmadd_ps(_rows3, _b3, _D); | |||
| _mm512_store_ps(Dp, _D); | |||
| __m512 _Dp = _mm512_mul_ps(_rows0, _b0); | |||
| _Dp = _mm512_fmadd_ps(_rows1, _b1, _Dp); | |||
| _Dp = _mm512_fmadd_ps(_rows2, _b2, _Dp); | |||
| _Dp = _mm512_fmadd_ps(_rows3, _b3, _Dp); | |||
| _mm512_store_ps(Dp, _Dp); | |||
| Dp += 16; | |||
| rows0p += 16; | |||
| @@ -268,11 +268,11 @@ static void resize_bicubic_image_pack4(const Mat& src, Mat& dst, float* alpha, i | |||
| __m128 _rows1 = _mm_load_ps(rows1p); | |||
| __m128 _rows2 = _mm_load_ps(rows2p); | |||
| __m128 _rows3 = _mm_load_ps(rows3p); | |||
| __m128 _D = _mm_mul_ps(_rows0, _b0); | |||
| _D = _mm_comp_fmadd_ps(_rows1, _b1, _D); | |||
| _D = _mm_comp_fmadd_ps(_rows2, _b2, _D); | |||
| _D = _mm_comp_fmadd_ps(_rows3, _b3, _D); | |||
| _mm_store_ps(Dp, _D); | |||
| __m128 _Dp = _mm_mul_ps(_rows0, _b0); | |||
| _Dp = _mm_comp_fmadd_ps(_rows1, _b1, _Dp); | |||
| _Dp = _mm_comp_fmadd_ps(_rows2, _b2, _Dp); | |||
| _Dp = _mm_comp_fmadd_ps(_rows3, _b3, _Dp); | |||
| _mm_store_ps(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -268,11 +268,11 @@ static void resize_bicubic_image_pack8(const Mat& src, Mat& dst, float* alpha, i | |||
| __m256 _rows1 = _mm256_load_ps(rows1p); | |||
| __m256 _rows2 = _mm256_load_ps(rows2p); | |||
| __m256 _rows3 = _mm256_load_ps(rows3p); | |||
| __m256 _D = _mm256_mul_ps(_rows0, _b0); | |||
| _D = _mm256_comp_fmadd_ps(_rows1, _b1, _D); | |||
| _D = _mm256_comp_fmadd_ps(_rows2, _b2, _D); | |||
| _D = _mm256_comp_fmadd_ps(_rows3, _b3, _D); | |||
| _mm256_store_ps(Dp, _D); | |||
| __m256 _Dp = _mm256_mul_ps(_rows0, _b0); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows1, _b1, _Dp); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows2, _b2, _Dp); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows3, _b3, _Dp); | |||
| _mm256_store_ps(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -137,9 +137,9 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x | |||
| { | |||
| __m256 _rows0 = _mm256_loadu_ps(rows0p); | |||
| __m256 _rows1 = _mm256_loadu_ps(rows1p); | |||
| __m256 _D = _mm256_mul_ps(_rows0, _b0_256); | |||
| _D = _mm256_comp_fmadd_ps(_rows1, _b1_256, _D); | |||
| _mm256_storeu_ps(Dp, _D); | |||
| __m256 _Dp = _mm256_mul_ps(_rows0, _b0_256); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows1, _b1_256, _Dp); | |||
| _mm256_storeu_ps(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -152,9 +152,9 @@ static void resize_bilinear_image(const Mat& src, Mat& dst, float* alpha, int* x | |||
| { | |||
| __m128 _rows0 = _mm_loadu_ps(rows0p); | |||
| __m128 _rows1 = _mm_loadu_ps(rows1p); | |||
| __m128 _D = _mm_mul_ps(_rows0, _b0_128); | |||
| _D = _mm_comp_fmadd_ps(_rows1, _b1_128, _D); | |||
| _mm_storeu_ps(Dp, _D); | |||
| __m128 _Dp = _mm_mul_ps(_rows0, _b0_128); | |||
| _Dp = _mm_comp_fmadd_ps(_rows1, _b1_128, _Dp); | |||
| _mm_storeu_ps(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -109,9 +109,9 @@ static void resize_bilinear_image_pack16(const Mat& src, Mat& dst, float* alpha, | |||
| { | |||
| __m512 _rows0 = _mm512_load_ps(rows0p); | |||
| __m512 _rows1 = _mm512_load_ps(rows1p); | |||
| __m512 _D = _mm512_mul_ps(_rows0, _b0); | |||
| _D = _mm512_fmadd_ps(_rows1, _b1, _D); | |||
| _mm512_store_ps(Dp, _D); | |||
| __m512 _Dp = _mm512_mul_ps(_rows0, _b0); | |||
| _Dp = _mm512_fmadd_ps(_rows1, _b1, _Dp); | |||
| _mm512_store_ps(Dp, _Dp); | |||
| Dp += 16; | |||
| rows0p += 16; | |||
| @@ -109,9 +109,9 @@ static void resize_bilinear_image_pack4(const Mat& src, Mat& dst, float* alpha, | |||
| { | |||
| __m128 _rows0 = _mm_load_ps(rows0p); | |||
| __m128 _rows1 = _mm_load_ps(rows1p); | |||
| __m128 _D = _mm_mul_ps(_rows0, _b0); | |||
| _D = _mm_comp_fmadd_ps(_rows1, _b1, _D); | |||
| _mm_store_ps(Dp, _D); | |||
| __m128 _Dp = _mm_mul_ps(_rows0, _b0); | |||
| _Dp = _mm_comp_fmadd_ps(_rows1, _b1, _Dp); | |||
| _mm_store_ps(Dp, _Dp); | |||
| Dp += 4; | |||
| rows0p += 4; | |||
| @@ -109,9 +109,9 @@ static void resize_bilinear_image_pack8(const Mat& src, Mat& dst, float* alpha, | |||
| { | |||
| __m256 _rows0 = _mm256_load_ps(rows0p); | |||
| __m256 _rows1 = _mm256_load_ps(rows1p); | |||
| __m256 _D = _mm256_mul_ps(_rows0, _b0); | |||
| _D = _mm256_comp_fmadd_ps(_rows1, _b1, _D); | |||
| _mm256_store_ps(Dp, _D); | |||
| __m256 _Dp = _mm256_mul_ps(_rows0, _b0); | |||
| _Dp = _mm256_comp_fmadd_ps(_rows1, _b1, _Dp); | |||
| _mm256_store_ps(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -474,24 +474,24 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); | |||
| __m128 _I = sigmoid_sse(_IFOG_4x4_0); | |||
| __m128 _F = sigmoid_sse(_IFOG_4x4_1); | |||
| __m128 _O = sigmoid_sse(_IFOG_4x4_2); | |||
| __m128 _G = tanh_sse(_IFOG_4x4_3); | |||
| __m128 _lstm_I = sigmoid_sse(_IFOG_4x4_0); | |||
| __m128 _lstm_F = sigmoid_sse(_IFOG_4x4_1); | |||
| __m128 _lstm_O = sigmoid_sse(_IFOG_4x4_2); | |||
| __m128 _lstm_G = tanh_sse(_IFOG_4x4_3); | |||
| __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_I, _G)); | |||
| __m128 _H = _mm_mul_ps(_O, tanh_sse(_cell2)); | |||
| __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_lstm_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_lstm_I, _lstm_G)); | |||
| __m128 _lstm_H = _mm_mul_ps(_lstm_O, tanh_sse(_cell2)); | |||
| _mm_storeu_ps(cell_ptr + q, _cell2); | |||
| if (num_output == hidden_size) | |||
| { | |||
| _mm_storeu_ps(hidden_ptr + q, _H); | |||
| _mm_storeu_ps(output_data + q, _H); | |||
| _mm_storeu_ps(hidden_ptr + q, _lstm_H); | |||
| _mm_storeu_ps(output_data + q, _lstm_H); | |||
| } | |||
| else | |||
| { | |||
| _mm_storeu_ps(tmp_hidden_ptr + q, _H); | |||
| _mm_storeu_ps(tmp_hidden_ptr + q, _lstm_H); | |||
| } | |||
| } | |||
| #else // __SSE2__ | |||
| @@ -229,9 +229,9 @@ void resize_bilinear_c1(const unsigned char* src, int srcw, int srch, int srcstr | |||
| int16x4_t _acc16 = vshrn_n_s32(_acc, 2); | |||
| int16x4_t _acc16_1 = vshrn_n_s32(_acc_1, 2); | |||
| uint8x8_t _D = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| uint8x8_t _Dp = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| vst1_u8(Dp, _D); | |||
| vst1_u8(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -538,9 +538,9 @@ void resize_bilinear_c2(const unsigned char* src, int srcw, int srch, int srcstr | |||
| int16x4_t _acc16 = vshrn_n_s32(_acc, 2); | |||
| int16x4_t _acc16_1 = vshrn_n_s32(_acc_1, 2); | |||
| uint8x8_t _D = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| uint8x8_t _Dp = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| vst1_u8(Dp, _D); | |||
| vst1_u8(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -858,9 +858,9 @@ void resize_bilinear_c3(const unsigned char* src, int srcw, int srch, int srcstr | |||
| int16x4_t _acc16 = vshrn_n_s32(_acc, 2); | |||
| int16x4_t _acc16_1 = vshrn_n_s32(_acc_1, 2); | |||
| uint8x8_t _D = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| uint8x8_t _Dp = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| vst1_u8(Dp, _D); | |||
| vst1_u8(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -1158,9 +1158,9 @@ void resize_bilinear_c4(const unsigned char* src, int srcw, int srch, int srcstr | |||
| int16x4_t _acc16 = vshrn_n_s32(_acc, 2); | |||
| int16x4_t _acc16_1 = vshrn_n_s32(_acc_1, 2); | |||
| uint8x8_t _D = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| uint8x8_t _Dp = vqmovun_s16(vcombine_s16(_acc16, _acc16_1)); | |||
| vst1_u8(Dp, _D); | |||
| vst1_u8(Dp, _Dp); | |||
| Dp += 8; | |||
| rows0p += 8; | |||
| @@ -20,11 +20,12 @@ class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| def forward(self, x, y, z, w): | |||
| x, x_indices = torch.max(x, dim=1, keepdim=False) | |||
| y = torch.max(y) | |||
| w = torch.max(z, w) | |||
| z, z_indices = torch.max(z, dim=0, keepdim=True) | |||
| return x, x_indices, y, z, z_indices | |||
| return x, x_indices, y, z, z_indices, w | |||
| def test(): | |||
| net = Model() | |||
| @@ -34,16 +35,17 @@ def test(): | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| w = torch.rand(5, 9, 10) | |||
| a = net(x, y, z) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torch_max.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_max.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| os.system("../src/pnnx test_torch_max.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_max_pnnx | |||
| @@ -20,11 +20,12 @@ class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| def forward(self, x, y, z, w): | |||
| x, x_indices = torch.min(x, dim=1, keepdim=False) | |||
| y = torch.min(y) | |||
| w = torch.min(z, w) | |||
| z, z_indices = torch.min(z, dim=0, keepdim=True) | |||
| return x, x_indices, y, z, z_indices | |||
| return x, x_indices, y, z, z_indices, w | |||
| def test(): | |||
| net = Model() | |||
| @@ -34,16 +35,17 @@ def test(): | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| w = torch.rand(5, 9, 10) | |||
| a = net(x, y, z) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torch_min.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_min.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| os.system("../src/pnnx test_torch_min.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_min_pnnx | |||