| @@ -371,6 +371,69 @@ v4sf tan_ps_f32(v4sf x) { | |||||
| #undef c_cephes_log_q1 | #undef c_cephes_log_q1 | ||||
| #undef c_cephes_log_q2 | #undef c_cephes_log_q2 | ||||
| static const struct { | |||||
| float lower_range; | |||||
| float upper_range; | |||||
| float alpha_9; | |||||
| float alpha_7; | |||||
| float alpha_5; | |||||
| float alpha_3; | |||||
| float alpha_1; | |||||
| float beta_10; | |||||
| float beta_8; | |||||
| float beta_6; | |||||
| float beta_4; | |||||
| float beta_2; | |||||
| float beta_0; | |||||
| float one_half; | |||||
| } sigmoid_constants = { | |||||
| -18.0f, | |||||
| 18.0f, | |||||
| 4.37031012579801e-11f, | |||||
| 1.15627324459942e-07f, | |||||
| 6.08574864600143e-05f, | |||||
| 8.51377133304701e-03f, | |||||
| 2.48287947061529e-01f, | |||||
| 6.10247389755681e-13f, | |||||
| 5.76102136993427e-09f, | |||||
| 6.29106785017040e-06f, | |||||
| 1.70198817374094e-03f, | |||||
| 1.16817656904453e-01f, | |||||
| 9.93151921023180e-01f, | |||||
| 0.5f, | |||||
| }; | |||||
| v4sf sigmoid_ps_f32(v4sf src) { | |||||
| auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src); | |||||
| val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val); | |||||
| auto squared = vmulq_f32(val, val); | |||||
| auto p = vmlaq_f32( | |||||
| vdupq_n_f32(sigmoid_constants.alpha_7), squared, | |||||
| vdupq_n_f32(sigmoid_constants.alpha_9)); | |||||
| p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared); | |||||
| p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared); | |||||
| p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared); | |||||
| p = vmulq_f32(p, val); | |||||
| auto q = vmlaq_f32( | |||||
| vdupq_n_f32(sigmoid_constants.beta_8), squared, | |||||
| vdupq_n_f32(sigmoid_constants.beta_10)); | |||||
| q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared); | |||||
| q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared); | |||||
| q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared); | |||||
| q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared); | |||||
| return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half)); | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| float16x8_t sigmoid_ps_f16(float16x8_t x) { | |||||
| float32x4_t low = vcvt_f32_f16(vget_low_f16(x)); | |||||
| float32x4_t high = vcvt_f32_f16(vget_high_f16(x)); | |||||
| low = sigmoid_ps_f32(low); | |||||
| high = sigmoid_ps_f32(high); | |||||
| return vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high)); | |||||
| } | |||||
| #endif | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -54,11 +54,38 @@ v4sf cos_ps_f32(v4sf x); | |||||
| v4sf tan_ps_f32(v4sf x); | v4sf tan_ps_f32(v4sf x); | ||||
| static inline v4sf div_ps_f32(v4sf x, v4sf y) { | |||||
| #if MEGDNN_AARCH64 | |||||
| return vdivq_f32(x, y); | |||||
| #else | |||||
| //! armv7 not support vdiv, so compute the reciprocal and iterate again | |||||
| float32x4_t recp = vrecpeq_f32(y); | |||||
| recp = vmulq_f32(vrecpsq_f32(y, recp), recp); | |||||
| return vmulq_f32(x, recp); | |||||
| #endif | |||||
| } | |||||
| v4sf sigmoid_ps_f32(v4sf x); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| /** | /** | ||||
| * \brief compute for 8 half at once, the inner just invoke exp_ps_f32 twice | * \brief compute for 8 half at once, the inner just invoke exp_ps_f32 twice | ||||
| */ | */ | ||||
| float16x8_t exp_ps_f16(float16x8_t x); | float16x8_t exp_ps_f16(float16x8_t x); | ||||
| static inline float16x8_t div_ps_f16(float16x8_t x, float16x8_t y) { | |||||
| #if MEGDNN_AARCH64 | |||||
| return vdivq_f16(x, y); | |||||
| #else | |||||
| //! armv7 not support vdiv, so compute the reciprocal and iterate again | |||||
| float16x8_t recp = vrecpeq_f16(y); | |||||
| recp = vmulq_f16(vrecpsq_f16(y, recp), recp); | |||||
| return vmulq_f16(x, recp); | |||||
| #endif | |||||
| } | |||||
| float16x8_t sigmoid_ps_f16(float16x8_t x); | |||||
| #endif | #endif | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -47,24 +47,14 @@ struct FuseAddSigmoidOp; | |||||
| vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | ||||
| } \ | } \ | ||||
| _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ | _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ | ||||
| auto zero_val = vdupq_n_##_func_suffix(0.f); \ | |||||
| auto one_val = vdupq_n_##_func_suffix(1.f); \ | |||||
| auto val1 = src0.val[0]; \ | auto val1 = src0.val[0]; \ | ||||
| auto val2 = src0.val[1]; \ | auto val2 = src0.val[1]; \ | ||||
| auto val3 = src1.val[0]; \ | auto val3 = src1.val[0]; \ | ||||
| auto val4 = src1.val[1]; \ | auto val4 = src1.val[1]; \ | ||||
| val1 = vaddq_##_func_suffix(val1, val3); \ | val1 = vaddq_##_func_suffix(val1, val3); \ | ||||
| val2 = vaddq_##_func_suffix(val2, val4); \ | val2 = vaddq_##_func_suffix(val2, val4); \ | ||||
| val1 = vsubq_##_func_suffix(zero_val, val1); \ | |||||
| val2 = vsubq_##_func_suffix(zero_val, val2); \ | |||||
| val1 = exp_ps_##_func_suffix(val1); \ | |||||
| val2 = exp_ps_##_func_suffix(val2); \ | |||||
| auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ | |||||
| auto recipe2 = vaddq_##_func_suffix(one_val, val2); \ | |||||
| val1 = vrecpeq_##_func_suffix(recipe1); \ | |||||
| val2 = vrecpeq_##_func_suffix(recipe2); \ | |||||
| val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \ | |||||
| val2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe2, val2), val2); \ | |||||
| val1 = sigmoid_ps_##_func_suffix(val1); \ | |||||
| val2 = sigmoid_ps_##_func_suffix(val2); \ | |||||
| return {{val1, val2}}; \ | return {{val1, val2}}; \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| @@ -33,34 +33,27 @@ struct SigmoidOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||||
| template <typename src_ctype, typename dst_ctype = src_ctype> | template <typename src_ctype, typename dst_ctype = src_ctype> | ||||
| struct SigmoidOp; | struct SigmoidOp; | ||||
| #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | |||||
| template <> \ | |||||
| struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ | |||||
| using SigmoidOpBase::SigmoidOpBase; \ | |||||
| using SigmoidOpBase::operator(); \ | |||||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||||
| void operator()(const _neon_type2& src, _ctype* dst) const { \ | |||||
| auto vitem = operator()(src); \ | |||||
| vst1q_##_func_suffix(dst, vitem.val[0]); \ | |||||
| vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||||
| } \ | |||||
| void operator()(const _neon_type& src, _ctype* dst) const { \ | |||||
| auto vitem = operator()(src); \ | |||||
| vst1q_##_func_suffix(dst, vitem); \ | |||||
| } \ | |||||
| _neon_type2 operator()(const _neon_type2& src) const { \ | |||||
| return {{operator()(src.val[0]), operator()(src.val[1])}}; \ | |||||
| } \ | |||||
| _neon_type operator()(const _neon_type& src) const { \ | |||||
| auto zero_val = vdupq_n_##_func_suffix(0.f); \ | |||||
| auto one_val = vdupq_n_##_func_suffix(1.f); \ | |||||
| auto val1 = vsubq_##_func_suffix(zero_val, src); \ | |||||
| val1 = exp_ps_##_func_suffix(val1); \ | |||||
| auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ | |||||
| val1 = vrecpeq_##_func_suffix(recipe1); \ | |||||
| val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \ | |||||
| return val1; \ | |||||
| } \ | |||||
| #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | |||||
| template <> \ | |||||
| struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ | |||||
| using SigmoidOpBase::SigmoidOpBase; \ | |||||
| using SigmoidOpBase::operator(); \ | |||||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||||
| void operator()(const _neon_type2& src, _ctype* dst) const { \ | |||||
| auto vitem = operator()(src); \ | |||||
| vst1q_##_func_suffix(dst, vitem.val[0]); \ | |||||
| vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||||
| } \ | |||||
| void operator()(const _neon_type& src, _ctype* dst) const { \ | |||||
| auto vitem = operator()(src); \ | |||||
| vst1q_##_func_suffix(dst, vitem); \ | |||||
| } \ | |||||
| _neon_type2 operator()(const _neon_type2& src) const { \ | |||||
| return {{operator()(src.val[0]), operator()(src.val[1])}}; \ | |||||
| } \ | |||||
| _neon_type operator()(const _neon_type& src) const { \ | |||||
| return sigmoid_ps_##_func_suffix(src); \ | |||||
| } \ | |||||
| }; | }; | ||||
| OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| @@ -318,7 +318,7 @@ def test_add_remove_output(): | |||||
| out = g.run(a.numpy(), b.numpy()) | out = g.run(a.numpy(), b.numpy()) | ||||
| np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy()) | np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy()) | ||||
| np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) | |||||
| np.testing.assert_almost_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) | |||||
| def test_query(): | def test_query(): | ||||