| @@ -0,0 +1,129 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/cast_fp16.h" | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/op_base.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Cast; | |||
| namespace mindspore::kernel { | |||
| namespace { | |||
| int CastRun(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| if (cdata == nullptr) { | |||
| MS_LOG(ERROR) << "input cdata is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| return reinterpret_cast<CastFp16CPUKernel *>(cdata)->DoCast(thread_id); | |||
| } | |||
| } // namespace | |||
| int CastFp16CPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int CastFp16CPUKernel::ReSize() { | |||
| data_num_ = in_tensors_[0]->ElementsNum(); | |||
| if (data_num_ == 0) { | |||
| return RET_OK; | |||
| } | |||
| op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, data_num_); | |||
| stride_ = UP_DIV(data_num_, op_parameter_->thread_num_); | |||
| return RET_OK; | |||
| } | |||
| int CastFp16CPUKernel::DoCast(int thread_id) { | |||
| auto input = in_tensors_.at(0); | |||
| int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); | |||
| if (data_num <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto offset = thread_id * stride_; | |||
| auto output_data = out_tensors_.at(0)->Data(); | |||
| switch (input->data_type()) { | |||
| case kNumberTypeFloat32: | |||
| Float32ToFloat16(reinterpret_cast<float *>(input->Data()) + offset, | |||
| reinterpret_cast<float16_t *>(output_data) + offset, data_num); | |||
| break; | |||
| case kNumberTypeFloat16: | |||
| Float16ToFloat32(reinterpret_cast<float16_t *>(input->Data()) + offset, | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CastFp16CPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| if (data_num_ == 0) { | |||
| return RET_OK; | |||
| } | |||
| return LiteBackendParallelLaunch(CastRun, this, op_parameter_->thread_num_); | |||
| } | |||
| kernel::LiteKernel *CpuCastFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx->thread_num_ == 0) { | |||
| MS_LOG(ERROR) << "context thread num is 0!"; | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) CastFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new CastFp16CPUKernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, CpuCastFp16KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CAST_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CAST_FP16_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| class CastFp16CPUKernel : public LiteKernel { | |||
| public: | |||
| CastFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~CastFp16CPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoCast(int thread_id); | |||
| private: | |||
| uint32_t stride_; | |||
| uint32_t data_num_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CAST_FP16_H_ | |||
| @@ -0,0 +1,149 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/pooling_fp16.h" | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/pooling_fp16.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/kernel/arm/nnacl/op_base.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Pooling; | |||
| namespace mindspore::kernel { | |||
| int PoolingFp16CPUKernel::InitBuffer() { | |||
| int in_batch = pooling_param_->input_batch_; | |||
| int in_h = pooling_param_->input_h_; | |||
| int in_w = pooling_param_->input_w_; | |||
| int in_channel = pooling_param_->input_channel_; | |||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * in_h * in_w * in_channel * sizeof(float16_t))); | |||
| if (fp16_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int out_batch = pooling_param_->output_batch_; | |||
| int out_h = pooling_param_->output_h_; | |||
| int out_w = pooling_param_->output_w_; | |||
| int out_channel = pooling_param_->output_channel_; | |||
| fp16_output_ = reinterpret_cast<float16_t *>(malloc(out_batch * out_h * out_w * out_channel * sizeof(float16_t))); | |||
| if (fp16_output_ == nullptr) { | |||
| MS_LOG(ERROR) << "fp16_out malloc failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingFp16CPUKernel::Init() { | |||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||
| set_need_reinit(); | |||
| return RET_OK; | |||
| } | |||
| auto ret = PoolingBaseCPUKernel::Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PoolingBase Init failed."; | |||
| return ret; | |||
| } | |||
| ret = InitBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init Buffer failed."; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingFp16CPUKernel::ReSize() { | |||
| auto ret = Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Pooling resize init failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingFp16CPUKernel::RunImpl(int task_id) { | |||
| if (pooling_param_->max_pooling_) { | |||
| MaxPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id); | |||
| } else { | |||
| AvgPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto pooling = reinterpret_cast<PoolingFp16CPUKernel *>(cdata); | |||
| auto error_code = pooling->RunImpl(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingFp16CPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| auto ele_num = in_tensors_.front()->ElementsNum(); | |||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->Data()); | |||
| Float32ToFloat16(input_ptr, fp16_input_, ele_num); | |||
| int error_code = LiteBackendParallelLaunch(PoolingFp16Impl, this, thread_count_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| auto out_ele_num = out_tensors_.front()->ElementsNum(); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data()); | |||
| Float16ToFloat32(fp16_output_, output_ptr, out_ele_num); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuPoolingFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); | |||
| auto *kernel = new (std::nothrow) PoolingFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new PoolingCPUKernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Pooling, CpuPoolingFp16KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_POOLING_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_POOLING_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/pooling_base.h" | |||
| namespace mindspore::kernel { | |||
| class PoolingFp16CPUKernel : public PoolingBaseCPUKernel { | |||
| public: | |||
| PoolingFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~PoolingFp16CPUKernel() override { | |||
| if (fp16_input_ != nullptr) { | |||
| free(fp16_input_); | |||
| } | |||
| if (fp16_output_ != nullptr) { | |||
| free(fp16_output_); | |||
| } | |||
| }; | |||
| int Init() override; | |||
| int InitBuffer(); | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int RunImpl(int task_id); | |||
| private: | |||
| float16_t *fp16_input_ = nullptr; | |||
| float16_t *fp16_output_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_POOLING_FP16_H_ | |||
| @@ -21,15 +21,8 @@ | |||
| #include "src/runtime/kernel/arm/base/pooling_base.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "ir/anf.h" | |||
| #include "include/context.h" | |||
| namespace mindspore::kernel { | |||
| using mindspore::lite::Context; | |||
| using mindspore::schema::PadMode; | |||
| using mindspore::schema::PoolMode; | |||
| using mindspore::schema::QuantType; | |||
| using mindspore::schema::RoundMode; | |||
| class PoolingCPUKernel : public PoolingBaseCPUKernel { | |||
| public: | |||
| PoolingCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| void Float32ToFloat16(const float *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float16_t)input[i]; | |||
| } | |||
| } | |||
| void Float16ToFloat32(const float16_t *input, float *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float)input[i]; | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/fp32/cast.h" | |||
| void Float32ToFloat16(const float *input, float16_t *output, int number); | |||
| void Float16ToFloat32(const float16_t *input, float *output, int number); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | |||
| @@ -0,0 +1,276 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/pooling_fp16.h" | |||
| #include <float.h> | |||
| void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| int c8 = channel / C8NUM; | |||
| int c8_res = channel % C8NUM; | |||
| int c4 = c8_res / C4NUM; | |||
| int in_w = pooling_param->input_w_; | |||
| int in_h = pooling_param->input_h_; | |||
| int output_w = pooling_param->output_w_; | |||
| int output_h = pooling_param->output_h_; | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = pooling_param->thread_num_; | |||
| // input channel is equal to output channel | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| int in_batch_offset = batch * in_h * in_w * channel; | |||
| int out_batch_offset = batch * output_h * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int out_plane_offset = out_batch_offset + index * channel; | |||
| for (int j = 0; j < c8; j++) { | |||
| int in_channel_offset = in_batch_offset + j * C8NUM; | |||
| int out_channel_offset = out_plane_offset + j * C8NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t tmp_avg = vdupq_n_f16(0); | |||
| #else | |||
| float16_t tmp_avg[8]{0}; | |||
| #endif | |||
| int real_count = 0; | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| } else { | |||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||
| #ifdef ENABLE_NEON | |||
| tmp_avg = vaddq_f16(tmp_avg, vld1q_f16(input_ptr + in_offset)); | |||
| #else | |||
| for (int t = 0; t < 8; t++) { | |||
| tmp_avg[t] += *(input_ptr + in_offset + t); | |||
| } | |||
| #endif | |||
| ++real_count; | |||
| } | |||
| } // win_w loop | |||
| } // win_h loop | |||
| #ifdef ENABLE_NEON | |||
| vst1q_f16(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f16(real_count)); | |||
| #else | |||
| for (int t = 0; t < C8NUM; ++t) { | |||
| *(output_ptr + out_channel_offset + t) = tmp_avg[t] / (float16_t)real_count; | |||
| } | |||
| #endif | |||
| } // c8 loop | |||
| int c4_offset = c8 * C8NUM; | |||
| for (int l = 0; l < c4; ++l) { | |||
| int in_channel_offset = in_batch_offset + c4_offset + l * C4NUM; | |||
| int out_channel_offset = out_plane_offset + c4_offset + l * C4NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x4_t tmp_avg = vdup_n_f16(0); | |||
| #else | |||
| float16_t tmp_avg[4]{0}; | |||
| #endif | |||
| int real_count = 0; | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| } else { | |||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||
| #ifdef ENABLE_NEON | |||
| tmp_avg = vadd_f16(tmp_avg, vld1_f16(input_ptr + in_offset)); | |||
| #else | |||
| for (int j = 0; j < C4NUM; ++j) { | |||
| tmp_avg[j] += *(input_ptr + in_offset); | |||
| } | |||
| #endif | |||
| ++real_count; | |||
| } | |||
| } // win_w loop | |||
| } // win_h loop | |||
| #ifdef ENABLE_NEON | |||
| vst1_f16(output_ptr + out_channel_offset, tmp_avg / vdup_n_f16(real_count)); | |||
| #else | |||
| for (int t = 0; t < C4NUM; ++t) { | |||
| *(output_ptr + out_channel_offset + t) = tmp_avg[t] / (float16_t)real_count; | |||
| } | |||
| #endif | |||
| } // c4 loop | |||
| int channel_s = c8 * C8NUM + c4 * C4NUM; | |||
| for (int k = channel_s; k < channel; k++) { | |||
| int in_channel_offset = in_batch_offset + k; | |||
| int out_channel_offset = out_plane_offset + k; | |||
| float16_t tmp_avg = 0; | |||
| int real_count = 0; | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| } else { | |||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||
| tmp_avg += *(input_ptr + in_offset); | |||
| ++real_count; | |||
| } | |||
| } // win_w loop | |||
| } // win_h loop | |||
| *(output_ptr + out_channel_offset) = tmp_avg / (float16_t)real_count; | |||
| } // channel_res loop | |||
| } // real_cal_num loop | |||
| } // out_plane loop | |||
| } // out_batch loop | |||
| } | |||
| void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| int pad_h = pooling_param->pad_u_; | |||
| int win_w = pooling_param->window_w_; | |||
| int win_h = pooling_param->window_h_; | |||
| int channel = pooling_param->input_channel_; | |||
| int in_w = pooling_param->input_w_; | |||
| int in_h = pooling_param->input_h_; | |||
| int output_w = pooling_param->output_w_; | |||
| int output_h = pooling_param->output_h_; | |||
| int output_batch = pooling_param->output_batch_; | |||
| int out_plane = output_w * output_h; | |||
| int out_tile_count = UP_DIV(out_plane, TILE_NUM); | |||
| int thread_num = pooling_param->thread_num_; | |||
| int c8 = channel / C8NUM; | |||
| int c8_res = channel % C8NUM; | |||
| int c4 = c8_res / C4NUM; | |||
| // input channel is equal to output channel | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| int in_batch_offset = batch * in_h * in_w * channel; | |||
| int out_batch_offset = batch * output_h * output_w * channel; | |||
| for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { | |||
| int cal_start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int index = cal_start_index + i; | |||
| int out_w_index = index % output_w; | |||
| int out_h_index = index / output_w; | |||
| int in_w_index = out_w_index * stride_w - pad_w; | |||
| int in_h_index = out_h_index * stride_h - pad_h; | |||
| int out_plane_offset = out_batch_offset + index * channel; | |||
| for (int j = 0; j < c8; j++) { | |||
| int in_channel_offset = in_batch_offset + j * C8NUM; | |||
| int out_channel_offset = out_plane_offset + j * C8NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t tmp_max = vdupq_n_f16(-FLT_MAX); | |||
| #else | |||
| float16_t tmp_max[8]{-FLT_MAX}; | |||
| #endif | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| } else { | |||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||
| #ifdef ENABLE_NEON | |||
| tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset)); | |||
| #else | |||
| for (int k = 0; k < C8NUM; k++) { | |||
| tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); | |||
| } | |||
| #endif | |||
| } | |||
| } // win_w loop | |||
| } // win_h loop | |||
| #ifdef ENABLE_NEON | |||
| vst1q_f16(output_ptr + out_channel_offset, tmp_max); | |||
| #else | |||
| for (int l = 0; l < C8NUM; ++l) { | |||
| *(output_ptr + out_channel_offset + l) = tmp_max[l]; | |||
| } | |||
| #endif | |||
| } // c8 loop | |||
| int c4_offset = c8 * C8NUM; | |||
| for (int j = 0; j < c4; j++) { | |||
| int in_channel_offset = in_batch_offset + c4_offset + j * C4NUM; | |||
| int out_channel_offset = out_plane_offset + c4_offset + j * C4NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x4_t tmp_max = vdup_n_f16(-FLT_MAX); | |||
| #else | |||
| float16_t tmp_max[4]{-FLT_MAX}; | |||
| #endif | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| } else { | |||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||
| #ifdef ENABLE_NEON | |||
| tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset)); | |||
| #else | |||
| for (int k = 0; k < C4NUM; k++) { | |||
| tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); | |||
| } | |||
| #endif | |||
| } | |||
| } // win_w loop | |||
| } // win_h loop | |||
| #ifdef ENABLE_NEON | |||
| vst1_f16(output_ptr + out_channel_offset, tmp_max); | |||
| #else | |||
| for (int l = 0; l < C4NUM; ++l) { | |||
| *(output_ptr + out_channel_offset + l) = tmp_max[l]; | |||
| } | |||
| #endif | |||
| } // c4 loop | |||
| int channel_s = c8 * C8NUM + c4 * C4NUM; | |||
| for (int k = channel_s; k < channel; k++) { | |||
| int in_channel_offset = in_batch_offset + k; | |||
| int out_channel_offset = out_plane_offset + k; | |||
| float16_t tmp_max = -FLT_MAX; | |||
| for (int h = 0; h < win_h; h++) { | |||
| for (int w = 0; w < win_w; w++) { | |||
| if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || | |||
| (in_w_index + w) >= in_w) { | |||
| continue; | |||
| } else { | |||
| int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; | |||
| tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); | |||
| } | |||
| } // win_w loop | |||
| } // win_h loop | |||
| *(output_ptr + out_channel_offset) = tmp_max; | |||
| } // channel_res loop | |||
| } // real_cal_num loop | |||
| } // out_plane loop | |||
| } // out_batch loop | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_POOLING_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_POOLING_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/pooling_parameter.h" | |||
| void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id); | |||
| void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_POOLING_FP16_H_ | |||
| @@ -45,17 +45,3 @@ void Float32ToInt32(const float *input, int32_t *output, int number) { | |||
| output[i] = (int32_t)input[i]; | |||
| } | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| void Float32ToFloat16(const float *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float16_t)input[i]; | |||
| } | |||
| } | |||
| void Float16ToFloat32(const float16_t *input, float *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (float)input[i]; | |||
| } | |||
| } | |||
| #endif | |||
| @@ -33,9 +33,5 @@ void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); | |||
| void Int8ToUint8(const int8_t *input, uint8_t *output, int number); | |||
| void Int32ToFloat32(const int32_t *input, float *output, int number); | |||
| void Float32ToInt32(const float *input, int32_t *output, int number); | |||
| #ifdef ENABLE_FP16 | |||
| void Float32ToFloat16(const float *input, float16_t *output, int number); | |||
| void Float16ToFloat32(const float16_t *input, float *output, int number); | |||
| #endif | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_H_ | |||
| @@ -21,35 +21,9 @@ | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/pooling_parameter.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef struct PoolingParameter { | |||
| OpParameter op_parameter_; | |||
| QuantArg **quant_args_; | |||
| bool global_; | |||
| bool max_pooling_; | |||
| bool avg_pooling_; | |||
| bool round_ceil_; | |||
| bool round_floor_; | |||
| int window_w_; | |||
| int window_h_; | |||
| int input_w_; | |||
| int input_h_; | |||
| int input_batch_; | |||
| int input_channel_; | |||
| int output_w_; | |||
| int output_h_; | |||
| int output_batch_; | |||
| int output_channel_; | |||
| int pad_u_; | |||
| int pad_d_; | |||
| int pad_l_; | |||
| int pad_r_; | |||
| int stride_w_; | |||
| int stride_h_; | |||
| int thread_num_; | |||
| } PoolingParameter; | |||
| void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); | |||
| void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POOLING_PARAMETER_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POOLING_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef struct PoolingParameter { | |||
| OpParameter op_parameter_; | |||
| QuantArg **quant_args_; | |||
| bool global_; | |||
| bool max_pooling_; | |||
| bool avg_pooling_; | |||
| bool round_ceil_; | |||
| bool round_floor_; | |||
| int window_w_; | |||
| int window_h_; | |||
| int input_w_; | |||
| int input_h_; | |||
| int input_batch_; | |||
| int input_channel_; | |||
| int output_w_; | |||
| int output_h_; | |||
| int output_batch_; | |||
| int output_channel_; | |||
| int pad_u_; | |||
| int pad_d_; | |||
| int pad_l_; | |||
| int pad_r_; | |||
| int stride_w_; | |||
| int stride_h_; | |||
| int thread_num_; | |||
| } PoolingParameter; | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POOLING_PARAMETER_H_ | |||