| @@ -130,12 +130,11 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo | |||||
| int out_plane = output_w * output_h; | int out_plane = output_w * output_h; | ||||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | int out_tile_count = UP_DIV(out_plane, TILE_NUM); | ||||
| int thread_num = pooling_param->thread_num_; | int thread_num = pooling_param->thread_num_; | ||||
| int c4 = UP_DIV(channel, C4NUM); | |||||
| // input channel is equal to output channel | |||||
| int c4 = UP_DIV(channel, C4NUM); /* oc && ic */ | |||||
| for (int batch = 0; batch < output_batch; batch++) { | for (int batch = 0; batch < output_batch; batch++) { | ||||
| int in_batch_offset = batch * in_h * in_w * channel; | |||||
| int out_batch_offset = batch * output_h * output_w * channel; | |||||
| const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; | |||||
| float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; | |||||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | ||||
| int cal_start_index = thread_id * TILE_NUM; | int cal_start_index = thread_id * TILE_NUM; | ||||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | ||||
| @@ -145,10 +144,18 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo | |||||
| int out_h_index = index / output_w; | int out_h_index = index / output_w; | ||||
| int in_w_index = out_w_index * stride_w - pad_w; | int in_w_index = out_w_index * stride_w - pad_w; | ||||
| int in_h_index = out_h_index * stride_h - pad_h; | int in_h_index = out_h_index * stride_h - pad_h; | ||||
| int out_plane_offset = out_batch_offset + index * channel; | |||||
| for (int j = 0; j < c4 - 1; j++) { | |||||
| int in_channel_offset = in_batch_offset + j * C4NUM; | |||||
| int out_channel_offset = out_plane_offset + j * C4NUM; | |||||
| const float *src_plane_ptr = src_b_ptr; | |||||
| float *dst_plane_ptr = dst_b_ptr + index * channel; | |||||
| int real_win_h_start = MSMAX(0, -in_h_index); | |||||
| int real_win_h_end = MSMIN(win_h, in_h - in_h_index); | |||||
| int resl_win_w_start = MSMAX(0, -in_w_index); | |||||
| int real_win_w_end = MSMIN(win_w, in_w - in_w_index); | |||||
| for (int ci = 0; ci < c4 - 1; ci++) { | |||||
| const float *src_c_ptr = src_plane_ptr + ci * C4NUM; | |||||
| float *dst_c_ptr = dst_plane_ptr + ci * C4NUM; | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX); | float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX); | ||||
| #else | #else | ||||
| @@ -157,51 +164,43 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo | |||||
| float tmp_max3 = -FLT_MAX; | float tmp_max3 = -FLT_MAX; | ||||
| float tmp_max4 = -FLT_MAX; | float tmp_max4 = -FLT_MAX; | ||||
| #endif | #endif | ||||
| for (int h = 0; h < win_h; h++) { | |||||
| for (int w = 0; w < win_w; w++) { | |||||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||||
| (in_w_index + w) >= in_w) { | |||||
| continue; | |||||
| } else { | |||||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||||
| for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { | |||||
| for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) { | |||||
| const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset)); | |||||
| tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr)); | |||||
| #else | #else | ||||
| tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset)); | |||||
| tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1)); | |||||
| tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2)); | |||||
| tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3)); | |||||
| tmp_max1 = fmax(tmp_max1, src_win_ptr[0]); | |||||
| tmp_max2 = fmax(tmp_max2, src_win_ptr[1]); | |||||
| tmp_max3 = fmax(tmp_max3, src_win_ptr[2]); | |||||
| tmp_max4 = fmax(tmp_max4, src_win_ptr[3]); | |||||
| #endif | #endif | ||||
| } | |||||
| } // win_w loop | } // win_w loop | ||||
| } // win_h loop | } // win_h loop | ||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| vst1q_f32(output_ptr + out_channel_offset, tmp_max); | |||||
| vst1q_f32(dst_c_ptr, tmp_max); | |||||
| #else | #else | ||||
| *(output_ptr + out_channel_offset) = tmp_max1; | |||||
| *(output_ptr + out_channel_offset + 1) = tmp_max2; | |||||
| *(output_ptr + out_channel_offset + 2) = tmp_max3; | |||||
| *(output_ptr + out_channel_offset + 3) = tmp_max4; | |||||
| dst_c_ptr[0] = tmp_max1; | |||||
| dst_c_ptr[1] = tmp_max2; | |||||
| dst_c_ptr[2] = tmp_max3; | |||||
| dst_c_ptr[3] = tmp_max4; | |||||
| #endif | #endif | ||||
| } // ic4-1 loop | } // ic4-1 loop | ||||
| int channel_s = (c4 - 1) * C4NUM; | int channel_s = (c4 - 1) * C4NUM; | ||||
| for (int k = channel_s; k < channel; k++) { | |||||
| int in_channel_offset = in_batch_offset + k; | |||||
| int out_channel_offset = out_plane_offset + k; | |||||
| for (int ci = channel_s; ci < channel; ci++) { | |||||
| float *dst_c_ptr = dst_plane_ptr + ci; | |||||
| const float *src_c_ptr = src_plane_ptr + ci; | |||||
| float tmp_max = -FLT_MAX; | float tmp_max = -FLT_MAX; | ||||
| for (int h = 0; h < win_h; h++) { | |||||
| for (int w = 0; w < win_w; w++) { | |||||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||||
| (in_w_index + w) >= in_w) { | |||||
| continue; | |||||
| } else { | |||||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||||
| tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); | |||||
| } | |||||
| for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { | |||||
| for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) { | |||||
| const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; | |||||
| tmp_max = fmax(tmp_max, src_win_ptr[0]); | |||||
| } // win_w loop | } // win_w loop | ||||
| } // win_h loop | } // win_h loop | ||||
| *(output_ptr + out_channel_offset) = tmp_max; | |||||
| dst_c_ptr[0] = tmp_max; | |||||
| } // channel_res loop | } // channel_res loop | ||||
| } // real_cal_num loop | } // real_cal_num loop | ||||
| } // out_plane loop | } // out_plane loop | ||||