| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||||
| #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | |||||
| int Convolution1x1FP16CPUKernel::Init() { | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int Convolution1x1FP16CPUKernel::ReSize() { | |||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| if (fp16_input_ != nullptr) { | |||||
| free(fp16_input_); | |||||
| } | |||||
| if (nhwc4_input_ != nullptr) { | |||||
| free(nhwc4_input_); | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int Convolution1x1FP16CPUKernel::RunImpl(int task_id) { | |||||
| // Conv1x1Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_, | |||||
| // reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, | |||||
| // tmp_dst_buffer_, tmp_out_, task_id, conv_param_); | |||||
| return RET_OK; | |||||
| } | |||||
| int Convolution1x1Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto conv = reinterpret_cast<Convolution1x1FP16CPUKernel *>(cdata); | |||||
| auto error_code = conv->RunImpl(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convolution1x1 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int Convolution1x1FP16CPUKernel::Run() { | |||||
| auto ret = Prepare(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||||
| int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_ | |||||
| #include <arm_neon.h> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | |||||
| Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~Convolution1x1FP16CPUKernel() override { | |||||
| if (fp16_input_ != nullptr) { | |||||
| free(fp16_input_); | |||||
| } | |||||
| if (fp16_weight_ != nullptr) { | |||||
| free(fp16_weight_); | |||||
| } | |||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| } | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int RunImpl(int task_id); | |||||
| private: | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_ | |||||
| @@ -52,8 +52,6 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara | |||||
| int Convolution3x3FP16CPUKernel::InitWeightBias() { | int Convolution3x3FP16CPUKernel::InitWeightBias() { | ||||
| auto input_channel = conv_param_->input_channel_; | auto input_channel = conv_param_->input_channel_; | ||||
| int output_channel = conv_param_->output_channel_; | int output_channel = conv_param_->output_channel_; | ||||
| int kernel_h = conv_param_->kernel_h_; | |||||
| int kernel_w = conv_param_->kernel_w_; | |||||
| int iC4 = UP_DIV(input_channel, C4NUM); | int iC4 = UP_DIV(input_channel, C4NUM); | ||||
| int oC8 = UP_DIV(output_channel, C8NUM); | int oC8 = UP_DIV(output_channel, C8NUM); | ||||
| // init weight | // init weight | ||||
| @@ -64,18 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(transformed_filter_addr_, 0, transformed_size); | memset(transformed_filter_addr_, 0, transformed_size); | ||||
| float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data()); | |||||
| size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t); | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||||
| if (fp16_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(fp16_weight_, 0, fp16_weight_size); | |||||
| for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | |||||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | |||||
| } | |||||
| ProcessFilterFp16(fp16_weight_, transformed_filter_addr_, conv_param_); | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||||
| ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_); | |||||
| // init bias | // init bias | ||||
| size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); | size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); | ||||
| @@ -183,10 +171,6 @@ void Convolution3x3FP16CPUKernel::ConfigInputOutput() { | |||||
| } | } | ||||
| int Convolution3x3FP16CPUKernel::Init() { | int Convolution3x3FP16CPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | auto ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | MS_LOG(ERROR) << "ConvolutionBase init failed."; | ||||
| @@ -244,8 +228,8 @@ int Convolution3x3FP16CPUKernel::ReSize() { | |||||
| int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { | int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { | ||||
| Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_, | Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_, | ||||
| reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, | |||||
| tmp_out_, task_id, conv_param_); | |||||
| reinterpret_cast<float16_t *>(bias_data_), execute_output_, tile_buffer_, block_unit_buffer_, | |||||
| tmp_dst_buffer_, tmp_out_, task_id, conv_param_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -265,16 +249,13 @@ int Convolution3x3FP16CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Prepare failed."; | MS_LOG(ERROR) << "Prepare failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto input_ele_num = input_tensor->ElementsNum(); | |||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||||
| int in_batch = conv_param_->input_batch_; | int in_batch = conv_param_->input_batch_; | ||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -294,7 +275,7 @@ int Convolution3x3FP16CPUKernel::Run() { | |||||
| batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_; | batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_; | ||||
| int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_; | int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_; | ||||
| const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset; | const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset; | ||||
| float16_t *batch_out = fp16_out_ + ro_batch_size; | |||||
| float16_t *batch_out = execute_output_ + ro_batch_size; | |||||
| for (int h = 0; h < conv_param_->output_h_; h++) { | for (int h = 0; h < conv_param_->output_h_; h++) { | ||||
| for (int w = 0; w < conv_param_->output_w_; w++) { | for (int w = 0; w < conv_param_->output_w_; w++) { | ||||
| for (int c = 0; c < conv_param_->output_channel_; c++) { | for (int c = 0; c < conv_param_->output_channel_; c++) { | ||||
| @@ -315,11 +296,7 @@ int Convolution3x3FP16CPUKernel::Run() { | |||||
| } | } | ||||
| } | } | ||||
| // cast fp16 out to fp32 data | |||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||||
| auto out_ele_num = out_tensor->ElementsNum(); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -20,16 +20,16 @@ | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | ||||
| const lite::Primitive *primitive) | const lite::Primitive *primitive) | ||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~Convolution3x3FP16CPUKernel() override { | ~Convolution3x3FP16CPUKernel() override { | ||||
| if (fp16_input_ != nullptr) { | if (fp16_input_ != nullptr) { | ||||
| free(fp16_input_); | free(fp16_input_); | ||||
| @@ -66,9 +66,6 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| void ConfigInputOutput(); | void ConfigInputOutput(); | ||||
| private: | private: | ||||
| float16_t *fp16_input_; | |||||
| float16_t *fp16_weight_; | |||||
| float16_t *fp16_out_; | |||||
| float16_t *transformed_filter_addr_; | float16_t *transformed_filter_addr_; | ||||
| float16_t *tile_buffer_; | float16_t *tile_buffer_; | ||||
| float16_t *block_unit_buffer_; | float16_t *block_unit_buffer_; | ||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_factory.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| namespace mindspore::kernel { | |||||
| int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() { | |||||
| // ===================input====================// | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto input_data_type = input_tensor->data_type(); | |||||
| MS_ASSERT(input_data_type == kNumberTypeFloat32 || input_data_type == kNumberTypeFloat16); | |||||
| if (input_data_type == kNumberTypeFloat32) { | |||||
| auto input_ele_num = input_tensor->ElementsNum(); | |||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||||
| execute_input_ = fp16_input_; | |||||
| } else { | |||||
| auto ori_input_data = reinterpret_cast<float16_t *>(input_tensor->Data()); | |||||
| execute_input_ = ori_input_data; | |||||
| } | |||||
| // ==================output====================// | |||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||||
| auto out_data_type = out_tensor->data_type(); | |||||
| MS_ASSERT(out_data_type == kNumberTypeFloat32 || out_data_type == kNumberTypeFloat16); | |||||
| out_data_type_ = out_data_type; | |||||
| if (out_data_type == kNumberTypeFloat32) { | |||||
| execute_output_ = fp16_out_; | |||||
| } else { | |||||
| auto out_ptr = reinterpret_cast<float16_t *>(out_tensor->Data()); | |||||
| execute_output_ = out_ptr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() { | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| auto weight_data_type = weight_tensor->data_type(); | |||||
| MS_ASSERT(weight_data_type == kNumberTypeFloat32 || weight_data_type == kNumberTypeFloat16); | |||||
| if (weight_data_type == kNumberTypeFloat32) { | |||||
| float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data()); | |||||
| size_t fp16_weight_size = conv_param_->input_channel_ * conv_param_->output_channel_ * conv_param_->kernel_h_ * | |||||
| conv_param_->input_w_ * sizeof(float16_t); | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||||
| if (fp16_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | |||||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | |||||
| } | |||||
| execute_weight_ = fp16_weight_; | |||||
| } else { | |||||
| auto *origin_weight = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->Data()); | |||||
| execute_weight_ = origin_weight; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void ConvolutionBaseFP16CPUKernel::IfCastOutput() { | |||||
| if (out_data_type_ == kNumberTypeFloat32) { | |||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||||
| auto out_ele_num = out_tensor->ElementsNum(); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_ | |||||
| #include <arm_neon.h> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class ConvolutionBaseFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| public: | |||||
| ConvolutionBaseFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~ConvolutionBaseFP16CPUKernel() override = default; | |||||
| int Init() override { return RET_OK; } | |||||
| int ReSize() override { return RET_OK; } | |||||
| int Run() override { return RET_OK; } | |||||
| int RunImpl(int task_id) { return RET_OK; } | |||||
| virtual int GetExecuteTensor(); | |||||
| virtual int GetExecuteFilter(); | |||||
| virtual void IfCastOutput(); | |||||
| protected: | |||||
| float16_t *fp16_input_ = nullptr; | |||||
| float16_t *fp16_weight_ = nullptr; | |||||
| float16_t *fp16_out_ = nullptr; | |||||
| float16_t *execute_input_; | |||||
| float16_t *execute_weight_; | |||||
| float16_t *execute_output_; | |||||
| TypeId out_data_type_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_ | |||||
| @@ -102,10 +102,6 @@ int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int ConvolutionDepthwiseFp16CPUKernel::Init() { | int ConvolutionDepthwiseFp16CPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| // conv base init | // conv base init | ||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | auto ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -46,24 +46,14 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; | int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; | ||||
| // init weight | // init weight | ||||
| float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data()); | |||||
| size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||||
| if (fp16_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | |||||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||||
| packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t))); | packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t))); | ||||
| if (packed_weight_ == nullptr) { | if (packed_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc packed_weight_ failed."; | MS_LOG(ERROR) << "malloc packed_weight_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | ||||
| PackWeightFp16(fp16_weight_, conv_param_, packed_weight_); | |||||
| PackWeightFp16(execute_weight_, conv_param_, packed_weight_); | |||||
| // init bias | // init bias | ||||
| bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); | bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); | ||||
| @@ -157,10 +147,6 @@ void ConvolutionFP16CPUKernel::ConfigInputOutput() { | |||||
| } | } | ||||
| int ConvolutionFP16CPUKernel::Init() { | int ConvolutionFP16CPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | auto ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; | MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; | ||||
| @@ -212,7 +198,7 @@ int ConvolutionFP16CPUKernel::ReSize() { | |||||
| int ConvolutionFP16CPUKernel::RunImpl(int task_id) { | int ConvolutionFP16CPUKernel::RunImpl(int task_id) { | ||||
| ConvFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_input_, packed_weight_, | ConvFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_input_, packed_weight_, | ||||
| reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, fp16_out_, task_id, conv_param_); | |||||
| reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, execute_output_, task_id, conv_param_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -232,16 +218,13 @@ int ConvolutionFP16CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Prepare failed."; | MS_LOG(ERROR) << "Prepare failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||||
| auto input_ele_num = input_tensor->ElementsNum(); | |||||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||||
| int in_batch = conv_param_->input_batch_; | int in_batch = conv_param_->input_batch_; | ||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(ConvolutionFp16Impl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(ConvolutionFp16Impl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -249,11 +232,7 @@ int ConvolutionFP16CPUKernel::Run() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // cast fp16 out to fp32 data | |||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| auto out_ele_num = out_tensor->ElementsNum(); | |||||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -20,15 +20,15 @@ | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | ||||
| const lite::Primitive *primitive) | const lite::Primitive *primitive) | ||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~ConvolutionFP16CPUKernel() override { | ~ConvolutionFP16CPUKernel() override { | ||||
| if (fp16_input_ != nullptr) { | if (fp16_input_ != nullptr) { | ||||
| free(fp16_input_); | free(fp16_input_); | ||||
| @@ -59,9 +59,6 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| void ConfigInputOutput(); | void ConfigInputOutput(); | ||||
| private: | private: | ||||
| float16_t *fp16_input_; | |||||
| float16_t *fp16_weight_; | |||||
| float16_t *fp16_out_; | |||||
| float16_t *packed_input_; | float16_t *packed_input_; | ||||
| float16_t *packed_weight_; | float16_t *packed_weight_; | ||||
| float16_t *tmp_output_block_; | float16_t *tmp_output_block_; | ||||
| @@ -39,23 +39,13 @@ int ConvolutionSWFP16CPUKernel::ProcessFilter() { | |||||
| int out_channel = conv_param_->output_channel_; | int out_channel = conv_param_->output_channel_; | ||||
| int ic4 = UP_DIV(in_channel, C4NUM); | int ic4 = UP_DIV(in_channel, C4NUM); | ||||
| auto *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data()); | |||||
| size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||||
| if (fp16_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // cast origin fp32 weight data to fp16 data | |||||
| for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | |||||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||||
| for (int oc = 0; oc < out_channel; ++oc) { | for (int oc = 0; oc < out_channel; ++oc) { | ||||
| int src_oc_offset = oc * kernel_h * kernel_w * in_channel; | int src_oc_offset = oc * kernel_h * kernel_w * in_channel; | ||||
| int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM; | int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM; | ||||
| for (int i = 0; i < kernel_h * kernel_w; ++i) { | for (int i = 0; i < kernel_h * kernel_w; ++i) { | ||||
| const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel; | |||||
| const float16_t *src = execute_weight_ + src_oc_offset + i * in_channel; | |||||
| float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM; | float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM; | ||||
| memcpy(dst, src, in_channel * sizeof(float16_t)); | memcpy(dst, src, in_channel * sizeof(float16_t)); | ||||
| } | } | ||||
| @@ -162,10 +152,6 @@ void ConvolutionSWFP16CPUKernel::ConfigInputOutput() { | |||||
| } | } | ||||
| int ConvolutionSWFP16CPUKernel::Init() { | int ConvolutionSWFP16CPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | auto ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; | MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; | ||||
| @@ -222,7 +208,7 @@ int ConvolutionSWFP16CPUKernel::ReSize() { | |||||
| int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) { | int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) { | ||||
| ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_), | ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_), | ||||
| tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_); | |||||
| tmp_output_block_, execute_output_, task_id, conv_param_, slidingWindow_param_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -242,16 +228,13 @@ int ConvolutionSWFP16CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Prepare failed."; | MS_LOG(ERROR) << "Prepare failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto input_ele_num = input_tensor->ElementsNum(); | |||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||||
| int in_batch = conv_param_->input_batch_; | int in_batch = conv_param_->input_batch_; | ||||
| int in_h = conv_param_->input_h_; | int in_h = conv_param_->input_h_; | ||||
| int in_w = conv_param_->input_w_; | int in_w = conv_param_->input_w_; | ||||
| int in_channel = conv_param_->input_channel_; | int in_channel = conv_param_->input_channel_; | ||||
| convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_); | int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| @@ -259,18 +242,14 @@ int ConvolutionSWFP16CPUKernel::Run() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // cast fp16 out to fp32 data | |||||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||||
| auto out_ele_num = out_tensor->ElementsNum(); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||||
| // output nhwc4 | // output nhwc4 | ||||
| int oc4_res = conv_param_->output_channel_ % C4NUM; | int oc4_res = conv_param_->output_channel_ % C4NUM; | ||||
| if (oc4_res != 0) { | if (oc4_res != 0) { | ||||
| PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(fp16_out_), | |||||
| PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(execute_output_), | |||||
| conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_, | conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_, | ||||
| conv_param_->output_channel_); | conv_param_->output_channel_); | ||||
| } | } | ||||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -19,15 +19,15 @@ | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| class ConvolutionSWFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | ||||
| const lite::Primitive *primitive) | const lite::Primitive *primitive) | ||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~ConvolutionSWFP16CPUKernel() override { | ~ConvolutionSWFP16CPUKernel() override { | ||||
| if (fp16_input_ != nullptr) { | if (fp16_input_ != nullptr) { | ||||
| free(fp16_input_); | free(fp16_input_); | ||||
| @@ -57,9 +57,6 @@ class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int ProcessFilter(); | int ProcessFilter(); | ||||
| private: | private: | ||||
| float16_t *fp16_input_; | |||||
| float16_t *fp16_weight_; | |||||
| float16_t *fp16_out_; | |||||
| float16_t *packed_weight_; | float16_t *packed_weight_; | ||||
| float16_t *tmp_output_block_; | float16_t *tmp_output_block_; | ||||
| SlidingWindowParam *slidingWindow_param_; | SlidingWindowParam *slidingWindow_param_; | ||||
| @@ -0,0 +1,409 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h" | |||||
| #include "src/runtime/kernel/arm/fp16/matrix_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/common_func.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/winograd_utils_fp16.h" | |||||
| #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | |||||
| void WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit, | |||||
| ConvParameter *conv_param, int oc_block) { | |||||
| // original weight format : ohwi | |||||
| auto channel_in = conv_param->input_channel_; | |||||
| auto channel_out = conv_param->output_channel_; | |||||
| int input_unit_square = input_unit * input_unit; | |||||
| // generate matrix_G && matrix_GT | |||||
| auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit); | |||||
| auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit); | |||||
| ChooseMatrixG(matrix_g, matrix_gt); | |||||
| auto matrix_g_data = reinterpret_cast<float *>(matrix_g->GetData()); | |||||
| auto matrix_gt_data = reinterpret_cast<float *>(matrix_gt->GetData()); | |||||
| auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t))); | |||||
| auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t))); | |||||
| Float32ToFloat16(matrix_g_data, matrix_g_data_fp16, input_unit * kernel_unit); | |||||
| Float32ToFloat16(matrix_gt_data, matrix_gt_data_fp16, input_unit * kernel_unit); | |||||
| // trans_filter = G*g*GT (g represents weight_data) | |||||
| // separate into two steps ===> tmp = G*g ===> out = tmp * GT | |||||
| auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit * kernel_unit * sizeof(float16_t))); | |||||
| auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t))); | |||||
| auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit * input_unit * sizeof(float16_t))); | |||||
| bool row = true; | |||||
| auto trans_weight_data = reinterpret_cast<float16_t *>(trans_weight->GetData()); | |||||
| std::vector<int> strides = trans_weight->GetStride(); | |||||
| int kernel_plane_stride = channel_in; | |||||
| if (oc_block == 0) { | |||||
| MS_LOG(ERROR) << "Divide by zero"; | |||||
| return; | |||||
| } | |||||
| for (int i = 0; i < channel_out; i++) { | |||||
| int out_c_block = i / oc_block; | |||||
| int out_c_res = i % oc_block; | |||||
| int input_oz_offset = i * kernel_unit * kernel_unit * channel_in; | |||||
| int output_oz_offset = out_c_block * strides[1] * input_unit * input_unit + out_c_res; | |||||
| for (int j = 0; j < channel_in; j++) { | |||||
| int ic4_block = j / C4NUM; | |||||
| int ic4_res = j % C4NUM; | |||||
| int input_iz_offset = input_oz_offset + j; | |||||
| int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; | |||||
| for (int k = 0; k < kernel_unit * kernel_unit; k++) { | |||||
| int input_xy_offset = input_iz_offset + k * kernel_plane_stride; | |||||
| tmp_weight_data[k] = *(weight_data + input_xy_offset); | |||||
| } | |||||
| // now we only support row-major matrix-multiply | |||||
| // tmp = G * g | |||||
| MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row); | |||||
| // out = tmp * GT | |||||
| MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, row); | |||||
| for (int z = 0; z < input_unit_square; z++) { | |||||
| int output_xy_offset = output_iz_offset + z * strides[1]; | |||||
| *(trans_weight_data + output_xy_offset) = trans_out_data[z]; | |||||
| } | |||||
| } | |||||
| } | |||||
| free(tmp_weight_data); | |||||
| free(tmp_data); | |||||
| free(trans_out_data); | |||||
| free(matrix_g_data_fp16); | |||||
| free(matrix_gt_data_fp16); | |||||
| delete matrix_g; | |||||
| delete matrix_gt; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| int output_channel = conv_param_->output_channel_; | |||||
| int oc_block, oc_block_num; | |||||
| oc_block = C8NUM; | |||||
| oc_block_num = UP_DIV(output_channel, C8NUM); | |||||
| // init weight | |||||
| auto ret = MallocFilterMatrix(oc_block, oc_block_num); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Malloc filter matrix failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||||
| WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block); | |||||
| // init bias | |||||
| bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); | |||||
| if (bias_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); | |||||
| auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_); | |||||
| if (in_tensors_.size() == kInputSize2) { | |||||
| auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data()); | |||||
| for (int i = 0; i < output_channel; ++i) { | |||||
| fp16_bias_data[i] = (float16_t)ori_bias[i]; | |||||
| } | |||||
| } else { | |||||
| MS_ASSERT(inputs_.size() == kInputSize1); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) { | |||||
| int channel_in = conv_param_->input_channel_; | |||||
| int ic4 = UP_DIV(channel_in, BLOCK); | |||||
| // set data | |||||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float); | |||||
| auto matrix_buffer = malloc(trans_matrix_data_size); | |||||
| if (matrix_buffer == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc matrix_buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(matrix_buffer, 0, trans_matrix_data_size); | |||||
| trans_weight_ = new Matrix(); | |||||
| trans_weight_->SetData(matrix_buffer); | |||||
| trans_weight_->SetNDim(5); | |||||
| std::vector<int> shapes; | |||||
| std::vector<int> strides; | |||||
| // set shape | |||||
| shapes.push_back(input_unit_ * input_unit_); | |||||
| shapes.push_back(oc_block_num); | |||||
| shapes.push_back(ic4); | |||||
| shapes.push_back(C4NUM); | |||||
| shapes.push_back(oc_block); | |||||
| // set stride | |||||
| for (int i = 0; i < 4; i++) { | |||||
| int stride = 1; | |||||
| for (int j = i + 1; j < 5; j++) { | |||||
| stride *= shapes[j]; | |||||
| } | |||||
| strides.push_back(stride); | |||||
| } | |||||
| trans_weight_->SetShape(shapes); | |||||
| trans_weight_->SetStride(strides); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||||
| int cal_num = 16; | |||||
| int channel_in = conv_param_->input_channel_; | |||||
| int channel_out = conv_param_->output_channel_; | |||||
| int output_h = conv_param_->output_h_; | |||||
| int output_w = conv_param_->output_w_; | |||||
| int ic4 = UP_DIV(channel_in, C4NUM); | |||||
| int oc8 = UP_DIV(channel_out, C8NUM); | |||||
| /*=============================fp16_input_============================*/ | |||||
| size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * | |||||
| conv_param_->input_w_ * sizeof(float16_t); | |||||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | |||||
| if (fp16_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| /*=============================trans_input_============================*/ | |||||
| size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float16_t); | |||||
| trans_input_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size)); | |||||
| if (trans_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(trans_input_, 0, tile_buffer_size); | |||||
| /*=============================gemm_out_============================*/ | |||||
| gemm_out_ = reinterpret_cast<float16_t *>( | |||||
| malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t))); | |||||
| if (gemm_out_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc gemm_out_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| /*=============================tmp_out_data_============================*/ | |||||
| int out_w_block = UP_DIV(output_w, output_unit_); | |||||
| int out_h_block = UP_DIV(output_h, output_unit_); | |||||
| tmp_out_data_ = reinterpret_cast<float16_t *>(malloc(conv_param_->output_batch_ * out_w_block * out_h_block * | |||||
| output_unit_ * output_unit_ * oc8 * C8NUM * sizeof(float16_t))); | |||||
| if (tmp_out_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| /*=============================fp16_out_============================*/ | |||||
| size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ * | |||||
| conv_param_->output_w_ * sizeof(float16_t); | |||||
| fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size)); | |||||
| if (fp16_out_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc fp16_out_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| /*=============================tmp_data_============================*/ | |||||
| tmp_data_ = | |||||
| reinterpret_cast<float16_t *>(malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float16_t))); | |||||
| if (tmp_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc tmp_data_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(tmp_data_, 0, C4NUM * input_unit_ * input_unit_ * sizeof(float16_t)); | |||||
| tmp_buffer_address_list_[0] = trans_input_; | |||||
| tmp_buffer_address_list_[1] = gemm_out_; | |||||
| tmp_buffer_address_list_[2] = tmp_out_data_; | |||||
| tmp_buffer_address_list_[3] = tmp_data_; | |||||
| /*=============================nhwc4_input_============================*/ | |||||
| size_t nhwc4_input_size = | |||||
| ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | |||||
| nhwc4_input_ = malloc(nhwc4_input_size); | |||||
| if (nhwc4_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() { | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ret = CheckLayout(input_tensor); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Check layout failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | |||||
| output_tensor->SetFormat(schema::Format_NHWC); | |||||
| // choose input transformer function (4x4 unit or 8x8 unit) | |||||
| input_trans_func_ = GetInputTransFuncFp16(input_unit_); | |||||
| if (input_trans_func_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Get input_trans_func failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| output_trans_func_ = GetOutputTransFuncFp16(input_unit_, output_unit_); | |||||
| if (output_trans_func_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Get output_trans_func_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::Init() { | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| kernel_unit_ = conv_param_->kernel_h_; | |||||
| input_unit_ = output_unit_ + kernel_unit_ - 1; | |||||
| conv_param_->input_unit_ = input_unit_; | |||||
| conv_param_->output_unit_ = output_unit_; | |||||
| ret = InitWeightBias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init weight bias failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // malloc tmp buffer | |||||
| ret = InitTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init tmp buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = ConfigInputOutput(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConfigInputOutput failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::ReSize() { | |||||
| if (tmp_data_ != nullptr) { | |||||
| free(tmp_data_); | |||||
| } | |||||
| if (trans_input_ != nullptr) { | |||||
| free(trans_input_); | |||||
| } | |||||
| if (gemm_out_ != nullptr) { | |||||
| free(gemm_out_); | |||||
| } | |||||
| if (tmp_out_data_ != nullptr) { | |||||
| free(tmp_out_data_); | |||||
| } | |||||
| if (nhwc4_input_ != nullptr) { | |||||
| free(nhwc4_input_); | |||||
| } | |||||
| if (fp16_input_ != nullptr) { | |||||
| free(fp16_input_); | |||||
| } | |||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| kernel_unit_ = conv_param_->kernel_h_; | |||||
| input_unit_ = output_unit_ + kernel_unit_ - 1; | |||||
| conv_param_->input_unit_ = input_unit_; | |||||
| conv_param_->output_unit_ = output_unit_; | |||||
| ret = InitTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init tmp buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = ConfigInputOutput(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConfigInputOutput failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) { | |||||
| ConvWinogardFp16(reinterpret_cast<float16_t *>(nhwc4_input_), reinterpret_cast<float16_t *>(trans_weight_->GetData()), | |||||
| reinterpret_cast<const float16_t *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_, | |||||
| input_trans_func_, output_trans_func_); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto conv = reinterpret_cast<ConvolutionWinogradFP16CPUKernel *>(cdata); | |||||
| auto error_code = conv->RunImpl(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvolutionWinograd Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionWinogradFP16CPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||||
| int in_batch = conv_param_->input_batch_; | |||||
| int in_h = conv_param_->input_h_; | |||||
| int in_w = conv_param_->input_w_; | |||||
| int in_channel = conv_param_->input_channel_; | |||||
| convert_func_(execute_input_, nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||||
| int error_code = LiteBackendParallelLaunch(ConvolutionWinogradFp16Impl, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // get real output | |||||
| UnPackWinogradOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, | |||||
| conv_param_->output_w_, conv_param_->output_channel_, output_unit_); | |||||
| int output_num = | |||||
| conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; | |||||
| if (conv_param_->is_relu_) { | |||||
| ReluFp16(execute_output_, execute_output_, output_num); | |||||
| } else if (conv_param_->is_relu6_) { | |||||
| Relu6Fp16(execute_output_, execute_output_, output_num); | |||||
| } else { | |||||
| // do nothing | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_ | |||||
| #include <arm_neon.h> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||||
| #include "src/runtime/kernel/arm/fp16/matrix_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/winograd_utils_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | |||||
| ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~ConvolutionWinogradFP16CPUKernel() override { | |||||
| if (fp16_input_ != nullptr) { | |||||
| free(fp16_input_); | |||||
| } | |||||
| if (fp16_weight_ != nullptr) { | |||||
| free(fp16_weight_); | |||||
| } | |||||
| if (fp16_out_ != nullptr) { | |||||
| free(fp16_out_); | |||||
| } | |||||
| if (tmp_data_ != nullptr) { | |||||
| free(tmp_data_); | |||||
| } | |||||
| if (trans_input_ != nullptr) { | |||||
| free(trans_input_); | |||||
| } | |||||
| if (gemm_out_ != nullptr) { | |||||
| free(gemm_out_); | |||||
| } | |||||
| if (tmp_out_data_ != nullptr) { | |||||
| free(tmp_out_data_); | |||||
| } | |||||
| delete trans_weight_; | |||||
| } | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int RunImpl(int task_id); | |||||
| int InitWeightBias(); | |||||
| int MallocFilterMatrix(int oc_block, int oc_block_num); | |||||
| int InitTmpBuffer(); | |||||
| int ConfigInputOutput(); | |||||
| private: | |||||
| int kernel_unit_; | |||||
| int input_unit_; | |||||
| int output_unit_; | |||||
| float16_t *tmp_data_; | |||||
| float16_t *trans_input_; | |||||
| float16_t *gemm_out_; | |||||
| float16_t *tmp_out_data_; | |||||
| Matrix *trans_weight_; | |||||
| InputTransformUnitFp16Func input_trans_func_; | |||||
| OutputTransformUnitFp16Func output_trans_func_; | |||||
| TmpBufferAddressFp16 tmp_buffer_address_list_[4]; | |||||
| }; | |||||
| void WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit, | |||||
| ConvParameter *conv_param, int oc_block); | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_ | |||||
| @@ -115,10 +115,6 @@ int DeconvolutionDepthwiseFp16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int DeconvolutionDepthwiseFp16CPUKernel::Init() { | int DeconvolutionDepthwiseFp16CPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| sliding_ = new SlidingWindowParam; | sliding_ = new SlidingWindowParam; | ||||
| InitSlideParam(); | InitSlideParam(); | ||||
| // conv base init | // conv base init | ||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/matrix_fp16.h" | |||||
| namespace mindspore::kernel { | |||||
| void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, | |||||
| bool row) { | |||||
| // row-major implementation | |||||
| int count = 0; | |||||
| for (int h = 0; h < m; h++) { | |||||
| int h_offset = h * k; | |||||
| for (int w = 0; w < n; w++) { | |||||
| float16_t res = 0; | |||||
| for (int i = 0; i < k; i++) { | |||||
| res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); | |||||
| } | |||||
| *(matrix_c + count) = res; | |||||
| count++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_ | |||||
| #include "src/runtime/kernel/arm/base/matrix.h" | |||||
| namespace mindspore::kernel { | |||||
| void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, | |||||
| bool row); | |||||
| } | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_ | |||||
| @@ -53,10 +53,6 @@ int PoolingFp16CPUKernel::InitBuffer() { | |||||
| } | } | ||||
| int PoolingFp16CPUKernel::Init() { | int PoolingFp16CPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| auto ret = PoolingBaseCPUKernel::Init(); | auto ret = PoolingBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "PoolingBase Init failed."; | MS_LOG(ERROR) << "PoolingBase Init failed."; | ||||
| @@ -329,10 +329,9 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) { | |||||
| MS_LOG(ERROR) << "gemm_func is nullptr."; | MS_LOG(ERROR) << "gemm_func is nullptr."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data()); | |||||
| ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), reinterpret_cast<float *>(trans_weight_->GetData()), | ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), reinterpret_cast<float *>(trans_weight_->GetData()), | ||||
| reinterpret_cast<const float *>(bias_data_), output_addr, tmp_buffer_address_list_, task_id, | |||||
| conv_param_, input_trans_func_, output_trans_func_, gemm_func_); | |||||
| reinterpret_cast<const float *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_, | |||||
| input_trans_func_, output_trans_func_, gemm_func_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -16,9 +16,7 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | ||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #endif | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/fp32/cast.h" | #include "nnacl/fp32/cast.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp16/common_func.h" | |||||
| void ReluFp16(float16_t *data, float16_t *dst, int ele_num) { | |||||
| int eight_block = UP_DIV(ele_num, C8NUM); | |||||
| for (int i = 0; i < eight_block - 1; i++) { | |||||
| int index = i * C8NUM; | |||||
| #ifdef ENABLE_NEON | |||||
| float16x8_t relu_data = vld1q_f16(data + index); | |||||
| float16x8_t zero_data = vdupq_n_f16(0); | |||||
| relu_data = vmaxq_f16(relu_data, zero_data); | |||||
| vst1q_f16(dst + index, relu_data); | |||||
| #else | |||||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||||
| #endif | |||||
| } | |||||
| for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { | |||||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||||
| } | |||||
| } | |||||
| void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) { | |||||
| int eight_block = UP_DIV(ele_num, C8NUM); | |||||
| for (int i = 0; i < eight_block - 1; i++) { | |||||
| int index = i * C8NUM; | |||||
| #ifdef ENABLE_NEON | |||||
| float16x8_t relu6_data = vld1q_f16(data + index); | |||||
| float16x8_t zero_data = vdupq_n_f16(0); | |||||
| float16x8_t six_data = vdupq_n_f16(6); | |||||
| relu6_data = vmaxq_f16(relu6_data, zero_data); | |||||
| relu6_data = vminq_f16(relu6_data, six_data); | |||||
| vst1q_f16(dst + index, relu6_data); | |||||
| #else | |||||
| for (int j = 0; j < C8NUM; ++j) { | |||||
| data[index + j] = data[index + j] < 0 ? 0 : data[index + j]; | |||||
| data[index + j] = data[index + j] > 6 ? 6 : data[index + j]; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { | |||||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||||
| data[j] = data[j] > 6 ? 6 : data[j]; | |||||
| } | |||||
| } | |||||
| @@ -39,6 +39,8 @@ void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *w | |||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | ||||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | ||||
| #endif | #endif | ||||
| void ReluFp16(float16_t *data, float16_t *dst, int ele_num); | |||||
| void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -32,12 +32,21 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||||
| #endif | #endif | ||||
| #ifndef ENABLE_NEON | #ifndef ENABLE_NEON | ||||
| void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | ||||
| size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, | |||||
| size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, | |||||
| size_t relu6) { | size_t relu6) { | ||||
| if (!(mode && writeC8)) { | |||||
| IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, output, offset, relu, relu6); | |||||
| } else { | |||||
| IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, output, offset, mode, writeC8, relu, relu6); | |||||
| } | |||||
| } | |||||
| void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||||
| size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6) { | |||||
| const int tile_n = 16; | const int tile_n = 16; | ||||
| for (int i = 0; i < out_channel; i++) { | for (int i = 0; i < out_channel; i++) { | ||||
| int oc8_block = i / 8; | |||||
| int oc8_res = i % 8; | |||||
| int oc8_block = i / C8NUM; | |||||
| int oc8_res = i % C8NUM; | |||||
| int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; | int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; | ||||
| for (int k = 0; k < tile_n; k++) { | for (int k = 0; k < tile_n; k++) { | ||||
| int input_tile_offset = k * C4NUM; | int input_tile_offset = k * C4NUM; | ||||
| @@ -72,32 +81,32 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||||
| } | } | ||||
| } | } | ||||
| void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *weight, const float16_t *bias, | |||||
| size_t step, size_t ic4, size_t output_channel, size_t offset, size_t mode, | |||||
| size_t writeC4, size_t relu, size_t relu6) { | |||||
| void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||||
| size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC8, | |||||
| size_t relu, size_t relu6) { | |||||
| const int tile_num = 16; | const int tile_num = 16; | ||||
| if (mode) { | |||||
| if (mode && writeC8) { | |||||
| for (int i = 0; i < tile_num; i++) { | for (int i = 0; i < tile_num; i++) { | ||||
| int input_tile_offset = i * C4NUM; | int input_tile_offset = i * C4NUM; | ||||
| int output_tile_offset = i * output_channel * 36; | |||||
| int output_tile_offset = i * output_channel * step; | |||||
| for (int j = 0; j < output_channel; j++) { | for (int j = 0; j < output_channel; j++) { | ||||
| int oc8_block = j / 8; | |||||
| int oc8_res = j % 8; | |||||
| int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res; | |||||
| int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res; | |||||
| int oc8_block = j / C8NUM; | |||||
| int oc8_res = j % C8NUM; | |||||
| int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; | |||||
| int out_oc_offset = output_tile_offset + oc8_block * step * C8NUM + oc8_res; | |||||
| for (int n = 0; n < step; n++) { | for (int n = 0; n < step; n++) { | ||||
| int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num; | int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num; | ||||
| int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * 8; | |||||
| int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; | |||||
| int output_kw_offset = out_oc_offset + n * C8NUM; | int output_kw_offset = out_oc_offset + n * C8NUM; | ||||
| float16_t acc = 0; | float16_t acc = 0; | ||||
| for (int k = 0; k < ic4; k++) { | for (int k = 0; k < ic4; k++) { | ||||
| int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM; | int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM; | ||||
| int weight_ic4_offset = weight_kw_offset + k * C4NUM * 8; | |||||
| for (int m = 0; m < 4; m++) { | |||||
| int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; | |||||
| for (int m = 0; m < C4NUM; m++) { | |||||
| int input_ic_offset = input_ic4_offset + m; | int input_ic_offset = input_ic4_offset + m; | ||||
| int weight_ic_offset = weight_ic4_offset + m * 8; | |||||
| int weight_ic_offset = weight_ic4_offset + m * C8NUM; | |||||
| acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; | acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -405,3 +414,91 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // fp16 convolution winograd | |||||
| void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | |||||
| TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | |||||
| InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func) { | |||||
| int thread_num = conv_param->thread_num_; | |||||
| int input_unit = conv_param->input_unit_; | |||||
| int in_batch = conv_param->input_batch_; | |||||
| int in_channel = conv_param->input_channel_; | |||||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||||
| int out_unit = conv_param->output_unit_; | |||||
| int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | |||||
| int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | |||||
| int tile_num = 16; | |||||
| int output_count = out_w_block * out_h_block; | |||||
| int output_tile_count = UP_DIV(output_count, tile_num); | |||||
| int out_channel = conv_param->output_channel_; | |||||
| int oc8 = UP_DIV(out_channel, C8NUM); | |||||
| int input_unit_square = input_unit * input_unit; | |||||
| size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t); | |||||
| float16_t *trans_input = buffer_list[0]; | |||||
| float16_t *gemm_out = buffer_list[1]; | |||||
| float16_t *tmp_out_data = buffer_list[2]; | |||||
| float16_t *tmp_data = buffer_list[3]; | |||||
| int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM; | |||||
| int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; | |||||
| int tmp_data_offset = input_unit_square * C4NUM; | |||||
| // step 1 : filter transform (pre-processed offline) | |||||
| // step 2 : input transform (online) | |||||
| for (int b = 0; b < in_batch; b++) { | |||||
| int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; | |||||
| int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc8 * C8NUM; | |||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | |||||
| int out_tile_index = thread_id * TILE_NUM; | |||||
| int cal_num = output_count - thread_id * tile_num; | |||||
| cal_num = cal_num > tile_num ? tile_num : cal_num; | |||||
| WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, | |||||
| tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | |||||
| input_trans_func); | |||||
| // step 3 : gemm | |||||
| IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, | |||||
| trans_weight, NULL, input_unit_square, ic4, oc8 * C8NUM, output_offset, 1, 1, 0, 0); | |||||
| // step 4 : output transform | |||||
| WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data, | |||||
| cal_num, out_tile_index, out_w_block, conv_param, output_trans_func); | |||||
| } | |||||
| } | |||||
| } | |||||
| void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, | |||||
| int output_unit) { | |||||
| int out_h_block_num = UP_DIV(height, output_unit); | |||||
| int out_w_block_num = UP_DIV(width, output_unit); | |||||
| int c8 = UP_DIV(channel, C8NUM); | |||||
| for (int b = 0; b < batch; b++) { | |||||
| int src_batch_offset = b * c8 * C8NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; | |||||
| int dst_batch_offset = b * height * width * channel; | |||||
| for (int h = 0; h < height; h++) { | |||||
| int src_h_offset = src_batch_offset + C8NUM * (h * out_w_block_num * output_unit); | |||||
| int dst_h_offset = dst_batch_offset + h * width * channel; | |||||
| for (int w = 0; w < width; w++) { | |||||
| int src_w_offset = src_h_offset + w * C8NUM; | |||||
| int dst_w_offset = dst_h_offset + w * channel; | |||||
| for (int c = 0; c < c8 - 1; c++) { | |||||
| int src_c8_offset = src_w_offset + c * C8NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; | |||||
| int dst_c8_offset = dst_w_offset + c * C8NUM; | |||||
| #ifdef ENABLE_NEON | |||||
| vst1q_f16(dst + dst_c8_offset, vld1q_f16(src + src_c8_offset)); | |||||
| #else | |||||
| for (int i = 0; i < C8NUM; ++i) { | |||||
| dst[dst_c8_offset + i] = src[src_c8_offset + i]; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| int c_res = channel - (c8 - 1) * C8NUM; | |||||
| int src_c_res_offset = (c8 - 1) * C8NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; | |||||
| int dst_c_res_offset = (c8 - 1) * C8NUM; | |||||
| for (int c = 0; c < c_res; c++) { | |||||
| int src_c8_res_offset = src_w_offset + src_c_res_offset + c; | |||||
| int dst_c8_res_offset = dst_w_offset + dst_c_res_offset + c; | |||||
| dst[dst_c8_res_offset] = src[src_c8_res_offset]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -18,11 +18,22 @@ | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| #include "nnacl/fp16/winograd_utils_fp16.h" | |||||
| #include "nnacl/fp16/winograd_transform_fp16.h" | |||||
| typedef float16_t *TmpBufferAddressFp16; | |||||
| #ifndef ENABLE_NEON | #ifndef ENABLE_NEON | ||||
| void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | ||||
| size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, | |||||
| size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, | |||||
| size_t relu6); | size_t relu6); | ||||
| void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||||
| size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6); | |||||
| void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||||
| size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, | |||||
| size_t relu6); | |||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -48,6 +59,14 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, | void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, | ||||
| float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, | float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, | ||||
| int task_id, ConvParameter *conv_param); | int task_id, ConvParameter *conv_param); | ||||
| // fp16 convolution winograd | |||||
| void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | |||||
| TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | |||||
| InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func); | |||||
| void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, | |||||
| int output_unit); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -534,3 +534,95 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // fp16 common winograd | |||||
| void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | |||||
| int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | |||||
| InputTransformUnitFp16Func input_trans_func) { | |||||
| int tile_num = 16; | |||||
| int input_unit = conv_param->input_unit_; | |||||
| int output_unit = conv_param->output_unit_; | |||||
| int in_channel = conv_param->input_channel_; | |||||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||||
| int pad_h = conv_param->pad_h_; | |||||
| int pad_w = conv_param->pad_w_; | |||||
| int input_h = conv_param->input_h_; | |||||
| int input_w = conv_param->input_w_; | |||||
| if (out_w_block_num == 0) { | |||||
| return; | |||||
| } | |||||
| for (int c = 0; c < cal_num; c++) { // actual tiled number | |||||
| int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; | |||||
| int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; | |||||
| int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; | |||||
| int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; | |||||
| int src_x_e = src_x_s + input_unit; | |||||
| int src_y_e = src_y_s + input_unit; | |||||
| int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); | |||||
| int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); | |||||
| int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s); | |||||
| int dst_plane_offset = c * C4NUM; | |||||
| for (int ic = 0; ic < ic4; ic++) { | |||||
| // clear tmp buffer | |||||
| memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float16_t)); | |||||
| // get real input block with padding | |||||
| int src_ic4_offset = src_plane_offset + ic * C4NUM; | |||||
| for (int interval = interval_y_s; interval < interval_y_e; interval++) { | |||||
| int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * ic4 * C4NUM; | |||||
| int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM; | |||||
| for (int j = 0; j < (interval_x_e - interval_x_s); j++) { | |||||
| int src_x_offset = src_y_offset + j * ic4 * C4NUM; | |||||
| int dst_x_offset = dst_y_offset + j * C4NUM; | |||||
| float16_t *src_addr = input_data + src_x_offset; | |||||
| float16_t *dst_addr = tmp_data + dst_x_offset; | |||||
| #ifdef ENABLE_NEON | |||||
| vst1_f16(dst_addr, vld1_f16(src_addr)); | |||||
| #else | |||||
| for (int k = 0; k < C4NUM; k++) { | |||||
| dst_addr[k] = src_addr[k]; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| } | |||||
| // input transform | |||||
| int dst_ic4_offset = dst_plane_offset + ic * tile_num * C4NUM; | |||||
| size_t dst_step = ic4 * C4NUM * tile_num; | |||||
| float16_t *trans_input_ptr = trans_input + dst_ic4_offset; | |||||
| input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step); | |||||
| } | |||||
| out_tile_index++; | |||||
| } // cal_tile_num loop | |||||
| } | |||||
| void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, | |||||
| int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param, | |||||
| OutputTransformUnitFp16Func output_trans_func) { | |||||
| int output_unit = conv_param->output_unit_; | |||||
| int output_w = conv_param->output_w_; | |||||
| int output_unit_block = UP_DIV(output_w, output_unit); | |||||
| int output_channel = conv_param->output_channel_; | |||||
| int oc8 = UP_DIV(output_channel, C8NUM); | |||||
| int input_unit = conv_param->input_unit_; | |||||
| if (output_unit_num == 0) { | |||||
| return; | |||||
| } | |||||
| for (int i = 0; i < cal_num; i++) { | |||||
| int dst_x_s = out_tile_index % output_unit_num; | |||||
| int dst_y_s = out_tile_index / output_unit_num; | |||||
| int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; | |||||
| int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit); | |||||
| for (int j = 0; j < oc8; j++) { | |||||
| int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; | |||||
| int dst_oc8_offset = | |||||
| dst_tile_offset + j * C8NUM * output_unit_block * output_unit_block * output_unit * output_unit; | |||||
| const float16_t *src_ptr = gemm_out + src_oc8_offset; | |||||
| const float16_t *bias_ptr = bias_data + j * C8NUM; | |||||
| float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; | |||||
| output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_unit_block * output_unit); | |||||
| } | |||||
| out_tile_index++; | |||||
| } | |||||
| } | |||||
| @@ -21,6 +21,7 @@ | |||||
| #include <string.h> | #include <string.h> | ||||
| #include "nnacl/fp16/pack_fp16.h" | #include "nnacl/fp16/pack_fp16.h" | ||||
| #include "nnacl/fp16/conv_fp16.h" | #include "nnacl/fp16/conv_fp16.h" | ||||
| #include "nnacl/fp16/winograd_utils_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| @@ -39,6 +40,15 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data | |||||
| void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, | void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, | ||||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); | int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); | ||||
| // fp16 common winograd | |||||
| void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | |||||
| int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | |||||
| InputTransformUnitFp16Func input_trans_func); | |||||
| void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, | |||||
| int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param, | |||||
| OutputTransformUnitFp16Func output_trans_func); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_ | |||||
| #include <arm_neon.h> | |||||
| #include "nnacl/conv_parameter.h" | |||||
| #include "nnacl/op_base.h" | |||||
| typedef void (*InputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| typedef void (*OutputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||||
| int src_step, int dst_step); | |||||
| InputTransformUnitFp16Func GetInputTransFuncFp16(int input_unit); | |||||
| OutputTransformUnitFp16Func GetOutputTransFuncFp16(int input_unit, int output_unit); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_ | |||||
| @@ -243,10 +243,9 @@ int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output | |||||
| } | } | ||||
| // fp32 conv winograd | // fp32 conv winograd | ||||
| void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, | |||||
| TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, | |||||
| InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func, | |||||
| GEMM_FUNC_FP32 gemm_func) { | |||||
| void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list, | |||||
| int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func, | |||||
| OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func) { | |||||
| int thread_num = conv_param->thread_num_; | int thread_num = conv_param->thread_num_; | ||||
| int input_unit = conv_param->input_unit_; | int input_unit = conv_param->input_unit_; | ||||
| int in_batch = conv_param->input_batch_; | int in_batch = conv_param->input_batch_; | ||||
| @@ -57,10 +57,9 @@ int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output | |||||
| StrassenMatMulParameter matmul_param); | StrassenMatMulParameter matmul_param); | ||||
| // fp32 convolution winograd | // fp32 convolution winograd | ||||
| void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, | |||||
| TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, | |||||
| InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func, | |||||
| GEMM_FUNC_FP32 gemm_func); | |||||
| void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list, | |||||
| int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func, | |||||
| OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func); | |||||
| void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit); | void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit); | ||||