|
|
|
@@ -15,104 +15,130 @@ |
|
|
|
*/ |
|
|
|
#include "nnacl/fp32/prelu_fp32.h" |
|
|
|
|
|
|
|
void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) { |
|
|
|
int plane_tile = plane / TILE_NUM * TILE_NUM; |
|
|
|
int channel_num = prelu_param_->channel_num_; |
|
|
|
int plane_index = 0; |
|
|
|
for (; plane_index < plane_tile; plane_index += TILE_NUM) { |
|
|
|
float *in_plane_ptr = input + plane_index * channel_num; |
|
|
|
float *out_plane_ptr = output + plane_index * channel_num; |
|
|
|
int channel_index = 0; |
|
|
|
#if defined(ENABLE_AVX) |
|
|
|
MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0.0f); |
|
|
|
MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f); |
|
|
|
float *negetive_slope_value_8 = prelu_param_->slope_; |
|
|
|
int div_channel_c8 = prelu_param_->channel_num_ / C8NUM * C8NUM; |
|
|
|
for (; channel_index < div_channel_c8; channel_index += C8NUM) { |
|
|
|
MS_FLOAT32X8 slope_value_8 = MS_LD256_F32(negetive_slope_value_8 + channel_index); |
|
|
|
LOAD256X8_F32(src, in_plane_ptr + channel_index, channel_num) |
|
|
|
PRELU_CALCULATE_256X8(dst, src) |
|
|
|
STORE256X8_F32(out_plane_ptr + channel_index, channel_num, dst) |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
inline void PRelu4x16(const float *in, float *out, float *cur_slope, size_t step) { |
|
|
|
asm volatile( |
|
|
|
"mov x10, %[in]\n" |
|
|
|
"mov x11, %[out]\n" |
|
|
|
"mov x12, %[cur_slope]\n" |
|
|
|
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12]\n" |
|
|
|
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" |
|
|
|
"fmul v16.4s, v0.4s, v4.4s\n" |
|
|
|
"fmul v17.4s, v1.4s, v5.4s\n" |
|
|
|
"fmul v18.4s, v2.4s, v6.4s\n" |
|
|
|
"fmul v19.4s, v3.4s, v7.4s\n" |
|
|
|
"fcmgt v20.4s, v0.4s, #0\n" |
|
|
|
"fcmgt v21.4s, v1.4s, #0\n" |
|
|
|
"fcmgt v22.4s, v2.4s, #0\n" |
|
|
|
"fcmgt v23.4s, v3.4s, #0\n" |
|
|
|
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], %[step]\n" |
|
|
|
"bif v0.16b, v16.16b, v20.16b\n" |
|
|
|
"bif v1.16b, v17.16b, v21.16b\n" |
|
|
|
"bif v2.16b, v18.16b, v22.16b\n" |
|
|
|
"bif v3.16b, v19.16b, v23.16b\n" |
|
|
|
"fmul v8.4s, v24.4s, v4.4s\n" |
|
|
|
"fmul v9.4s, v25.4s, v5.4s\n" |
|
|
|
"fmul v10.4s, v26.4s, v6.4s\n" |
|
|
|
"fmul v11.4s, v27.4s, v7.4s\n" |
|
|
|
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" |
|
|
|
"fcmgt v12.4s, v24.4s, #0\n" |
|
|
|
"fcmgt v13.4s, v25.4s, #0\n" |
|
|
|
"fcmgt v14.4s, v26.4s, #0\n" |
|
|
|
"fcmgt v15.4s, v27.4s, #0\n" |
|
|
|
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" |
|
|
|
"bif v24.16b, v8.16b, v12.16b\n" |
|
|
|
"bif v25.16b, v9.16b, v13.16b\n" |
|
|
|
"bif v26.16b, v10.16b, v14.16b\n" |
|
|
|
"bif v27.16b, v11.16b, v15.16b\n" |
|
|
|
"fmul v16.4s, v0.4s, v4.4s\n" |
|
|
|
"fmul v17.4s, v1.4s, v5.4s\n" |
|
|
|
"fmul v18.4s, v2.4s, v6.4s\n" |
|
|
|
"fmul v19.4s, v3.4s, v7.4s\n" |
|
|
|
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], %[step]\n" |
|
|
|
"fcmgt v20.4s, v0.4s, #0\n" |
|
|
|
"fcmgt v21.4s, v1.4s, #0\n" |
|
|
|
"fcmgt v22.4s, v2.4s, #0\n" |
|
|
|
"fcmgt v23.4s, v3.4s, #0\n" |
|
|
|
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10]\n" |
|
|
|
"bif v0.16b, v16.16b, v20.16b\n" |
|
|
|
"bif v1.16b, v17.16b, v21.16b\n" |
|
|
|
"bif v2.16b, v18.16b, v22.16b\n" |
|
|
|
"bif v3.16b, v19.16b, v23.16b\n" |
|
|
|
"fmul v8.4s, v24.4s, v4.4s\n" |
|
|
|
"fmul v9.4s, v25.4s, v5.4s\n" |
|
|
|
"fmul v10.4s, v26.4s, v6.4s\n" |
|
|
|
"fmul v11.4s, v27.4s, v7.4s\n" |
|
|
|
"fcmgt v12.4s, v24.4s, #0\n" |
|
|
|
"fcmgt v13.4s, v25.4s, #0\n" |
|
|
|
"fcmgt v14.4s, v26.4s, #0\n" |
|
|
|
"fcmgt v15.4s, v27.4s, #0\n" |
|
|
|
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" |
|
|
|
"bif v24.16b, v8.16b, v12.16b\n" |
|
|
|
"bif v25.16b, v9.16b, v13.16b\n" |
|
|
|
"bif v26.16b, v10.16b, v14.16b\n" |
|
|
|
"bif v27.16b, v11.16b, v15.16b\n" |
|
|
|
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11]\n" |
|
|
|
: |
|
|
|
: [ in ] "r"(in), [ out ] "r"(out), [ cur_slope ] "r"(cur_slope), [ step ] "r"(step) |
|
|
|
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", |
|
|
|
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
void PRelu(const float *input, float *output, float *slope, int start, int end, int channel) { |
|
|
|
int i = start; |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
for (; i < end - 3; i += 4) { |
|
|
|
const float *cur_in = input + i * channel; |
|
|
|
float *cur_out = output + i * channel; |
|
|
|
int j = 0; |
|
|
|
for (; j < channel - 15; j += 16) { |
|
|
|
const float *in = cur_in + j; |
|
|
|
float *out = cur_out + j; |
|
|
|
float *cur_slope = slope + j; |
|
|
|
size_t step = channel * sizeof(float); |
|
|
|
PRelu4x16(in, out, cur_slope, step); |
|
|
|
} |
|
|
|
for (; j < channel; j++) { |
|
|
|
cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); |
|
|
|
cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; |
|
|
|
cur_out[j + 2 * channel] = |
|
|
|
(cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); |
|
|
|
cur_out[j + 3 * channel] = |
|
|
|
(cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
// note: First AVX processing, then SSE processing on X86 platform |
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) |
|
|
|
MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f); |
|
|
|
float *negetive_slope_value = prelu_param_->slope_; |
|
|
|
int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM; |
|
|
|
for (; channel_index < div_channel; channel_index += C4NUM) { |
|
|
|
MS_FLOAT32X4 slope_value = MS_LDQ_F32(negetive_slope_value + channel_index); |
|
|
|
LOAD128X8_F32(src, in_plane_ptr + channel_index, channel_num) |
|
|
|
PRELU_CALCULATE_128X8(dst, src) |
|
|
|
STORE128X8_F32(out_plane_ptr + channel_index, channel_num, dst) |
|
|
|
for (; i < end; i++) { |
|
|
|
const float *cur_in = input + i * channel; |
|
|
|
float *cur_out = output + i * channel; |
|
|
|
int j = 0; |
|
|
|
#if defined(ENABLE_ARM) |
|
|
|
for (; j < channel - 3; j += 4) { |
|
|
|
MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j); |
|
|
|
MS_FLOAT32X4 s = MS_LDQ_F32(slope + j); |
|
|
|
MS_FLOAT32X4 mul = MS_MULQ_F32(in, s); |
|
|
|
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); |
|
|
|
MS_FLOAT32X4 res = MS_BLENDQ_F32(mul, in, MS_CMPGTQ_F32(in, zero)); |
|
|
|
MS_STQ_F32(cur_out + j, res); |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; channel_index < channel_num; channel_index++) { |
|
|
|
float *in_c = in_plane_ptr + channel_index; |
|
|
|
float *out_c = out_plane_ptr + channel_index; |
|
|
|
for (int tile_i = 0; tile_i < TILE_NUM; tile_i++) { |
|
|
|
float *in_tile = in_c + tile_i * channel_num; |
|
|
|
float *out_tile = out_c + tile_i * channel_num; |
|
|
|
const float in_data = in_tile[0]; |
|
|
|
out_tile[0] = (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0); |
|
|
|
for (; j < channel; j++) { |
|
|
|
if (cur_in[j] > 0) { |
|
|
|
cur_out[j] = cur_in[j]; |
|
|
|
} else { |
|
|
|
cur_out[j] = cur_in[j] * slope[j]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (; plane_index < plane; plane_index++) { |
|
|
|
float *in_plane_ptr = input + plane_index * channel_num; |
|
|
|
float *out_plane_ptr = output + plane_index * channel_num; |
|
|
|
for (int channel_index = 0; channel_index < channel_num; channel_index++) { |
|
|
|
const float in_data = in_plane_ptr[channel_index]; |
|
|
|
out_plane_ptr[channel_index] = |
|
|
|
(in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id) { |
|
|
|
for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) { |
|
|
|
int cal_index; |
|
|
|
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) |
|
|
|
cal_index = j * 64; |
|
|
|
#else |
|
|
|
cal_index = j * 32; |
|
|
|
#endif |
|
|
|
|
|
|
|
float *input_ptr = input + cal_index; |
|
|
|
float *output_ptr = input + cal_index; |
|
|
|
#if defined(ENABLE_AVX) |
|
|
|
MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0); |
|
|
|
MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f); |
|
|
|
MS_FLOAT32X8 slope_value_8 = MS_MOV256_F32(prelu_param_->slope_[0]); |
|
|
|
LOAD256X8_F32(src, input_ptr, 8) |
|
|
|
PRELU_CALCULATE_256X8(dst, src) |
|
|
|
STORE256X8_F32(output_ptr, 8, dst) |
|
|
|
#elif defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) |
|
|
|
MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0); |
|
|
|
MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f); |
|
|
|
MS_FLOAT32X4 slope_value = MS_MOVQ_F32(prelu_param_->slope_[0]); |
|
|
|
|
|
|
|
LOAD128X8_F32(src, input_ptr, 4) |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
LOAD128X8_F32(src1, input_ptr + 32, 4) |
|
|
|
#endif |
|
|
|
PRELU_CALCULATE_128X8(dst, src) |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
PRELU_CALCULATE_128X8(dst1, src1) |
|
|
|
#endif |
|
|
|
STORE128X8_F32(output_ptr, 4, dst) |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
STORE128X8_F32(output_ptr + 32, 4, dst1) |
|
|
|
#endif |
|
|
|
|
|
|
|
#else |
|
|
|
const int cal_per_time = 32; |
|
|
|
for (int i = 0; i < cal_per_time; ++i) { |
|
|
|
float data = input_ptr[i]; |
|
|
|
output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0); |
|
|
|
void PReluShareChannel(const float *input, float *output, float slope, int start, int end) { |
|
|
|
for (int i = start; i < end; i++) { |
|
|
|
if (input[i] > 0) { |
|
|
|
output[i] = input[i]; |
|
|
|
} else { |
|
|
|
output[i] = input[i] * slope; |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
} |