|
|
|
@@ -18,14 +18,31 @@ |
|
|
|
#include "nnacl/errorcode.h" |
|
|
|
|
|
|
|
int Fp32Relu(const float *src, int length, float *dst) { |
|
|
|
for (int i = 0; i < length; ++i) { |
|
|
|
int i = 0; |
|
|
|
#ifdef ENABLE_ARM |
|
|
|
float32x4_t zero_4 = vdupq_n_f32(0.0f); |
|
|
|
for (; i < length - 4; i += 4) { |
|
|
|
vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4)); |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; i < length; ++i) { |
|
|
|
dst[i] = src[i] > 0 ? src[i] : 0; |
|
|
|
} |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int Fp32Relu6(const float *src, int length, float *dst) { |
|
|
|
for (int i = 0; i < length; ++i) { |
|
|
|
int i = 0; |
|
|
|
#ifdef ENABLE_ARM |
|
|
|
float32x4_t zero_4 = vdupq_n_f32(0.0f); |
|
|
|
float32x4_t six_4 = vdupq_n_f32(6.0f); |
|
|
|
for (; i < length - 4; i += 4) { |
|
|
|
float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4); |
|
|
|
dst_4 = vminq_f32(dst_4, six_4); |
|
|
|
vst1q_f32(dst + i, dst_4); |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; i < length; ++i) { |
|
|
|
if (src[i] < 0) { |
|
|
|
dst[i] = 0; |
|
|
|
} else { |
|
|
|
@@ -36,7 +53,18 @@ int Fp32Relu6(const float *src, int length, float *dst) { |
|
|
|
} |
|
|
|
|
|
|
|
int LRelu(const float *src, int length, float *dst, float alpha) { |
|
|
|
for (int i = 0; i < length; ++i) { |
|
|
|
int i = 0; |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
float32x4_t alpha_4 = vdupq_n_f32(alpha); |
|
|
|
for (; i < length - 4; i += 4) { |
|
|
|
float32x4_t src_4 = vld1q_f32(src + i); |
|
|
|
float32x4_t mul_4 = vmulq_f32(src_4, alpha_4); |
|
|
|
uint32x4_t flag = vclezq_f32(src_4); |
|
|
|
float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4); |
|
|
|
vst1q_f32(dst + i, dst_4); |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; i < length; ++i) { |
|
|
|
dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); |
|
|
|
} |
|
|
|
return NNACL_OK; |
|
|
|
|