From: @lx0095 Reviewed-by: @zhang_xue_tong Signed-off-by: @zhang_xue_tongpull/14088/MERGE
| @@ -15,104 +15,130 @@ | |||||
| */ | */ | ||||
| #include "nnacl/fp32/prelu_fp32.h" | #include "nnacl/fp32/prelu_fp32.h" | ||||
| void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) { | |||||
| int plane_tile = plane / TILE_NUM * TILE_NUM; | |||||
| int channel_num = prelu_param_->channel_num_; | |||||
| int plane_index = 0; | |||||
| for (; plane_index < plane_tile; plane_index += TILE_NUM) { | |||||
| float *in_plane_ptr = input + plane_index * channel_num; | |||||
| float *out_plane_ptr = output + plane_index * channel_num; | |||||
| int channel_index = 0; | |||||
| #if defined(ENABLE_AVX) | |||||
| MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0.0f); | |||||
| MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f); | |||||
| float *negetive_slope_value_8 = prelu_param_->slope_; | |||||
| int div_channel_c8 = prelu_param_->channel_num_ / C8NUM * C8NUM; | |||||
| for (; channel_index < div_channel_c8; channel_index += C8NUM) { | |||||
| MS_FLOAT32X8 slope_value_8 = MS_LD256_F32(negetive_slope_value_8 + channel_index); | |||||
| LOAD256X8_F32(src, in_plane_ptr + channel_index, channel_num) | |||||
| PRELU_CALCULATE_256X8(dst, src) | |||||
| STORE256X8_F32(out_plane_ptr + channel_index, channel_num, dst) | |||||
| #ifdef ENABLE_ARM64 | |||||
| inline void PRelu4x16(const float *in, float *out, float *cur_slope, size_t step) { | |||||
| asm volatile( | |||||
| "mov x10, %[in]\n" | |||||
| "mov x11, %[out]\n" | |||||
| "mov x12, %[cur_slope]\n" | |||||
| "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12]\n" | |||||
| "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" | |||||
| "fmul v16.4s, v0.4s, v4.4s\n" | |||||
| "fmul v17.4s, v1.4s, v5.4s\n" | |||||
| "fmul v18.4s, v2.4s, v6.4s\n" | |||||
| "fmul v19.4s, v3.4s, v7.4s\n" | |||||
| "fcmgt v20.4s, v0.4s, #0\n" | |||||
| "fcmgt v21.4s, v1.4s, #0\n" | |||||
| "fcmgt v22.4s, v2.4s, #0\n" | |||||
| "fcmgt v23.4s, v3.4s, #0\n" | |||||
| "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], %[step]\n" | |||||
| "bif v0.16b, v16.16b, v20.16b\n" | |||||
| "bif v1.16b, v17.16b, v21.16b\n" | |||||
| "bif v2.16b, v18.16b, v22.16b\n" | |||||
| "bif v3.16b, v19.16b, v23.16b\n" | |||||
| "fmul v8.4s, v24.4s, v4.4s\n" | |||||
| "fmul v9.4s, v25.4s, v5.4s\n" | |||||
| "fmul v10.4s, v26.4s, v6.4s\n" | |||||
| "fmul v11.4s, v27.4s, v7.4s\n" | |||||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" | |||||
| "fcmgt v12.4s, v24.4s, #0\n" | |||||
| "fcmgt v13.4s, v25.4s, #0\n" | |||||
| "fcmgt v14.4s, v26.4s, #0\n" | |||||
| "fcmgt v15.4s, v27.4s, #0\n" | |||||
| "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" | |||||
| "bif v24.16b, v8.16b, v12.16b\n" | |||||
| "bif v25.16b, v9.16b, v13.16b\n" | |||||
| "bif v26.16b, v10.16b, v14.16b\n" | |||||
| "bif v27.16b, v11.16b, v15.16b\n" | |||||
| "fmul v16.4s, v0.4s, v4.4s\n" | |||||
| "fmul v17.4s, v1.4s, v5.4s\n" | |||||
| "fmul v18.4s, v2.4s, v6.4s\n" | |||||
| "fmul v19.4s, v3.4s, v7.4s\n" | |||||
| "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], %[step]\n" | |||||
| "fcmgt v20.4s, v0.4s, #0\n" | |||||
| "fcmgt v21.4s, v1.4s, #0\n" | |||||
| "fcmgt v22.4s, v2.4s, #0\n" | |||||
| "fcmgt v23.4s, v3.4s, #0\n" | |||||
| "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10]\n" | |||||
| "bif v0.16b, v16.16b, v20.16b\n" | |||||
| "bif v1.16b, v17.16b, v21.16b\n" | |||||
| "bif v2.16b, v18.16b, v22.16b\n" | |||||
| "bif v3.16b, v19.16b, v23.16b\n" | |||||
| "fmul v8.4s, v24.4s, v4.4s\n" | |||||
| "fmul v9.4s, v25.4s, v5.4s\n" | |||||
| "fmul v10.4s, v26.4s, v6.4s\n" | |||||
| "fmul v11.4s, v27.4s, v7.4s\n" | |||||
| "fcmgt v12.4s, v24.4s, #0\n" | |||||
| "fcmgt v13.4s, v25.4s, #0\n" | |||||
| "fcmgt v14.4s, v26.4s, #0\n" | |||||
| "fcmgt v15.4s, v27.4s, #0\n" | |||||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" | |||||
| "bif v24.16b, v8.16b, v12.16b\n" | |||||
| "bif v25.16b, v9.16b, v13.16b\n" | |||||
| "bif v26.16b, v10.16b, v14.16b\n" | |||||
| "bif v27.16b, v11.16b, v15.16b\n" | |||||
| "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11]\n" | |||||
| : | |||||
| : [ in ] "r"(in), [ out ] "r"(out), [ cur_slope ] "r"(cur_slope), [ step ] "r"(step) | |||||
| : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", | |||||
| "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); | |||||
| } | |||||
| #endif | |||||
| void PRelu(const float *input, float *output, float *slope, int start, int end, int channel) { | |||||
| int i = start; | |||||
| #ifdef ENABLE_ARM64 | |||||
| for (; i < end - 3; i += 4) { | |||||
| const float *cur_in = input + i * channel; | |||||
| float *cur_out = output + i * channel; | |||||
| int j = 0; | |||||
| for (; j < channel - 15; j += 16) { | |||||
| const float *in = cur_in + j; | |||||
| float *out = cur_out + j; | |||||
| float *cur_slope = slope + j; | |||||
| size_t step = channel * sizeof(float); | |||||
| PRelu4x16(in, out, cur_slope, step); | |||||
| } | |||||
| for (; j < channel; j++) { | |||||
| cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); | |||||
| cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; | |||||
| cur_out[j + 2 * channel] = | |||||
| (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); | |||||
| cur_out[j + 3 * channel] = | |||||
| (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); | |||||
| } | } | ||||
| } | |||||
| #endif | #endif | ||||
| // note: First AVX processing, then SSE processing on X86 platform | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0.0f); | |||||
| MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f); | |||||
| float *negetive_slope_value = prelu_param_->slope_; | |||||
| int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM; | |||||
| for (; channel_index < div_channel; channel_index += C4NUM) { | |||||
| MS_FLOAT32X4 slope_value = MS_LDQ_F32(negetive_slope_value + channel_index); | |||||
| LOAD128X8_F32(src, in_plane_ptr + channel_index, channel_num) | |||||
| PRELU_CALCULATE_128X8(dst, src) | |||||
| STORE128X8_F32(out_plane_ptr + channel_index, channel_num, dst) | |||||
| for (; i < end; i++) { | |||||
| const float *cur_in = input + i * channel; | |||||
| float *cur_out = output + i * channel; | |||||
| int j = 0; | |||||
| #if defined(ENABLE_ARM) | |||||
| for (; j < channel - 3; j += 4) { | |||||
| MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j); | |||||
| MS_FLOAT32X4 s = MS_LDQ_F32(slope + j); | |||||
| MS_FLOAT32X4 mul = MS_MULQ_F32(in, s); | |||||
| MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); | |||||
| MS_FLOAT32X4 res = MS_BLENDQ_F32(mul, in, MS_CMPGTQ_F32(in, zero)); | |||||
| MS_STQ_F32(cur_out + j, res); | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; channel_index < channel_num; channel_index++) { | |||||
| float *in_c = in_plane_ptr + channel_index; | |||||
| float *out_c = out_plane_ptr + channel_index; | |||||
| for (int tile_i = 0; tile_i < TILE_NUM; tile_i++) { | |||||
| float *in_tile = in_c + tile_i * channel_num; | |||||
| float *out_tile = out_c + tile_i * channel_num; | |||||
| const float in_data = in_tile[0]; | |||||
| out_tile[0] = (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0); | |||||
| for (; j < channel; j++) { | |||||
| if (cur_in[j] > 0) { | |||||
| cur_out[j] = cur_in[j]; | |||||
| } else { | |||||
| cur_out[j] = cur_in[j] * slope[j]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| for (; plane_index < plane; plane_index++) { | |||||
| float *in_plane_ptr = input + plane_index * channel_num; | |||||
| float *out_plane_ptr = output + plane_index * channel_num; | |||||
| for (int channel_index = 0; channel_index < channel_num; channel_index++) { | |||||
| const float in_data = in_plane_ptr[channel_index]; | |||||
| out_plane_ptr[channel_index] = | |||||
| (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id) { | |||||
| for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) { | |||||
| int cal_index; | |||||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||||
| cal_index = j * 64; | |||||
| #else | |||||
| cal_index = j * 32; | |||||
| #endif | |||||
| float *input_ptr = input + cal_index; | |||||
| float *output_ptr = input + cal_index; | |||||
| #if defined(ENABLE_AVX) | |||||
| MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0); | |||||
| MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f); | |||||
| MS_FLOAT32X8 slope_value_8 = MS_MOV256_F32(prelu_param_->slope_[0]); | |||||
| LOAD256X8_F32(src, input_ptr, 8) | |||||
| PRELU_CALCULATE_256X8(dst, src) | |||||
| STORE256X8_F32(output_ptr, 8, dst) | |||||
| #elif defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) | |||||
| MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0); | |||||
| MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f); | |||||
| MS_FLOAT32X4 slope_value = MS_MOVQ_F32(prelu_param_->slope_[0]); | |||||
| LOAD128X8_F32(src, input_ptr, 4) | |||||
| #ifdef ENABLE_ARM64 | |||||
| LOAD128X8_F32(src1, input_ptr + 32, 4) | |||||
| #endif | |||||
| PRELU_CALCULATE_128X8(dst, src) | |||||
| #ifdef ENABLE_ARM64 | |||||
| PRELU_CALCULATE_128X8(dst1, src1) | |||||
| #endif | |||||
| STORE128X8_F32(output_ptr, 4, dst) | |||||
| #ifdef ENABLE_ARM64 | |||||
| STORE128X8_F32(output_ptr + 32, 4, dst1) | |||||
| #endif | |||||
| #else | |||||
| const int cal_per_time = 32; | |||||
| for (int i = 0; i < cal_per_time; ++i) { | |||||
| float data = input_ptr[i]; | |||||
| output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0); | |||||
| void PReluShareChannel(const float *input, float *output, float slope, int start, int end) { | |||||
| for (int i = start; i < end; i++) { | |||||
| if (input[i] > 0) { | |||||
| output[i] = input[i]; | |||||
| } else { | |||||
| output[i] = input[i] * slope; | |||||
| } | } | ||||
| #endif | |||||
| } | } | ||||
| } | } | ||||
| @@ -22,39 +22,11 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int task_id); | |||||
| void PRelu(const float *input, float *output, float *slope, int start, int end, int channel); | |||||
| void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id); | |||||
| void PReluShareChannel(const float *input, float *output, float slope, int start, int end); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #define PRELU_CALCULATE_256X8(dst, src) \ | |||||
| MS_FLOAT32X8 dst##1 = \ | |||||
| MS_MUL256_F32(src##1, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##1, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##2 = \ | |||||
| MS_MUL256_F32(src##2, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##2, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##3 = \ | |||||
| MS_MUL256_F32(src##3, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##3, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##4 = \ | |||||
| MS_MUL256_F32(src##4, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##4, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##5 = \ | |||||
| MS_MUL256_F32(src##5, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##5, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##6 = \ | |||||
| MS_MUL256_F32(src##6, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##6, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##7 = \ | |||||
| MS_MUL256_F32(src##7, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##7, zero_value_8, 30))); \ | |||||
| MS_FLOAT32X8 dst##8 = \ | |||||
| MS_MUL256_F32(src##8, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##8, zero_value_8, 30))); | |||||
| #define PRELU_CALCULATE_128X8(dst, src) \ | |||||
| MS_FLOAT32X4 dst##1 = MS_MULQ_F32(src##1, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##1, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##2 = MS_MULQ_F32(src##2, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##2, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##3 = MS_MULQ_F32(src##3, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##3, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##4 = MS_MULQ_F32(src##4, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##4, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##5 = MS_MULQ_F32(src##5, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##5, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##6 = MS_MULQ_F32(src##6, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##6, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##7 = MS_MULQ_F32(src##7, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##7, zero_value))); \ | |||||
| MS_FLOAT32X4 dst##8 = MS_MULQ_F32(src##8, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##8, zero_value))); | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_PRELU_H_ | #endif // MINDSPORE_LITE_NNACL_FP32_PRELU_H_ | ||||
| @@ -27,8 +27,7 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_PReLUFusion; | using mindspore::schema::PrimitiveType_PReLUFusion; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| namespace { | |||||
| int PReluRun(void *cdata, int task_id) { | |||||
| static int PReluRun(void *cdata, int task_id) { | |||||
| auto PRelu = reinterpret_cast<PReluCPUKernel *>(cdata); | auto PRelu = reinterpret_cast<PReluCPUKernel *>(cdata); | ||||
| auto ret = PRelu->DoExcute(task_id); | auto ret = PRelu->DoExcute(task_id); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -37,7 +36,6 @@ int PReluRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace | |||||
| int PReluCPUKernel::Init() { | int PReluCPUKernel::Init() { | ||||
| if (in_tensors_[1]->ElementsNum() == 1) { | if (in_tensors_[1]->ElementsNum() == 1) { | ||||
| @@ -52,26 +50,22 @@ int PReluCPUKernel::Init() { | |||||
| } | } | ||||
| int PReluCPUKernel::DoExcute(int task_id) { | int PReluCPUKernel::DoExcute(int task_id) { | ||||
| int thread_num = prelu_param_->op_parameter_.thread_num_; | |||||
| if (prelu_param_->channelShared) { | if (prelu_param_->channelShared) { | ||||
| PReluShareChannel(input_data_, output_data_, prelu_param_, task_id); | |||||
| int step = UP_DIV(prelu_param_->input_num_, thread_num); | |||||
| int start = task_id * step; | |||||
| int end = MSMIN(start + step, prelu_param_->input_num_); | |||||
| PReluShareChannel(input_data_, output_data_, prelu_param_->slope_[0], start, end); | |||||
| } else { | } else { | ||||
| int res_plane = prelu_param_->input_num_ - task_id * prelu_param_->tile_block_; | |||||
| int plane = MSMIN(prelu_param_->tile_block_, res_plane); | |||||
| if (plane <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| float *in = input_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_; | |||||
| float *out = output_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_; | |||||
| PRelu(in, out, prelu_param_, plane); | |||||
| int step = UP_DIV(prelu_param_->tile_block_, thread_num); | |||||
| int start = task_id * step; | |||||
| int end = MSMIN(start + step, prelu_param_->tile_block_); | |||||
| PRelu(input_data_, output_data_, prelu_param_->slope_, start, end, prelu_param_->channel_num_); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int PReluCPUKernel::ReSize() { | int PReluCPUKernel::ReSize() { | ||||
| if (prelu_param_->channelShared) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_tensor = in_tensors_.at(0); | auto input_tensor = in_tensors_.at(0); | ||||
| auto in_shape = input_tensor->shape(); | auto in_shape = input_tensor->shape(); | ||||
| auto n_dim = in_shape.size(); | auto n_dim = in_shape.size(); | ||||
| @@ -81,46 +75,19 @@ int PReluCPUKernel::ReSize() { | |||||
| input_plane *= in_shape.at(i); | input_plane *= in_shape.at(i); | ||||
| } | } | ||||
| prelu_param_->input_num_ = input_plane; | |||||
| prelu_param_->tile_block_ = UP_DIV(UP_DIV(input_plane, TILE_NUM), op_parameter_->thread_num_) * TILE_NUM; | |||||
| prelu_param_->input_num_ = input_plane * channel_num; | |||||
| prelu_param_->tile_block_ = input_plane; | |||||
| prelu_param_->channel_num_ = channel_num; | prelu_param_->channel_num_ = channel_num; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int PReluCPUKernel::ProcessShareChannelInput() { | |||||
| auto input_tensor = in_tensors_.at(0); | |||||
| prelu_param_->input_num_ = input_tensor->ElementsNum(); | |||||
| int tile = 32; | |||||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||||
| tile = 64; | |||||
| #endif | |||||
| prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile); | |||||
| input_data_ = | |||||
| reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * tile * sizeof(float))); | |||||
| if (input_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc input_data_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float)); | |||||
| return RET_OK; | |||||
| } | |||||
| int PReluCPUKernel::Run() { | int PReluCPUKernel::Run() { | ||||
| MS_ASSERT(in_tensors_.size() >= 2); | MS_ASSERT(in_tensors_.size() >= 2); | ||||
| auto input_tensor = in_tensors_[0]; | auto input_tensor = in_tensors_[0]; | ||||
| ori_input_ = reinterpret_cast<float *>(input_tensor->data_c()); | |||||
| input_data_ = reinterpret_cast<float *>(input_tensor->data_c()); | |||||
| output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c()); | output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c()); | ||||
| MS_ASSERT(ori_input_); | |||||
| MS_ASSERT(input_data_); | |||||
| MS_ASSERT(output_data_); | MS_ASSERT(output_data_); | ||||
| if (prelu_param_->channelShared) { | |||||
| auto ret = ProcessShareChannelInput(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ProcessShareChannel failed."; | |||||
| return ret; | |||||
| } | |||||
| } else { | |||||
| input_data_ = ori_input_; | |||||
| } | |||||
| // negative slope tensor | // negative slope tensor | ||||
| auto negative_slope_tensor = in_tensors_.at(1); | auto negative_slope_tensor = in_tensors_.at(1); | ||||
| @@ -129,14 +96,9 @@ int PReluCPUKernel::Run() { | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, PReluRun, this, prelu_param_->op_parameter_.thread_num_); | auto ret = ParallelLaunch(this->context_->thread_pool_, PReluRun, this, prelu_param_->op_parameter_.thread_num_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "PRelu Run error: error_code[" << ret << "]"; | MS_LOG(ERROR) << "PRelu Run error: error_code[" << ret << "]"; | ||||
| context_->allocator->Free(input_data_); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (prelu_param_->channelShared) { | |||||
| memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float)); | |||||
| context_->allocator->Free(input_data_); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -35,11 +35,9 @@ class PReluCPUKernel : public LiteKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int DoExcute(int task_id); | int DoExcute(int task_id); | ||||
| int ProcessShareChannelInput(); | |||||
| private: | private: | ||||
| PReluParameter *prelu_param_; | PReluParameter *prelu_param_; | ||||
| float *ori_input_ = nullptr; | |||||
| float *input_data_ = nullptr; | float *input_data_ = nullptr; | ||||
| float *output_data_ = nullptr; | float *output_data_ = nullptr; | ||||
| }; | }; | ||||