|
|
|
@@ -19,6 +19,33 @@ |
|
|
|
#include "nnacl/errorcode.h" |
|
|
|
#include "nnacl/op_base.h" |
|
|
|
|
|
|
|
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1) |
|
|
|
#define SimdFp32AvgPoolingBatchCoreCalc(block_size, block_num, src_plane_ptr, channel, dst_plane_ptr, ci, \ |
|
|
|
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, \ |
|
|
|
in_h_index, in_w, in_w_index) \ |
|
|
|
do { \ |
|
|
|
MS_FLOAT_32xN(block_num) min_val_##block_num = MS_MOVN_F32(block_size, minf); \ |
|
|
|
MS_FLOAT_32xN(block_num) max_val_##block_num = MS_MOVN_F32(block_size, maxf); \ |
|
|
|
for (int block_max_size = channel - block_num + 1; ci < block_max_size; ci += block_num) { \ |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; \ |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; \ |
|
|
|
MS_FLOAT_32xN(block_num) tmp_avg = MS_MOVN_F32(block_size, 0.0f); \ |
|
|
|
int real_count = 0; \ |
|
|
|
for (int h = real_win_h_start; h < real_win_h_end; h++) { \ |
|
|
|
for (int w = real_win_w_start; w < real_win_w_end; w++) { \ |
|
|
|
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; \ |
|
|
|
tmp_avg = MS_ADD_F32(block_size, tmp_avg, MS_LD_F32(block_size, src_win_ptr)); \ |
|
|
|
++real_count; \ |
|
|
|
} \ |
|
|
|
} \ |
|
|
|
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); \ |
|
|
|
tmp_avg = MS_DIV_F32(block_size, tmp_avg, MS_MOVN_F32(block_size, real_count)); \ |
|
|
|
tmp_avg = MS_MAX_F32(block_size, tmp_avg, min_val_##block_num); \ |
|
|
|
tmp_avg = MS_MIN_F32(block_size, tmp_avg, max_val_##block_num); \ |
|
|
|
MS_ST_F32(block_size, dst_c_ptr, tmp_avg); \ |
|
|
|
} \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id, |
|
|
|
float minf, float maxf) { |
|
|
|
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_; |
|
|
|
@@ -28,16 +55,6 @@ int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam |
|
|
|
int out_plane = output_w * output_h; |
|
|
|
int out_tile_count = UP_DIV(out_plane, TILE_NUM); |
|
|
|
NNACL_CHECK_ZERO_RETURN_ERR(output_w); |
|
|
|
#ifdef ENABLE_AVX |
|
|
|
int c8 = channel / C8NUM * C8NUM; |
|
|
|
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf); |
|
|
|
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf); |
|
|
|
#endif |
|
|
|
#if defined(ENABLE_NEON) || defined(ENABLE_SSE) |
|
|
|
int c4 = channel / C4NUM * C4NUM; |
|
|
|
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf); |
|
|
|
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf); |
|
|
|
#endif |
|
|
|
|
|
|
|
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) { |
|
|
|
int cal_start_index = thread_id * TILE_NUM; |
|
|
|
@@ -57,46 +74,11 @@ int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam |
|
|
|
int real_win_w_start = MSMAX(0, -in_w_index); |
|
|
|
int real_win_w_end = MSMIN(win_w, in_w - in_w_index); |
|
|
|
int ci = 0; |
|
|
|
#ifdef ENABLE_AVX |
|
|
|
for (; ci < c8; ci += C8NUM) { |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; |
|
|
|
MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0); |
|
|
|
int real_count = 0; |
|
|
|
for (int h = real_win_h_start; h < real_win_h_end; h++) { |
|
|
|
for (int w = real_win_w_start; w < real_win_w_end; w++) { |
|
|
|
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; |
|
|
|
tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(src_win_ptr)); |
|
|
|
++real_count; |
|
|
|
} // win_w loop |
|
|
|
} // win_h loop |
|
|
|
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); |
|
|
|
tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count)); |
|
|
|
tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8); |
|
|
|
tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8); |
|
|
|
MS_ST256_F32(dst_c_ptr, tmp_avg); |
|
|
|
} // ic8-1 loop |
|
|
|
#endif |
|
|
|
#if defined(ENABLE_NEON) || defined(ENABLE_SSE) |
|
|
|
for (; ci < c4; ci += C4NUM) { |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; |
|
|
|
MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0); |
|
|
|
int real_count = 0; |
|
|
|
for (int h = real_win_h_start; h < real_win_h_end; h++) { |
|
|
|
for (int w = real_win_w_start; w < real_win_w_end; w++) { |
|
|
|
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; |
|
|
|
tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(src_win_ptr)); |
|
|
|
++real_count; |
|
|
|
} // win_w loop |
|
|
|
} // win_h loop |
|
|
|
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); |
|
|
|
tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count)); |
|
|
|
tmp_avg = MS_MAXQ_F32(tmp_avg, min_value); |
|
|
|
tmp_avg = MS_MINQ_F32(tmp_avg, max_value); |
|
|
|
MS_STQ_F32(dst_c_ptr, tmp_avg); |
|
|
|
} // ic4-1 loop |
|
|
|
#endif |
|
|
|
|
|
|
|
MS_SIMD_RUN_NO_SCALAR(SimdFp32AvgPoolingBatchCoreCalc, src_plane_ptr, channel, dst_plane_ptr, ci, |
|
|
|
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w, |
|
|
|
in_w_index); |
|
|
|
|
|
|
|
for (; ci < channel; ci++) { |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; |
|
|
|
@@ -377,6 +359,29 @@ int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1) |
|
|
|
#define SimdFp32MaxPoolingBatchCoreCalc(block_size, block_num, src_plane_ptr, channel, dst_plane_ptr, ci, \ |
|
|
|
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, \ |
|
|
|
in_h_index, in_w, in_w_index) \ |
|
|
|
do { \ |
|
|
|
MS_FLOAT_32xN(block_num) min_val_##block_num = MS_MOVN_F32(block_size, minf); \ |
|
|
|
MS_FLOAT_32xN(block_num) max_val_##block_num = MS_MOVN_F32(block_size, maxf); \ |
|
|
|
for (int block_max_size = channel - block_num + 1; ci < block_max_size; ci += block_num) { \ |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; \ |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; \ |
|
|
|
MS_FLOAT_32xN(block_num) tmp_max = MS_MOVN_F32(block_size, -FLT_MAX); \ |
|
|
|
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { \ |
|
|
|
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { \ |
|
|
|
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; \ |
|
|
|
tmp_max = MS_MAX_F32(block_size, tmp_max, MS_LD_F32(block_size, src_win_ptr)); \ |
|
|
|
} \ |
|
|
|
} \ |
|
|
|
tmp_max = MS_MAX_F32(block_size, tmp_max, min_val_##block_num); \ |
|
|
|
tmp_max = MS_MIN_F32(block_size, tmp_max, max_val_##block_num); \ |
|
|
|
MS_ST_F32(block_size, dst_c_ptr, tmp_max); \ |
|
|
|
} \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id, |
|
|
|
float minf, float maxf) { |
|
|
|
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_; |
|
|
|
@@ -386,16 +391,6 @@ int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam |
|
|
|
int out_plane = output_w * output_h; |
|
|
|
int out_tile_count = UP_DIV(out_plane, TILE_NUM); |
|
|
|
NNACL_CHECK_ZERO_RETURN_ERR(output_w); |
|
|
|
#ifdef ENABLE_AVX |
|
|
|
int c8 = channel / C8NUM * C8NUM; |
|
|
|
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf); |
|
|
|
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf); |
|
|
|
#endif |
|
|
|
#if defined(ENABLE_NEON) || defined(ENABLE_SSE) |
|
|
|
int c4 = channel / C4NUM * C4NUM; |
|
|
|
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf); |
|
|
|
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf); |
|
|
|
#endif |
|
|
|
|
|
|
|
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) { |
|
|
|
int cal_start_index = thread_id * TILE_NUM; |
|
|
|
@@ -415,38 +410,11 @@ int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam |
|
|
|
int real_win_w_start = MSMAX(0, -in_w_index); |
|
|
|
int real_win_w_end = MSMIN(win_w, in_w - in_w_index); |
|
|
|
int ci = 0; |
|
|
|
#ifdef ENABLE_AVX |
|
|
|
for (; ci < c8; ci += C8NUM) { |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; |
|
|
|
MS_FLOAT32X8 tmp_max = MS_MOV256_F32(-FLT_MAX); |
|
|
|
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { |
|
|
|
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { |
|
|
|
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; |
|
|
|
tmp_max = MS_MAX256_F32(tmp_max, MS_LD256_F32(src_win_ptr)); |
|
|
|
} // win_w loop |
|
|
|
} // win_h loop |
|
|
|
tmp_max = MS_MAX256_F32(tmp_max, min_value_8); |
|
|
|
tmp_max = MS_MIN256_F32(tmp_max, max_value_8); |
|
|
|
MS_ST256_F32(dst_c_ptr, tmp_max); |
|
|
|
} // ic8 loop |
|
|
|
#endif |
|
|
|
#if defined(ENABLE_NEON) || defined(ENABLE_SSE) |
|
|
|
for (; ci < c4; ci += C4NUM) { |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; |
|
|
|
MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX); |
|
|
|
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { |
|
|
|
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { |
|
|
|
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; |
|
|
|
tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(src_win_ptr)); |
|
|
|
} // win_w loop |
|
|
|
} // win_h loop |
|
|
|
tmp_max = MS_MAXQ_F32(tmp_max, min_value); |
|
|
|
tmp_max = MS_MINQ_F32(tmp_max, max_value); |
|
|
|
MS_STQ_F32(dst_c_ptr, tmp_max); |
|
|
|
} // ic4 loop |
|
|
|
#endif |
|
|
|
|
|
|
|
MS_SIMD_RUN_NO_SCALAR(SimdFp32MaxPoolingBatchCoreCalc, src_plane_ptr, channel, dst_plane_ptr, ci, |
|
|
|
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w, |
|
|
|
in_w_index); |
|
|
|
|
|
|
|
for (; ci < channel; ci++) { |
|
|
|
float *dst_c_ptr = dst_plane_ptr + ci; |
|
|
|
const float *src_c_ptr = src_plane_ptr + ci; |
|
|
|
|