diff --git a/mindspore/lite/nnacl/fp32/prelu_fp32.c b/mindspore/lite/nnacl/fp32/prelu_fp32.c index 016cc272f6..afe513bafd 100644 --- a/mindspore/lite/nnacl/fp32/prelu_fp32.c +++ b/mindspore/lite/nnacl/fp32/prelu_fp32.c @@ -18,89 +18,70 @@ #include #endif -void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int task_id) { - float *negetive_slope_value = prelu_param_->slope_; - int c4 = prelu_param_->channel_num_ / C4NUM; +void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) { +#ifdef ENABLE_ARM + float32x4_t zero_value = vdupq_n_f32(0); +#endif + int plane_tile = plane / TILE_NUM * TILE_NUM; int channel_num = prelu_param_->channel_num_; - for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) { - float *input_ptr = input + j * TILE_NUM * channel_num; - float *output_ptr = input_ptr; -#ifdef ENABLE_ARM64 - for (int i = 0; i < c4; i++) { - int c_offset = i * C4NUM; - float32x4_t slope_value = vld1q_f32(negetive_slope_value + c_offset); - float32x4_t v1 = vld1q_f32(input_ptr + c_offset); - float32x4_t v2 = vld1q_f32(input_ptr + c_offset + channel_num); - float32x4_t v3 = vld1q_f32(input_ptr + c_offset + 2 * channel_num); - float32x4_t v4 = vld1q_f32(input_ptr + c_offset + 3 * channel_num); - float32x4_t v5 = vld1q_f32(input_ptr + c_offset + 4 * channel_num); - float32x4_t v6 = vld1q_f32(input_ptr + c_offset + 5 * channel_num); - float32x4_t v7 = vld1q_f32(input_ptr + c_offset + 6 * channel_num); - float32x4_t v8 = vld1q_f32(input_ptr + c_offset + 7 * channel_num); + int plane_index = 0; + for (; plane_index < plane_tile; plane_index += TILE_NUM) { + float *in_plane_ptr = input + plane_index * channel_num; + float *out_plane_ptr = output + plane_index * channel_num; + int channel_index = 0; +#ifdef ENABLE_ARM + float *negetive_slope_value = prelu_param_->slope_; + int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM; + for (; channel_index < div_channel; channel_index += C4NUM) { + float32x4_t slope_value = vld1q_f32(negetive_slope_value + channel_index); + float32x4_t v1 = vld1q_f32(in_plane_ptr + channel_index + 0 * channel_num); + float32x4_t v2 = vld1q_f32(in_plane_ptr + channel_index + 1 * channel_num); + float32x4_t v3 = vld1q_f32(in_plane_ptr + channel_index + 2 * channel_num); + float32x4_t v4 = vld1q_f32(in_plane_ptr + channel_index + 3 * channel_num); + float32x4_t v5 = vld1q_f32(in_plane_ptr + channel_index + 4 * channel_num); + float32x4_t v6 = vld1q_f32(in_plane_ptr + channel_index + 5 * channel_num); + float32x4_t v7 = vld1q_f32(in_plane_ptr + channel_index + 6 * channel_num); + float32x4_t v8 = vld1q_f32(in_plane_ptr + channel_index + 7 * channel_num); - float32x4_t t1 = vmulq_f32(v1, slope_value); - float32x4_t t2 = vmulq_f32(v2, slope_value); - float32x4_t t3 = vmulq_f32(v3, slope_value); - float32x4_t t4 = vmulq_f32(v4, slope_value); - float32x4_t t5 = vmulq_f32(v5, slope_value); - float32x4_t t6 = vmulq_f32(v6, slope_value); - float32x4_t t7 = vmulq_f32(v7, slope_value); - float32x4_t t8 = vmulq_f32(v8, slope_value); + float32x4_t r1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value)); + float32x4_t r2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value)); + float32x4_t r3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value)); + float32x4_t r4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value)); + float32x4_t r5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value)); + float32x4_t r6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value)); + float32x4_t r7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value)); + float32x4_t r8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value)); - uint32x4_t flag1 = vclezq_f32(v1); - uint32x4_t flag2 = vclezq_f32(v2); - uint32x4_t flag3 = vclezq_f32(v3); - uint32x4_t flag4 = vclezq_f32(v4); - uint32x4_t flag5 = vclezq_f32(v5); - uint32x4_t flag6 = vclezq_f32(v6); - uint32x4_t flag7 = vclezq_f32(v7); - uint32x4_t flag8 = vclezq_f32(v8); - - float32x4_t r1 = vbslq_f32(flag1, t1, v1); - float32x4_t r2 = vbslq_f32(flag2, t2, v2); - float32x4_t r3 = vbslq_f32(flag3, t3, v3); - float32x4_t r4 = vbslq_f32(flag4, t4, v4); - float32x4_t r5 = vbslq_f32(flag5, t5, v5); - float32x4_t r6 = vbslq_f32(flag6, t6, v6); - float32x4_t r7 = vbslq_f32(flag7, t7, v7); - float32x4_t r8 = vbslq_f32(flag8, t8, v8); - - vst1q_f32(output_ptr + c_offset, r1); - vst1q_f32(output_ptr + c_offset + channel_num, r2); - vst1q_f32(output_ptr + c_offset + 2 * channel_num, r3); - vst1q_f32(output_ptr + c_offset + 3 * channel_num, r4); - vst1q_f32(output_ptr + c_offset + 4 * channel_num, r5); - vst1q_f32(output_ptr + c_offset + 5 * channel_num, r6); - vst1q_f32(output_ptr + c_offset + 6 * channel_num, r7); - vst1q_f32(output_ptr + c_offset + 7 * channel_num, r8); - } // c4 -1 loop -#else - for (int i = 0; i < TILE_NUM; ++i) { - int tile_offset = i * channel_num; - for (int k = 0; k < c4; ++k) { - int c4_offset = tile_offset + k * C4NUM; - int slope_offset = k * C4NUM; - for (int l = 0; l < C4NUM; ++l) { - const float in_data = input_ptr[c4_offset + l]; - output_ptr[c4_offset + l] = - (in_data < 0 ? in_data : 0) * negetive_slope_value[slope_offset + l] + (in_data > 0 ? in_data : 0); - } - } - } // c4 - 1 loop + vst1q_f32(out_plane_ptr + channel_index + 0 * channel_num, r1); + vst1q_f32(out_plane_ptr + channel_index + 1 * channel_num, r2); + vst1q_f32(out_plane_ptr + channel_index + 2 * channel_num, r3); + vst1q_f32(out_plane_ptr + channel_index + 3 * channel_num, r4); + vst1q_f32(out_plane_ptr + channel_index + 4 * channel_num, r5); + vst1q_f32(out_plane_ptr + channel_index + 5 * channel_num, r6); + vst1q_f32(out_plane_ptr + channel_index + 6 * channel_num, r7); + vst1q_f32(out_plane_ptr + channel_index + 7 * channel_num, r8); + } #endif - int c_s = c4 * C4NUM; - for (int m = 0; m < TILE_NUM; ++m) { - int offset = m * channel_num; - for (int k = c_s; k < channel_num; ++k) { - int c4_offset = offset + k; - const float in_data = input_ptr[c4_offset]; - if (in_data >= 0) { - output_ptr[c4_offset] = in_data; - } else { - output_ptr[c4_offset] = in_data * negetive_slope_value[k]; - } + for (; channel_index < channel_num; channel_index++) { + float *in_c = in_plane_ptr + channel_index; + float *out_c = out_plane_ptr + channel_index; + for (int tile_i = 0; tile_i < TILE_NUM; tile_i++) { + float *in_tile = in_c + tile_i * channel_num; + float *out_tile = out_c + tile_i * channel_num; + const float in_data = in_tile[0]; + out_tile[0] = (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0); } - } // res loop + } + } + + for (; plane_index < plane; plane_index++) { + float *in_plane_ptr = input + plane_index * channel_num; + float *out_plane_ptr = output + plane_index * channel_num; + for (int channel_index = 0; channel_index < channel_num; channel_index++) { + const float in_data = in_plane_ptr[channel_index]; + out_plane_ptr[channel_index] = + (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0); + } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc index 9733979f40..d7422f77bd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc @@ -39,19 +39,39 @@ int PReluRun(void *cdata, int task_id) { } } // namespace -int PReluCPUKernel::Init() { return RET_OK; } +int PReluCPUKernel::Init() { + if (in_tensors_[1]->ElementsNum() == 1) { + prelu_param_->channelShared = true; + } else { + prelu_param_->channelShared = false; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} int PReluCPUKernel::DoExcute(int task_id) { if (prelu_param_->channelShared) { PReluShareChannel(input_data_, output_data_, prelu_param_, task_id); } else { - PRelu(input_data_, output_data_, prelu_param_, task_id); + int res_plane = prelu_param_->input_num_ - task_id * prelu_param_->tile_block_; + int plane = MSMIN(prelu_param_->tile_block_, res_plane); + if (plane <= 0) { + return RET_OK; + } + float *in = input_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_; + float *out = output_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_; + PRelu(in, out, prelu_param_, plane); } return RET_OK; } -int PReluCPUKernel::ProcessInput() { - // input tensor +int PReluCPUKernel::ReSize() { + if (prelu_param_->channelShared) { + return RET_OK; + } + auto input_tensor = in_tensors_.at(0); auto in_shape = input_tensor->shape(); auto n_dim = in_shape.size(); @@ -60,57 +80,36 @@ int PReluCPUKernel::ProcessInput() { for (size_t i = 0; i < n_dim - 1; ++i) { input_plane *= in_shape.at(i); } - int tile_block = UP_DIV(input_plane, TILE_NUM); - prelu_param_->input_num_ = input_tensor->ElementsNum(); - prelu_param_->tile_block_ = tile_block; + + prelu_param_->input_num_ = input_plane; + prelu_param_->tile_block_ = UP_DIV(UP_DIV(input_plane, TILE_NUM), op_parameter_->thread_num_) * TILE_NUM; prelu_param_->channel_num_ = channel_num; - input_data_ = - reinterpret_cast(context_->allocator->Malloc(tile_block * TILE_NUM * channel_num * sizeof(float))); - if (input_data_ == nullptr) { - MS_LOG(ERROR) << "malloc input_data_ failed."; - return RET_ERROR; - } - memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float)); return RET_OK; } int PReluCPUKernel::ProcessShareChannelInput() { - // input tensor auto input_tensor = in_tensors_.at(0); prelu_param_->input_num_ = input_tensor->ElementsNum(); + int tile = 32; #ifdef ENABLE_ARM64 - prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 64); - input_data_ = reinterpret_cast(context_->allocator->Malloc(prelu_param_->tile_block_ * 64 * sizeof(float))); - if (input_data_ == nullptr) { - MS_LOG(ERROR) << "malloc input_data_ failed."; - return RET_ERROR; - } - memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float)); -#elif ENABLE_ARM32 - prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 32); - input_data_ = reinterpret_cast(context_->allocator->Malloc(prelu_param_->tile_block_ * 32 * sizeof(float))); - if (input_data_ == nullptr) { - MS_LOG(ERROR) << "malloc input_data_ failed."; - return RET_ERROR; - } - memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float)); -#else - prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, 32); - input_data_ = reinterpret_cast(context_->allocator->Malloc(prelu_param_->tile_block_ * 32 * sizeof(float))); + tile = 64; +#endif + prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile); + input_data_ = + reinterpret_cast(context_->allocator->Malloc(prelu_param_->tile_block_ * tile * sizeof(float))); if (input_data_ == nullptr) { MS_LOG(ERROR) << "malloc input_data_ failed."; return RET_ERROR; } memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float)); -#endif return RET_OK; } int PReluCPUKernel::Run() { MS_ASSERT(in_tensors_.size() >= 2); auto input_tensor = in_tensors_[0]; - ori_input_ = reinterpret_cast(input_tensor->MutableData()); - output_data_ = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); + ori_input_ = reinterpret_cast(input_tensor->data_c()); + output_data_ = reinterpret_cast(out_tensors_.at(kOutputIndex)->data_c()); MS_ASSERT(ori_input_); MS_ASSERT(output_data_); if (prelu_param_->channelShared) { @@ -120,16 +119,12 @@ int PReluCPUKernel::Run() { return ret; } } else { - auto ret = ProcessInput(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Process failed."; - return ret; - } + input_data_ = ori_input_; } // negative slope tensor auto negative_slope_tensor = in_tensors_.at(1); - prelu_param_->slope_ = reinterpret_cast(negative_slope_tensor->MutableData()); + prelu_param_->slope_ = reinterpret_cast(negative_slope_tensor->data_c()); auto ret = ParallelLaunch(this->context_->thread_pool_, PReluRun, this, prelu_param_->op_parameter_.thread_num_); if (ret != RET_OK) { @@ -138,8 +133,10 @@ int PReluCPUKernel::Run() { return RET_ERROR; } - memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float)); - context_->allocator->Free(input_data_); + if (prelu_param_->channelShared) { + memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float)); + context_->allocator->Free(input_data_); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.h index 54dfa1c8c6..932908153d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.h @@ -33,11 +33,10 @@ class PReluCPUKernel : public LiteKernel { ~PReluCPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; int DoExcute(int task_id); int ProcessShareChannelInput(); - int ProcessInput(); private: PReluParameter *prelu_param_;