| @@ -23,10 +23,6 @@ Matrix *TransformMatrixGenerator(int m, int k) { | |||
| auto aa = malloc(m * k * sizeof(float)); | |||
| matrix->SetData(aa); | |||
| matrix->SetNum(m, k); | |||
| // matrix->data_ = malloc(m * k * sizeof(float)); | |||
| // matrix->m_ = m; | |||
| // matrix->k_ = k; | |||
| // matrix->row_major_ = true; | |||
| return matrix; | |||
| } | |||
| @@ -65,26 +65,6 @@ class Matrix { | |||
| int n_dim_; | |||
| bool row_major_; | |||
| }; | |||
| // struct Matrix { | |||
| // void *data_; | |||
| // int *shape_; | |||
| // int *stride_; | |||
| // int m_; | |||
| // int k_; | |||
| // int n_dim_; | |||
| // bool row_major_; | |||
| // ~Matrix() { | |||
| // if (data_ != nullptr) { | |||
| // free(data_); | |||
| // } | |||
| // if (shape_ != nullptr) { | |||
| // free(shape_); | |||
| // } | |||
| // if (shape_ != nullptr) { | |||
| // free(stride_); | |||
| // } | |||
| // } | |||
| //}; | |||
| Matrix *TransformMatrixGenerator(int m, int k); | |||
| @@ -16,6 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" | |||
| @@ -265,11 +266,9 @@ int Convolution3x3FP16CPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto input_ele_num = input_tensor->ElementsNum(); | |||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| auto input_element_num = input_tensor->ElementsNum(); | |||
| for (int i = 0; i < input_element_num; ++i) { | |||
| fp16_input_[i] = (float16_t)ori_input_data[i]; | |||
| } | |||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| @@ -285,12 +284,9 @@ int Convolution3x3FP16CPUKernel::Run() { | |||
| // cast fp16 out to fp32 data | |||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||
| auto out_ele_num = out_tensor->ElementsNum(); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||
| auto output_element_num = out_tensor->ElementsNum(); | |||
| for (int j = 0; j < output_element_num; ++j) { | |||
| output_addr[j] = static_cast<float>(fp16_out_[j]); | |||
| } | |||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -15,8 +15,11 @@ | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/convolution_fp16.h" | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| @@ -231,10 +234,8 @@ int ConvolutionFP16CPUKernel::Run() { | |||
| } | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| auto input_element_num = input_tensor->ElementsNum(); | |||
| for (int i = 0; i < input_element_num; ++i) { | |||
| fp16_input_[i] = (float16_t)ori_input_data[i]; | |||
| } | |||
| auto input_ele_num = input_tensor->ElementsNum(); | |||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| @@ -251,10 +252,8 @@ int ConvolutionFP16CPUKernel::Run() { | |||
| // cast fp16 out to fp32 data | |||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||
| auto output_element_num = out_tensor->ElementsNum(); | |||
| for (int j = 0; j < output_element_num; ++j) { | |||
| output_addr[j] = static_cast<float>(fp16_out_[j]); | |||
| } | |||
| auto out_ele_num = out_tensor->ElementsNum(); | |||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,269 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" | |||
| #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Conv2D; | |||
| namespace mindspore::kernel { | |||
| int ConvolutionSWFP16CPUKernel::ProcessFilter() { | |||
| int kernel_h = conv_param_->kernel_h_; | |||
| int kernel_w = conv_param_->kernel_w_; | |||
| int in_channel = conv_param_->input_channel_; | |||
| int out_channel = conv_param_->output_channel_; | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| auto *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data()); | |||
| size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); | |||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||
| if (fp16_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| // cast origin fp32 weight data to fp16 data | |||
| for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { | |||
| fp16_weight_[i] = (float16_t)origin_weight[i]; | |||
| } | |||
| for (int oc = 0; oc < out_channel; ++oc) { | |||
| int src_oc_offset = oc * kernel_h * kernel_w * in_channel; | |||
| int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM; | |||
| for (int i = 0; i < kernel_h * kernel_w; ++i) { | |||
| const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel; | |||
| float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM; | |||
| memcpy(dst, src, in_channel * sizeof(float16_t)); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionSWFP16CPUKernel::InitWeightBias() { | |||
| int kernel_h = conv_param_->kernel_h_; | |||
| int kernel_w = conv_param_->kernel_w_; | |||
| int in_channel = conv_param_->input_channel_; | |||
| int out_channel = conv_param_->output_channel_; | |||
| int oc4 = UP_DIV(out_channel, C4NUM); | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; | |||
| // init weight | |||
| packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t))); | |||
| if (packed_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc packed_weight_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | |||
| auto ret = ProcessFilter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Process filter failed."; | |||
| return ret; | |||
| } | |||
| // init bias | |||
| bias_data_ = malloc(oc4 * C4NUM * sizeof(float16_t)); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, oc4 * C4NUM * sizeof(float16_t)); | |||
| auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_); | |||
| if (in_tensors_.size() == kInputSize2) { | |||
| auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data()); | |||
| for (int i = 0; i < out_channel; ++i) { | |||
| fp16_bias_data[i] = (float16_t)ori_bias[i]; | |||
| } | |||
| } else { | |||
| MS_ASSERT(in_tensor_.size() == kInputSize1); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionSWFP16CPUKernel::InitTmpBuffer() { | |||
| int in_channel = conv_param_->input_channel_; | |||
| int out_channel = conv_param_->output_channel_; | |||
| int channel_block = UP_DIV(in_channel, C4NUM); | |||
| int oc4 = UP_DIV(out_channel, C4NUM); | |||
| /*=============================fp16_input_============================*/ | |||
| size_t fp16_input_size = | |||
| in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | |||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | |||
| if (fp16_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| /*=============================nhwc4_input_============================*/ | |||
| size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * | |||
| conv_param_->input_w_ * sizeof(float16_t); | |||
| nhwc4_input_ = malloc(nhwc4_input_size); | |||
| if (nhwc4_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(nhwc4_input_, 0, nhwc4_input_size); | |||
| /*=============================tmp_output_block_============================*/ | |||
| tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(conv_param_->output_batch_ * conv_param_->output_h_ * | |||
| conv_param_->output_w_ * oc4 * C4NUM * sizeof(float16_t))); | |||
| if (tmp_output_block_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| /*=============================fp16_out_============================*/ | |||
| size_t fp16_output_size = | |||
| out_channel * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float16_t); | |||
| fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size)); | |||
| if (fp16_out_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_out_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void ConvolutionSWFP16CPUKernel::ConfigInputOutput() { | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto input_format = input_tensor->GetFormat(); | |||
| schema::Format execute_format = schema::Format_NHWC4; | |||
| convert_func_ = LayoutTransformFp16(input_format, execute_format); | |||
| if (convert_func_ == nullptr) { | |||
| MS_LOG(ERROR) << "layout convert func is nullptr."; | |||
| return; | |||
| } | |||
| } | |||
| int ConvolutionSWFP16CPUKernel::Init() { | |||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||
| set_need_reinit(); | |||
| return RET_OK; | |||
| } | |||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; | |||
| return ret; | |||
| } | |||
| ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init weight bias failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitTmpBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init tmp buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ConfigInputOutput(); | |||
| // init sliding window param | |||
| slidingWindow_param_ = new SlidingWindowParam; | |||
| InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionSWFP16CPUKernel::ReSize() { | |||
| if (tmp_output_block_ != nullptr) { | |||
| free(tmp_output_block_); | |||
| } | |||
| if (nhwc4_input_ != nullptr) { | |||
| free(nhwc4_input_); | |||
| } | |||
| if (fp16_input_ != nullptr) { | |||
| free(fp16_input_); | |||
| } | |||
| if (fp16_out_ != nullptr) { | |||
| free(fp16_out_); | |||
| } | |||
| delete slidingWindow_param_; | |||
| auto ret = ConvolutionBaseCPUKernel::Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionBase init failed."; | |||
| return ret; | |||
| } | |||
| ret = InitTmpBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init tmp buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| // init sliding window param | |||
| slidingWindow_param_ = new SlidingWindowParam; | |||
| InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) { | |||
| ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_), | |||
| tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionSWFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto conv = reinterpret_cast<ConvolutionSWFP16CPUKernel *>(cdata); | |||
| auto error_code = conv->RunImpl(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionSWFP16CPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||
| auto input_ele_num = input_tensor->ElementsNum(); | |||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); | |||
| int in_batch = conv_param_->input_batch_; | |||
| int in_h = conv_param_->input_h_; | |||
| int in_w = conv_param_->input_w_; | |||
| int in_channel = conv_param_->input_channel_; | |||
| convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); | |||
| int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| // cast fp16 out to fp32 data | |||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||
| auto out_ele_num = out_tensor->ElementsNum(); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||
| Float16ToFloat32(fp16_out_, output_addr, out_ele_num); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||
| namespace mindspore::kernel { | |||
| class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||
| public: | |||
| ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ConvolutionSWFP16CPUKernel() override { | |||
| if (fp16_input_ != nullptr) { | |||
| free(fp16_input_); | |||
| } | |||
| if (fp16_weight_ != nullptr) { | |||
| free(fp16_weight_); | |||
| } | |||
| if (fp16_out_ != nullptr) { | |||
| free(fp16_out_); | |||
| } | |||
| if (packed_weight_ != nullptr) { | |||
| free(packed_weight_); | |||
| } | |||
| if (tmp_output_block_ != nullptr) { | |||
| free(tmp_output_block_); | |||
| } | |||
| delete slidingWindow_param_; | |||
| } | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int RunImpl(int task_id); | |||
| int InitWeightBias(); | |||
| int InitTmpBuffer(); | |||
| void ConfigInputOutput(); | |||
| int ProcessFilter(); | |||
| private: | |||
| float16_t *fp16_input_; | |||
| float16_t *fp16_weight_; | |||
| float16_t *fp16_out_; | |||
| float16_t *packed_weight_; | |||
| float16_t *tmp_output_block_; | |||
| SlidingWindowParam *slidingWindow_param_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ | |||
| @@ -259,6 +259,21 @@ int Convolution3x3CPUKernel::Run() { | |||
| MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| auto is_relu = conv_param_->is_relu_; | |||
| auto is_relu6 = conv_param_->is_relu6_; | |||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data()); | |||
| PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_, | |||
| conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); | |||
| int output_num = | |||
| conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; | |||
| if (is_relu) { | |||
| ReluFp32(output_addr, output_addr, output_num); | |||
| } else if (is_relu6) { | |||
| Relu6Fp32(output_addr, output_addr, output_num); | |||
| } else { | |||
| // do nothing | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -189,8 +189,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||
| /*=============================tmp_out_data_============================*/ | |||
| int out_w_block = UP_DIV(output_w, output_unit_); | |||
| int out_h_block = UP_DIV(output_h, output_unit_); | |||
| tmp_out_data_ = reinterpret_cast<float *>( | |||
| malloc(out_w_block * out_h_block * output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); | |||
| tmp_out_data_ = reinterpret_cast<float *>(malloc(conv_param_->output_batch_ * out_w_block * out_h_block * | |||
| output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); | |||
| if (tmp_out_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; | |||
| return RET_ERROR; | |||
| @@ -365,6 +365,22 @@ int ConvolutionWinogradCPUKernel::Run() { | |||
| MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| // get real output | |||
| auto out_tensor = out_tensors_.front(); | |||
| auto out_data = reinterpret_cast<float *>(out_tensor->Data()); | |||
| UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_, | |||
| conv_param_->output_w_, conv_param_->output_channel_, output_unit_); | |||
| int output_num = | |||
| conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; | |||
| if (conv_param_->is_relu_) { | |||
| ReluFp32(out_data, out_data, output_num); | |||
| } else if (conv_param_->is_relu6_) { | |||
| Relu6Fp32(out_data, out_data, output_num); | |||
| } else { | |||
| // do nothing | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -18,7 +18,6 @@ | |||
| #include "nnacl/fp16/pack_fp16.h" | |||
| #include "nnacl/fp16/winograd_transform_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -112,6 +111,209 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w | |||
| } | |||
| #endif | |||
| void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, | |||
| int width, int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic, bool is_relu, | |||
| bool is_relu6) { | |||
| int ic8 = ic / C8NUM; | |||
| int ic8_res = ic8 % C8NUM; | |||
| int ic4 = ic8_res / C4NUM; | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| dst[c] = 0; | |||
| } | |||
| const float16_t *weight_oc = weight; | |||
| for (int oc = 0; oc < C4NUM; ++oc) { | |||
| const float16_t *weight_kh = weight_oc; | |||
| const float16_t *src_kh = src; | |||
| for (int kh = 0; kh < height; kh++) { | |||
| const float16_t *src_kw = src_kh; | |||
| const float16_t *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < width; kw++) { | |||
| const float16_t *src_ic8 = src_kw; | |||
| const float16_t *weight_ic8 = weight_kw; | |||
| for (int rc = 0; rc < ic8; ++rc) { | |||
| for (int c = 0; c < C8NUM; c++) { | |||
| dst[oc] += src_ic8[c] * weight_ic8[c]; | |||
| } | |||
| src_ic8 += C8NUM; | |||
| weight_ic8 += C8NUM; | |||
| } // ic8 loop | |||
| const float16_t *src_ic4 = src_ic8; | |||
| const float16_t *weight_ic4 = weight_ic8; | |||
| for (int rc = 0; rc < ic4; ++rc) { | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| dst[oc] += src_ic4[c] * weight_ic4[c]; | |||
| } | |||
| src_ic4 += C4NUM; | |||
| weight_ic4 += C4NUM; | |||
| } // ic4 loop | |||
| src_kw += in_kw_step; | |||
| weight_kw += ic4 * C4NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * ic4 * C4NUM; | |||
| } // kernel_h loop | |||
| dst[oc] += bias[oc]; | |||
| dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]); | |||
| dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]); | |||
| weight_oc += kernel_h * kernel_w * ic4 * C4NUM; | |||
| } // oc loop | |||
| } | |||
| void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, | |||
| int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { | |||
| float16_t *dst_h = dst + top * sliding->out_h_step_; | |||
| for (int oh = top; oh < bottom; oh++) { | |||
| int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; | |||
| int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); | |||
| int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); | |||
| const float16_t *src_h = src + ih * sliding->in_h_step_; | |||
| float16_t *dst_kernel = dst_h + left * sliding->block_channel_; | |||
| for (int ow = left; ow < right; ow++) { | |||
| int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; | |||
| int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); | |||
| int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); | |||
| const float16_t *src_w = src_h + iw * sliding->ic4_channel_; | |||
| const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||
| const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_; | |||
| SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | |||
| sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, | |||
| sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_); | |||
| dst_kernel += sliding->block_channel_; | |||
| } // width loop | |||
| dst_h += sliding->out_h_step_; | |||
| } // height loop | |||
| } | |||
| void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, | |||
| int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step, | |||
| int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { | |||
| int ic8 = ic / C8NUM; | |||
| int ic8_res = ic % C8NUM; | |||
| int ic4 = ic8_res / C4NUM; | |||
| float16_t *dst_h = dst; | |||
| const float16_t *src_h = src; | |||
| for (int oh = 0; oh < height; oh++) { | |||
| float16_t *dst_w = dst_h; | |||
| const float16_t *src_w = src_h; | |||
| for (int ow = 0; ow < width; ow++) { | |||
| const float16_t *weight_oc = weight; | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| dst_w[c] = 0; | |||
| } | |||
| for (int oc = 0; oc < C4NUM; oc++) { | |||
| const float16_t *weight_kh = weight_oc; | |||
| const float16_t *src_kh = src_w; | |||
| for (int kh = 0; kh < kernel_h; kh++) { | |||
| const float16_t *src_kw = src_kh; | |||
| const float16_t *weight_kw = weight_kh; | |||
| for (int kw = 0; kw < kernel_w; kw++) { | |||
| const float16_t *src_ic8 = src_kw; | |||
| const float16_t *weight_ic8 = weight_kw; | |||
| for (int rc = 0; rc < ic8; ++rc) { | |||
| for (int c = 0; c < C8NUM; c++) { | |||
| dst_w[oc] += src_ic8[c] * weight_ic8[c]; | |||
| } | |||
| src_ic8 += C8NUM; | |||
| weight_ic8 += C8NUM; | |||
| } // ic8 loop | |||
| const float16_t *src_ic4 = src_ic8; | |||
| const float16_t *weight_ic4 = weight_ic8; | |||
| for (int rc = 0; rc < ic4; ++rc) { | |||
| for (int c = 0; c < C4NUM; c++) { | |||
| dst_w[oc] += src_ic4[c] * weight_ic4[c]; | |||
| } | |||
| src_ic4 += C4NUM; | |||
| weight_ic4 += C4NUM; | |||
| } // ic4 loop | |||
| src_kw += in_kw_step; | |||
| weight_kw += ic4 * C4NUM; | |||
| } // kernel_w loop | |||
| src_kh += in_kh_step; | |||
| weight_kh += kernel_w * ic4 * C4NUM; | |||
| } // kernel_h loop | |||
| // add biad relu | |||
| dst_w[oc] += bias[oc]; | |||
| dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]); | |||
| dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]); | |||
| weight_oc += kernel_h * kernel_w * ic4 * C4NUM; | |||
| } // oc block | |||
| dst_w += block_channel; | |||
| src_w += in_sw_step; | |||
| } // dst_width loop | |||
| dst_h += out_h_step; | |||
| src_h += in_sh_step; | |||
| } // dst_height loop | |||
| } | |||
| // fp16 conv sliding window | |||
| void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data, | |||
| float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param, | |||
| SlidingWindowParam *slidingWindow_param) { | |||
| int oc4_res = conv_param->output_channel_ % C4NUM; | |||
| const float16_t *src = input_data; | |||
| float16_t *dst; | |||
| if (oc4_res == 0) { | |||
| dst = output_data; | |||
| } else { | |||
| dst = tmp_out_block; | |||
| } | |||
| for (int b = 0; b < conv_param->output_batch_; b++) { | |||
| for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) { | |||
| const float16_t *src_data = src; | |||
| float16_t *dst_data = dst + oc * C4NUM; | |||
| const float16_t *weight = packed_weight + oc * slidingWindow_param->kernel_step_; | |||
| const float16_t *bias = bias_data + oc * C4NUM; | |||
| SWBorderFp16(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param, | |||
| slidingWindow_param); | |||
| SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0, | |||
| conv_param->output_w_, conv_param, slidingWindow_param); | |||
| SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0, | |||
| slidingWindow_param->left_, conv_param, slidingWindow_param); | |||
| SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, | |||
| slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param); | |||
| if (slidingWindow_param->right_ > slidingWindow_param->left_ && | |||
| slidingWindow_param->bottom_ > slidingWindow_param->top_) { | |||
| int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_; | |||
| int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_; | |||
| const float16_t *in_t = | |||
| src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_; | |||
| float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ + | |||
| slidingWindow_param->left_ * slidingWindow_param->block_channel_; | |||
| SWCenterFp16(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_, | |||
| slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, | |||
| conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, | |||
| slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_, | |||
| slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_, | |||
| slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_); | |||
| } | |||
| } // output C4 loop | |||
| src += slidingWindow_param->in_step_; | |||
| dst += slidingWindow_param->out_step_; | |||
| } // batch loop | |||
| // output nhwc4 | |||
| if (oc4_res != 0) { | |||
| PackNHWC4ToNHWCFp16((const void *)tmp_out_block, (void *)output_data, conv_param->output_batch_, | |||
| conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); | |||
| } | |||
| } | |||
| // fp16 convolution common (im2col+gemm) | |||
| void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, | |||
| float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) { | |||
| @@ -144,7 +346,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||
| // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||
| int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; | |||
| int out_batch_offset = b * out_channel * out_h * out_w; | |||
| int gemm_in_batch_offset = b * packed_input_size; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| @@ -172,7 +374,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||
| void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, | |||
| float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param) { | |||
| // todo | |||
| int thread_count = conv_param->thread_num_; | |||
| int tile_num = 16; | |||
| const int output_unit = 4; | |||
| @@ -195,6 +396,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| int input_batch = conv_param->input_batch_; | |||
| for (int batch = 0; batch < input_batch; batch++) { | |||
| int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| int start_index = thread_id * tile_num; | |||
| int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; | |||
| @@ -207,8 +410,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM, | |||
| oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); | |||
| Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, | |||
| real_cal_num, out_w_block, conv_param); | |||
| Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||
| } | |||
| } | |||
| @@ -217,7 +420,10 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| bool relu = conv_param->is_relu_; | |||
| bool relu6 = conv_param->is_relu6_; | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| int batch_size = batch * output_channel * output_h * output_w; | |||
| int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit; | |||
| int ro_batch_size = batch * output_channel * output_h * output_w; | |||
| const float16_t *batch_tmp_out = tmp_out + tmp_out_batch_offset; | |||
| float16_t *batch_out = output_data + ro_batch_size; | |||
| for (int h = 0; h < output_h; h++) { | |||
| for (int w = 0; w < output_w; w++) { | |||
| for (int c = 0; c < output_channel; c++) { | |||
| @@ -226,12 +432,12 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 | |||
| int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + | |||
| C8NUM * (h * out_w_block * output_unit + w) + oc8_res; | |||
| int dst_offset = (h * output_w + w) * output_channel + c; | |||
| (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; | |||
| (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; | |||
| if (relu) { | |||
| (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; | |||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||
| } else if (relu6) { | |||
| (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; | |||
| (output_data + dst_offset)[0] = (output_data + dst_offset)[0] > 6 ? 6 : (output_data + dst_offset)[0]; | |||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; | |||
| (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| @@ -28,6 +28,18 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, | |||
| int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding); | |||
| void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, | |||
| int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step, | |||
| int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6); | |||
| // fp16 sliding window | |||
| void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data, | |||
| float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param, | |||
| SlidingWindowParam *slidingWindow_param); | |||
| // fp16 convolution common (im2col+gemm) | |||
| void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, | |||
| float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param); | |||
| @@ -219,6 +219,24 @@ void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int c | |||
| } | |||
| } | |||
| void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| int ic_remainder_ = channel % C4NUM; | |||
| if (ic_remainder_ != 0) { | |||
| int nhwc_batch_unit_offset = channel * plane; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * c4 * C4NUM * plane; | |||
| for (int i = 0; i < plane; i++) { | |||
| memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel, | |||
| (float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t)); | |||
| } | |||
| } | |||
| } else { | |||
| size_t ori_input_size = batch * plane * channel * sizeof(float16_t); | |||
| memcpy((float16_t *)dst, (float16_t *)src, ori_input_size); | |||
| } | |||
| } | |||
| void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int nhwc4_batch_offset = 0; | |||
| int ic4 = UP_DIV(channel, C4NUM); | |||
| @@ -41,6 +41,8 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int | |||
| void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); | |||
| @@ -217,7 +217,7 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons | |||
| size_t output_offset = out_channel * sizeof(float); | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * in_channel * in_h * in_w; | |||
| int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; | |||
| int out_batch_offset = b * out_channel * out_h * out_w; | |||
| int gemm_in_batch_offset = b * packed_input_size; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| @@ -263,12 +263,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int out_channel = conv_param->output_channel_; | |||
| int out_batch = conv_param->output_batch_; | |||
| int oc4 = UP_DIV(out_channel, C4NUM); | |||
| int input_unit_square = input_unit * input_unit; | |||
| size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float); | |||
| bool is_relu = conv_param->is_relu_; | |||
| bool is_relu6 = conv_param->is_relu6_; | |||
| float *trans_input = buffer_list[0]; | |||
| float *gemm_out = buffer_list[1]; | |||
| @@ -280,11 +277,13 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||
| // step 1 : filter transform (pre-processed offline) | |||
| // step 2 : input transform (online) | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | |||
| int out_tile_index = thread_id * TILE_NUM; | |||
| int cal_num = output_count - thread_id * TILE_NUM; | |||
| cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num; | |||
| WinogradInputTransform(input_data, trans_input + task_id * trans_input_offset, | |||
| WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, | |||
| tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | |||
| input_trans_func); | |||
| // step 3 : gemm | |||
| @@ -292,21 +291,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||
| input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0); | |||
| // step 4 : output transform | |||
| WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data, bias_data, cal_num, out_tile_index, | |||
| out_w_block, conv_param, output_trans_func); | |||
| WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data, | |||
| cal_num, out_tile_index, out_w_block, conv_param, output_trans_func); | |||
| } | |||
| } | |||
| // get real output | |||
| UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel, | |||
| out_unit); | |||
| int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; | |||
| if (is_relu) { | |||
| ReluFp32(output_data, output_data, output_num); | |||
| } else if (is_relu6) { | |||
| Relu6Fp32(output_data, output_data, output_num); | |||
| } else { | |||
| // do nothing | |||
| } | |||
| } | |||
| void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, | |||
| @@ -360,8 +348,6 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int input_unit_square = 4 * 4; | |||
| bool is_relu = conv_param->is_relu_; | |||
| bool is_relu6 = conv_param->is_relu6_; | |||
| float *tile_buffer = buffer_list[0]; | |||
| float *block_unit_buffer = buffer_list[1]; | |||
| float *tmp_dst_buffer = buffer_list[2]; | |||
| @@ -372,10 +358,13 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||
| int input_batch = conv_param->input_batch_; | |||
| for (int batch = 0; batch < input_batch; batch++) { | |||
| int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| int start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | |||
| Conv3x3Fp32InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, | |||
| Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||
| out_w_block, conv_param); | |||
| @@ -383,17 +372,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||
| transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM, | |||
| oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0); | |||
| Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out, bias_data, start_index, | |||
| real_cal_num, out_w_block, conv_param); | |||
| Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset, | |||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||
| } | |||
| PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); | |||
| } | |||
| int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; | |||
| if (is_relu) { | |||
| ReluFp32(output_data, output_data, output_num); | |||
| } else if (is_relu6) { | |||
| Relu6Fp32(output_data, output_data, output_num); | |||
| } else { | |||
| // do nothing | |||
| } | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| hiai_model_0909_kd_rot_ps_softmax.tflite | |||
| hiai_chinese_english_recognize_model_float32.tflite | |||
| #hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite | |||
| hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite | |||
| hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite | |||
| hiai_cn_recognize_modify_padv2.tflite | |||
| hiai_model_normalize_object_scene_ps_20200519.tflite | |||