From 0e5d30f9174e42deb2ac6f566ab1b6ac7af11def Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Wed, 29 Jul 2020 16:03:38 +0800 Subject: [PATCH] fix bug && optimize relu/relu6/winograd unpack func --- mindspore/lite/CMakeLists.txt | 3 + .../src/runtime/kernel/arm/CMakeLists.txt | 11 +- .../kernel/arm/base/convolution_base.cc | 71 +- .../kernel/arm/base/convolution_base.h | 6 +- .../kernel/arm/fp16/convolution_3x3_fp16.cc | 21 +- .../kernel/arm/fp16/convolution_fp16.cc | 9 +- .../kernel/arm/fp16/convolution_fp16.h | 4 +- .../runtime/kernel/arm/fp32/convolution.cc | 16 +- .../kernel/arm/int8/convolution_int8.cc | 32 +- .../runtime/kernel/arm/opclib/CMakeLists.txt | 15 +- .../runtime/kernel/arm/opclib/common_func.cc | 55 +- .../kernel/arm/opclib/conv_parameter.h | 2 +- .../kernel/arm/opclib/fp16/conv_fp16.cc | 2 +- .../kernel/arm/opclib/fp32/common_func.cc | 2 +- .../runtime/kernel/arm/opclib/fp32/conv.cc | 75 +- .../src/runtime/kernel/arm/opclib/fp32/conv.h | 6 +- .../runtime/kernel/arm/opclib/fp32/pooling.cc | 5 +- .../kernel/arm/opclib/int8/conv_int8.cc | 22 +- .../kernel/arm/opclib/opt_op_handler.c | 1 - .../src/runtime/kernel/arm/opclib/pack.cc | 33 +- .../lite/src/runtime/kernel/arm/opclib/pack.h | 48 - .../arm/opclib/quantization/fixed_point.h | 666 +------- .../kernel/arm/opclib/winograd_utils.cc | 1340 ++++++++++++++--- 23 files changed, 1370 insertions(+), 1075 deletions(-) diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index f35b5b60c1..b8191fed7b 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -113,6 +113,9 @@ if (BUILD_DEVICE) if (PLATFORM_ARM64) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") add_compile_definitions(ENABLE_ARM64) + if (ENABLE_FP16) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + endif () endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark) diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index 2f88b378a3..a2a39274c7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -1,15 +1,16 @@ -file(GLOB_RECURSE KERNEL_SRC +file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/opclib/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp32/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/opclib/int8/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/quantization/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc ) if (PLATFORM_ARM64) # assembly - file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s + file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.S) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) @@ -17,13 +18,15 @@ endif() if (PLATFORM_ARM32) # assembly - file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s) + file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.S + ) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) endif() if (ENABLE_FP16) - file(GLOB_RECURSE FP6_SRC + file(GLOB FP6_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc ) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index 815cccfc8d..bb615d0ae7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -101,7 +101,6 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() { int ConvolutionBaseCPUKernel::Init() { auto input = this->inputs_.front(); auto output = this->outputs_.front(); - conv_param_->input_batch_ = input->Batch(); conv_param_->input_h_ = input->Height(); conv_param_->input_w_ = input->Width(); @@ -111,7 +110,6 @@ int ConvolutionBaseCPUKernel::Init() { conv_param_->output_w_ = output->Width(); conv_param_->output_channel_ = output->Channel(); conv_param_->thread_num_ = ctx_->threadNum; - return RET_OK; } @@ -221,9 +219,24 @@ void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *con } } -kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { +bool CheckSupportFP16() { + bool support_fp16 = false; +#ifdef ENABLE_ARM64 + void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; + if (optimize_op_handler != nullptr) { + support_fp16 = true; + MS_LOG(INFO) << "Support FP16."; + } else { + support_fp16 = false; + MS_LOG(INFO) << "Your machine doesn't support fp16, return back to float32 kernel."; + } +#endif + return support_fp16; +} + +kernel::LiteKernel *CpuConvFloatKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { auto conv_param = reinterpret_cast(opParameter); int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; @@ -240,43 +253,34 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - auto conv_param = reinterpret_cast(opParameter); - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - - if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { - auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; - } else { - auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; +#endif + } + auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); return kernel; } } -#endif kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, @@ -308,17 +312,10 @@ kernel::LiteKernel *CpuConvKernelCreator(const std::vector #endif #include "src/lite_kernel.h" - #include "include/context.h" - #include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" using mindspore::lite::Context; using mindspore::schema::PadMode; @@ -40,7 +39,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { public: ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { opParameter->thread_num_ = ctx->threadNum; conv_param_ = reinterpret_cast(opParameter); } @@ -63,6 +62,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { LayoutConvertor convert_func_; }; void ComputeQuantOutRange(ConvParameter *conv_param); +bool CheckSupportFP16(); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index 600337cfb1..20bbb7ecd1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -49,6 +49,8 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara int Convolution3x3FP16CPUKernel::InitWeightBias() { auto input_channel = conv_param_->input_channel_; int output_channel = conv_param_->output_channel_; + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; int iC4 = UP_DIV(input_channel, C4NUM); int oC8 = UP_DIV(output_channel, C8NUM); // init weight @@ -60,8 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { } memset(transformed_filter_addr_, 0, transformed_size); float *origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); - size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); - fp16_weight_ = malloc(fp16_weight_size); + size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t); + fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); if (fp16_weight_ == nullptr) { MS_LOG(ERROR) << "malloc fp16_weight_ failed."; return RET_ERROR; @@ -74,16 +76,17 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { // init bias size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); - bias_data_ = reinterpret_cast(malloc(new_bias_size)); + bias_data_ = malloc(new_bias_size); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "malloc bias_data_ failed."; return RET_ERROR; } memset(bias_data_, 0, new_bias_size); + auto fp16_bias_data = reinterpret_cast(bias_data_); if (inputs_.size() == kInputSize2) { auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); - for (int i = 0; i < out_channel; ++i) { - bias_data_[i] = (float16_t)ori_bias_addr[i]; + for (int i = 0; i < output_channel; ++i) { + fp16_bias_data[i] = (float16_t)ori_bias_addr[i]; } } else { MS_ASSERT(inputs_.size() == kInputSize1); @@ -129,16 +132,15 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { } memset(tmp_out_, 0, tmp_out_size); - size_t fp16_input_size = - in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); - fp16_input_ = malloc(fp16_input_size); + size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * + conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = reinterpret_cast(malloc(fp16_input_size)); if (fp16_input_ == nullptr) { MS_LOG(ERROR) << "malloc fp16_input_ failed."; return RET_ERROR; } memset(fp16_input_, 0, fp16_input_size); - // init nhwc4 input size_t nhwc4_input_size = iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); @@ -249,4 +251,3 @@ int Convolution3x3FP16CPUKernel::Run() { return RET_OK; } } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index f00a5835e5..c73d85a9d1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -42,7 +42,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { // init weight float *origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); - fp16_weight_ = malloc(fp16_weight_size); + fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); if (fp16_weight_ == nullptr) { MS_LOG(ERROR) << "malloc fp16_weight_ failed."; return RET_ERROR; @@ -60,16 +60,17 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { PackWeightFp16(fp16_weight_, conv_param_, packed_weight_); // init bias - bias_data_ = reinterpret_cast(malloc(oc8 * C8NUM * sizeof(float16_t))); + bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "malloc bias_data_ failed."; return RET_ERROR; } memset(bias_data_, 0, oc8 * C8NUM * sizeof(float16_t)); + auto fp16_bias_data = reinterpret_cast(bias_data_); if (inputs_.size() == kInputSize2) { auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); for (int i = 0; i < out_channel; ++i) { - bias_data_[i] = (float16_t)ori_bias[i]; + fp16_bias_data[i] = (float16_t)ori_bias[i]; } } else { MS_ASSERT(inputs_.size() == kInputSize1); @@ -101,7 +102,7 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { size_t fp16_input_size = in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); - fp16_input_ = malloc(fp16_input_size); + fp16_input_ = reinterpret_cast(malloc(fp16_input_size)); if (fp16_input_ == nullptr) { MS_LOG(ERROR) << "malloc fp16_input_ failed."; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h index bf7540f6fc..ecd785aa1e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -20,9 +20,7 @@ #include #include #include "src/lite_kernel.h" - #include "src/runtime/kernel/arm/base/convolution_base.h" -#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" namespace mindspore::kernel { typedef void (*FP16_GEMM_FUNC)(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, @@ -33,7 +31,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} ~ConvolutionFP16CPUKernel() override { if (fp16_input_ != nullptr) { free(fp16_input_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index fb6f68a58e..615c9c9a88 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -33,10 +33,17 @@ int ConvolutionCPUKernel::InitWeightBias() { int kernel_w = conv_param_->kernel_w_; int in_channel = conv_param_->input_channel_; int out_channel = conv_param_->output_channel_; - int oc8 = UP_DIV(out_channel, C8NUM); int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; - int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; + int oc_block, oc_block_num; +#ifdef ENABLE_ARM32 + oc_block = C4NUM; + oc_block_num = UP_DIV(out_channel, C4NUM); +#else + oc_block = C8NUM; + oc_block_num = UP_DIV(out_channel, C8NUM); +#endif + int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane; // init weight auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); @@ -49,12 +56,12 @@ int ConvolutionCPUKernel::InitWeightBias() { PackWeightFp32(origin_weight, conv_param_, packed_weight_); // init bias - bias_data_ = reinterpret_cast(malloc(oc8 * C8NUM * sizeof(float))); + bias_data_ = reinterpret_cast(malloc(oc_block_num * oc_block * sizeof(float))); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "malloc bias failed."; return RET_ERROR; } - memset(bias_data_, 0, oc8 * C8NUM * sizeof(float)); + memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float)); if (inputs_.size() == kInputSize2) { auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); @@ -198,4 +205,3 @@ int ConvolutionCPUKernel::Run() { return RET_OK; } } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index cbd652225c..80d1284871 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -55,6 +55,7 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() { support_optimize_ = false; } #endif + conv_param_->tile_num_ = tile_num_; } int ConvolutionInt8CPUKernel::InitWeightBias() { @@ -78,7 +79,7 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { return RET_ERROR; } memset(packed_weight_, 0, pack_weight_size); - int32_t *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + auto *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); for (int i = 0; i < out_channel; i++) weight_sum[i] = 0; PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum); @@ -105,15 +106,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { } int ConvolutionInt8CPUKernel::InitTmpBuffer() { - int tile_n = 4; int output_count = conv_param_->output_h_ * conv_param_->output_w_; - int output_tile_count = UP_DIV(output_count, tile_n); + int output_tile_count = UP_DIV(output_count, tile_num_); int in_channel = conv_param_->input_channel_; int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; int plane_c4 = UP_DIV(kernel_plane, C4NUM); int unit_size = plane_c4 * C4NUM * ic4 * C4NUM; - int packed_input_size = output_tile_count * tile_n * unit_size; + int packed_input_size = output_tile_count * tile_num_ * unit_size; packed_input_ = reinterpret_cast(malloc(conv_param_->input_batch_ * packed_input_size)); if (packed_input_ == nullptr) { @@ -122,14 +122,14 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { } memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); - input_sum_ = reinterpret_cast(malloc(tile_n * thread_count_ * sizeof(int32_t))); + input_sum_ = reinterpret_cast(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); if (input_sum_ == nullptr) { MS_LOG(ERROR) << "malloc input_sum_ failed."; return RET_ERROR; } - memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t)); + memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); - size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t); + size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); tmp_dst_ = reinterpret_cast(malloc(tmp_dst_size)); if (tmp_dst_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_dst_ failed."; @@ -137,7 +137,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { } memset(tmp_dst_, 0, tmp_dst_size); - tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_n * conv_param_->output_channel_)); + tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); if (tmp_out_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_out_ failed."; return RET_ERROR; @@ -173,7 +173,7 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { return RET_ERROR; } memset(packed_weight_, filter_zp, pack_weight_size); - int32_t *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + auto *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane; PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum); @@ -200,15 +200,13 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { } int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { - // todo - int tile_n = 24; int output_count = conv_param_->output_h_ * conv_param_->output_w_; - int output_tile_count = UP_DIV(output_count, tile_n); + int output_tile_count = UP_DIV(output_count, tile_num_); int in_channel = conv_param_->input_channel_; int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; int unit_size = kernel_plane * ic4 * C4NUM; - int packed_input_size = output_tile_count * tile_n * unit_size; + int packed_input_size = output_tile_count * tile_num_ * unit_size; packed_input_ = reinterpret_cast(malloc(conv_param_->input_batch_ * packed_input_size)); if (packed_input_ == nullptr) { @@ -217,14 +215,14 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { } memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); - input_sum_ = reinterpret_cast(malloc(tile_n * thread_count_ * sizeof(int32_t))); + input_sum_ = reinterpret_cast(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); if (input_sum_ == nullptr) { MS_LOG(ERROR) << "malloc input_sum_ failed."; return RET_ERROR; } - memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t)); + memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); - size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t); + size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); tmp_dst_ = reinterpret_cast(malloc(tmp_dst_size)); if (tmp_dst_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_dst_ failed."; @@ -232,7 +230,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { } memset(tmp_dst_, 0, tmp_dst_size); - tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_n * conv_param_->output_channel_)); + tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); if (tmp_out_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_out_ failed."; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt index 9895fd458f..8421a6200b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt @@ -5,19 +5,22 @@ set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../) include_directories(OPTIMIZED_OP_DIR) ########################### optimized files ########################### -set(FP16_ASSEMBLY -# ${OPTIMIZED_OP_DIR}/assembly/arm64/IndirectGemmFp16_16x8.s +file(GLOB OPTIMIZED_ASSEMBLY + ${OPTIMIZED_OP_DIR}/assembly/opt/*.s + ${OPTIMIZED_OP_DIR}/assembly/opt/*.S ) -file(GLOB_RECURSE OPTIMIZED_INT8_ASSEMBLY - ${OPTIMIZED_OP_DIR}/assembly/opt/*.S + +file(GLOB FP16_SRC + # ${OPTIMIZED_OP_DIR}/fp16/*.cc + # ${OPTIMIZED_OP_DIR}/../fp16/*.cc ) ########################### share library build ######################## set(OPTIMIZED_OPS "opt_op_handler.c") -set_property(SOURCE ${OPTIMIZED_INT8_ASSEMBLY} PROPERTY LANGUAGE C) -list(APPEND OPTIMIZED_OPS ${OPTIMIZED_INT8_ASSEMBLY} ${FP16_ASSEMBLY}) +set_property(SOURCE ${OPTIMIZED_ASSEMBLY} PROPERTY LANGUAGE C) +list(APPEND OPTIMIZED_OPS ${OPTIMIZED_ASSEMBLY} ${FP16_SRC}) if (PLATFORM_ARM64) string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc index 810cf6f764..ebc32b780b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc @@ -17,7 +17,7 @@ #include "src/runtime/kernel/arm/opclib/common_func.h" #include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" -#ifndef ENABLE_ARM +#ifndef __aarch64__ void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, int output_channel, size_t offset, size_t relu, size_t relu6) { for (int i = 0; i < TILE_NUM; i++) { @@ -108,24 +108,49 @@ int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } void ReluFp32(float *data, int ele_num) { - for (int i = 0; i < ele_num; i++) { - if (data[i] < 0) { - data[i] = 0; - } else { - // do nothing - } + int four_block = UP_DIV(ele_num, C4NUM); + for (int i = 0; i < four_block - 1; i++) { + int index = i * C4NUM; +#ifdef ENABLE_NEON + float32x4_t relu_data = vld1q_f32(data + index); + float32x4_t zero_data = vdupq_n_f32(0); + relu_data = vmaxq_f32(relu_data, zero_data); +#else + data[index] = data[index] < 0 ? 0 : data[index]; + data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; + data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; + data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; +#endif + } + for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) { + data[j] = data[j] < 0 ? 0 : data[j]; } } void Relu6Fp32(float *data, int ele_num) { - for (int i = 0; i < ele_num; i++) { - if (data[i] < 0) { - data[i] = 0; - } else if (data[i] > 6) { - data[i] = 6; - } else { - // do nothing - } + int four_block = UP_DIV(ele_num, C4NUM); + for (int i = 0; i < four_block - 1; i++) { + int index = i * C4NUM; +#ifdef ENABLE_NEON + float32x4_t relu6_data = vld1q_f32(data + index); + float32x4_t zero_data = vdupq_n_f32(0); + float32x4_t six_data = vdupq_n_f32(6); + relu6_data = vmaxq_f32(relu6_data, zero_data); + relu6_data = vminq_f32(relu6_data, six_data); +#else + data[index] = data[index] < 0 ? 0 : data[index]; + data[index] = data[index] > 6 ? 6 : data[index]; + data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; + data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1]; + data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; + data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2]; + data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; + data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3]; +#endif + } + for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) { + data[j] = data[j] < 0 ? 0 : data[j]; + data[j] = data[j] > 6 ? 6 : data[j]; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h index 8d5ad1b708..9ac4c95ce0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h @@ -39,7 +39,7 @@ struct ConvParameter { int pad_l_; int pad_r_; int group_; - int n_dim_; + int tile_num_; int input_batch_; int input_h_; int input_w_; diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc index 50c57f3132..a112cd1553 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc @@ -141,7 +141,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ int start_index = thread_id * tile_n; int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; float16_t *gemm_input = - (float *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); + (float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); int out_offset = thread_id * tile_n * out_channel + out_batch_offset; diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc index ce5f1a06fd..ffa93e0071 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc @@ -16,7 +16,7 @@ #include "src/runtime/kernel/arm/opclib/fp32/common_func.h" -#ifndef ENABLE_ARM +#ifndef __aarch64__ void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, size_t row, size_t col) { for (int r = 0; r < row; r++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc index cc121dc4b5..66fe33c862 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc @@ -31,13 +31,12 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; int thread_count = conv_param->thread_num_; - int tile_n = 8; int output_count = out_h * out_w; - int output_tile_count = UP_DIV(output_count, tile_n); + int output_tile_count = UP_DIV(output_count, TILE_NUM); int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; int unit_size = kernel_plane * ic4 * C4NUM; - int packed_input_size = output_tile_count * tile_n * unit_size; + int packed_input_size = output_tile_count * TILE_NUM * unit_size; // we accumulate 4 channels per time for input blocks int conv_depth = kernel_h * kernel_w; @@ -50,13 +49,13 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons int out_batch_offset = b * out_channel * out_h * out_w; int gemm_in_batch_offset = b * packed_input_size; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { - int start_index = thread_id * tile_n; - int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; - float *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + float *gemm_input = packed_input + thread_id * unit_size * TILE_NUM + gemm_in_batch_offset; Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); - int out_offset = thread_id * tile_n * out_channel + out_batch_offset; - if (real_cal_num == tile_n) { + int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset; + if (real_cal_num == TILE_NUM) { float *gemm_output = output_data + out_offset; IndirectGemmFp32_8x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0, conv_param->is_relu_, conv_param->is_relu6_); @@ -121,22 +120,8 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ } } // get real output - for (int batch = 0; batch < out_batch; batch++) { - int batch_size = batch * out_channel * conv_param->output_h_ * conv_param->output_w_; - for (int h = 0; h < conv_param->output_h_; h++) { - for (int w = 0; w < conv_param->output_w_; w++) { - for (int c = 0; c < out_channel; c++) { - int oc4_block = c / C4NUM; - int oc4_res = c % C4NUM; - int src_offset = oc4_block * C4NUM * out_w_block * out_h_block * out_unit * out_unit + - C4NUM * (h * out_w_block * out_unit + w) + oc4_res; - int dst_offset = (h * conv_param->output_w_ + w) * out_channel + c; - (output_data + dst_offset)[0] = (tmp_out_data + src_offset)[0]; - } - } - } - } - + UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel, + out_unit); int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; if (is_relu) { ReluFp32(output_data, output_num); @@ -147,6 +132,45 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ } } +void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, + int output_unit) { + int out_h_block_num = UP_DIV(height, output_unit); + int out_w_block_num = UP_DIV(width, output_unit); + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_batch_offset = b * c4 * C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; + int dst_batch_offset = b * height * width * channel; + for (int h = 0; h < height; h++) { + int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit); + int dst_h_offset = dst_batch_offset + h * width * channel; + for (int w = 0; w < width; w++) { + int src_w_offset = src_h_offset + w * C4NUM; + int dst_w_offset = dst_h_offset + w * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c4_offset = src_w_offset + c * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; + int dst_c4_offset = dst_w_offset + c * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32(dst + dst_c4_offset, vld1q_f32(src + src_c4_offset)); +#else + dst[dst_c4_offset] = src[src_c4_offset]; + dst[dst_c4_offset + 1] = src[src_c4_offset + 1]; + dst[dst_c4_offset + 2] = src[src_c4_offset + 2]; + dst[dst_c4_offset + 3] = src[src_c4_offset + 3]; +#endif + } + int c_res = channel - (c4 - 1) * C4NUM; + int src_c_res_offset = (c4 - 1) * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; + int dst_c_res_offset = (c4 - 1) * C4NUM; + for (int c = 0; c < c_res; c++) { + int src_c4_res_offset = src_w_offset + src_c_res_offset + c; + int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c; + dst[dst_c4_res_offset] = src[src_c4_res_offset]; + } + } + } + } +} + // fp32 conv3x3 void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param) { @@ -182,7 +206,7 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat } PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); } - int output_num = oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; + int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; if (is_relu) { ReluFp32(output_data, output_num); } else if (is_relu6) { @@ -191,4 +215,3 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat // do nothing } } - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h index 5c1d096532..07b316587f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h @@ -42,10 +42,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func); +void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit); + // fp32 conv3x3 void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, - TmpBufferAddress *buffer_list, int task_id, - ConvParameter *conv_param); + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc index 084e7d42d5..f2d3c11d03 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc @@ -81,8 +81,9 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo } // win_w loop } // win_h loop #ifdef ENABLE_NEON - float32x4_t dup_count = vdupq_n_f32(real_count); - vst1q_f32(output_ptr + out_channel_offset, vdivq_f32(tmp_avg, dup_count)); + float reverse_count = 1 / real_count; + float32x4_t dup_count = vdupq_n_f32(reverse_count); + vst1q_f32(output_ptr + out_channel_offset, vmulq_f32(tmp_avg, dup_count)); #else *(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count; *(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count; diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc index e3573498d0..fe7d9f0b9c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc @@ -24,11 +24,6 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t * size_t oc4, size_t offset); #ifdef ENABLE_ARM64 -// void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t -// ksize, -// size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, -// size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t -// shift_before, size_t shift_after); void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, @@ -54,8 +49,9 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in #ifdef __aarch64__ IndirectGemmInt8_4x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); + // todo arm32 #else - int tile_num = 4; + int tile_num = conv_param->tile_num_; int plane_c4 = UP_DIV(kernel_plane, C4NUM); for (int oc = 0; oc < output_channel; oc++) { int oc4_block = oc / C4NUM; @@ -109,7 +105,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); #endif } else { - int tile_num = 24; + int tile_num = conv_param->tile_num_; for (int oc = 0; oc < output_channel; oc++) { int oc4_block = oc / C4NUM; int oc4_res = oc % C4NUM; @@ -202,7 +198,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c int out_channel = conv_param->output_channel_; int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; - int tile_n = 4; + int tile_n = conv_param->tile_num_; int thread_count = conv_param->thread_num_; int output_count = out_h * out_w; int output_tile_count = UP_DIV(output_count, tile_n); @@ -255,9 +251,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight int out_w = conv_param->output_w_; int out_channel = conv_param->output_channel_; int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; - - // todo - int tile_n = 24; + int tile_n = conv_param->tile_num_; int thread_count = conv_param->thread_num_; int output_count = out_h * out_w; int output_tile_count = UP_DIV(output_count, tile_n); @@ -302,7 +296,6 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, int task_id, ConvParameter *conv_param) { - // todo int thread_count = conv_param->thread_num_; int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); int output_batch = conv_param->output_batch_; @@ -331,8 +324,5 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi } // get real output - for (int batch = 0; batch < output_batch; batch++) { - // int batch_size = batch * output_channel * output_h * output_w; - C4UnpackToHwcInt8(tmp_out, output_data, output_channel, output_h, output_w); - } + PackNC4HW4ToNHWCInt8(tmp_out, output_data, output_batch, output_h * output_w, output_channel); } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c b/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c index b6dec4f2e8..82591d4aef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c @@ -16,7 +16,6 @@ #include -// todo extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc index 8ed9691970..7c89e05c19 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc @@ -345,6 +345,7 @@ void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, i void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) { // original weight format : ohwi + // todo pack weight for arm32 platform int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; int in_channel = conv_param->input_channel_; @@ -352,7 +353,7 @@ void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed int oc8 = UP_DIV(out_channel, C8NUM); int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; - int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; + int pack_weight_size = oc8 * C8NUM * ic4 * C4NUM * kernel_plane; int unit_size = C8NUM * C4NUM; int block_size = pack_weight_size / oc8; @@ -565,7 +566,7 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, int32_t *input_sum, ConvParameter *conv_param) { // input format : nhwc - int tile_num = 4; + int tile_num = conv_param->tile_num_; int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; @@ -624,7 +625,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, int32_t *input_sum, ConvParameter *conv_param) { // input format : nhwc - int tile_num = 24; + int tile_num = conv_param->tile_num_; int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; @@ -980,15 +981,23 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int for (int b = 0; b < batch; b++) { int src_offset = b * plane * c4 * C4NUM; int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k * channel; - ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; + ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; + ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; + ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; + ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h index 95e711851e..66438103eb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h @@ -122,52 +122,4 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter void PackDepthwiseInt8Weight(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); -inline void UnpackHwcToChwFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { - int cur = 0; - for (int i = 0; i < channel; i++) { - auto plane = i / BLOCK; - auto offset = i % BLOCK; - auto src_plane = plane * h * w * BLOCK + src_ptr; - for (int j = 0; j < h * w; j++) { - dst_ptr[cur++] = src_plane[j * BLOCK + offset]; - } - } -} - -inline void C8UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { - int cur = 0; - for (int j = 0; j < h * w; j++) { - for (int i = 0; i < channel; i++) { - auto plane = i / 8; - auto offset = i % 8; - auto src_plane = plane * h * w * 8 + src_ptr; - dst_ptr[cur++] = src_plane[j * 8 + offset]; - } - } -} - -inline void C4UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { - int cur = 0; - for (int j = 0; j < h * w; j++) { - for (int i = 0; i < channel; i++) { - auto plane = i / 4; - auto offset = i % 4; - auto src_plane = plane * h * w * 4 + src_ptr; - dst_ptr[cur++] = src_plane[j * 4 + offset]; - } - } -} - -inline void C4UnpackToHwcInt8(int8_t *src_ptr, int8_t *dst_ptr, int channel, int h, int w) { - int cur = 0; - for (int j = 0; j < h * w; j++) { - for (int i = 0; i < channel; i++) { - auto plane = i / 4; - auto offset = i % 4; - auto src_plane = plane * h * w * 4 + src_ptr; - dst_ptr[cur++] = src_plane[j * 4 + offset]; - } - } -} - #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h index 3249d2e47c..9477b01fef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h @@ -17,659 +17,43 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ -#include -#include -#include -#include -#include +#include +#include "include/infer_log.h" #ifdef ENABLE_NEON #include #endif -// Part 1: Low-level integer-arithmetic primitives. -// The implementations here are generic implementations valid for -// scalar types (e.g. std::int32_t). Architecture-specific SIMD types -// (e.g. NEON int32x4_t) may be supported by providing -// specializations for them in separate files. -// -// The purpose of these primitives is two-fold: -// - They will be used to implement higher-level fixed-point -// abstractions, namely the FixedPoint class and its arithmetic -// operators. -// - They will be directly used to implement some more involved -// fixed-point computations, e.g. the fixed-point implementation -// of math functions such as tanh. - -// Some compile-time traits around raw types to handle SIMD aspects: -// number of lanes, underlying scalar type. -template -struct FixedPointRawTypeTraits {}; - -template <> -struct FixedPointRawTypeTraits { - typedef std::int32_t ScalarRawType; - static constexpr int kLanes = 1; -}; - -template <> -struct FixedPointRawTypeTraits { - typedef std::int16_t ScalarRawType; - static constexpr int kLanes = 1; -}; - -// Returns a SIMD value duplicating a scalar value across all lanes. -template -tRawType Dup(typename FixedPointRawTypeTraits::ScalarRawType x) { - return x; -} - -// Plain bit-wise AND -template -tIntegerType BitAnd(tIntegerType a, tIntegerType b) { - return a & b; -} - -// Plain bit-wise OR -template -tIntegerType BitOr(tIntegerType a, tIntegerType b) { - return a | b; -} - -// Plain bit-wise XOR -template -tIntegerType BitXor(tIntegerType a, tIntegerType b) { - return a ^ b; -} - -// Plain bit-wise NOT -template -tIntegerType BitNot(tIntegerType a) { - return ~a; -} - -// Integer addition. Not saturating. Overflow is undefined behavior. -template -tIntegerType Add(tIntegerType a, tIntegerType b) { - return a + b; -} - -// Integer multiplication. Not saturating. Overflow is undefined behavior. -template -tIntegerType Mul(tIntegerType a, tIntegerType b) { - return a * b; -} - -// Integer subtraction. Not saturating. Overflow is undefined behavior. -template -tIntegerType Sub(tIntegerType a, tIntegerType b) { - return a - b; -} - -// Integer unary negative. Not saturating. Overflow is undefined behavior. -template -tIntegerType Neg(tIntegerType a) { - return -a; -} - -// Integer arithmetic left-shift, equivalent to multiplying with a power of two. -// Negative values are OK. In case of overflow, no Undefined -// Behavior, but the results are implementation-defined (in practice, -// they currently are saturated, but we make no commitment to that). The idea -// is that the caller will want to implement the overflowing cases with -// saturation with compare-and-mask, so we don't care about the results -// in the overflow case, we just want to avoid undefined behavior. -// -// tIntegerType may be int32 or any narrower signed type. -template -tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { - const std::int64_t wide_a = (std::int64_t)(a); - const std::int64_t wide_shifted = wide_a * (1 << offset); - const auto min = std::numeric_limits::min(); - const auto max = std::numeric_limits::max(); - return wide_shifted < min ? min : wide_shifted > max ? max : (tIntegerType)(wide_shifted); -} - -// Integer arithmetic right-shift. Not rounding. -// Relying on implementation-defined, but in-practice-consistent, -// C++ compiler behavior. -template -tIntegerType ShiftRight(tIntegerType a, int offset) { - return a >> offset; -} - -// Each bit of the result is set to the corresponding bit of either then_val or -// else_val depending on whether the corresponding bit of if_mask is set. -// Equivalent to the VBSL instruction in ARM NEON. -template -tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) { - return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); -} - -// For each input scalar, the corresponding bits of the result are set if the -// input scalar is non-zero. -template -tIntegerType MaskIfNonZero(tIntegerType a) { - static constexpr tIntegerType zero = 0; - return a ? BitNot(zero) : zero; -} - -// For each input scalar, the corresponding bits of the result are set if the -// input scalar is zero. -template -tIntegerType MaskIfZero(tIntegerType a) { - return MaskIfNonZero(!a); -} - -// For each pair of input scalars, the corresponding bits of the result are -// set if the input scalars are equal. -template -tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero(a == b); -} - -// For each pair of input scalars, the corresponding bits of the result are -// set if the input scalars are not equal. -template -tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero(a != b); -} - -// For each pair of input scalars, the corresponding bits of the result are -// set if the input scalars a, b satisfy a > b. -template -tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { - return MaskIfNonZero(a > b); -} - -// For each pair of input scalars, the corresponding bits of the result are -// set if the input scalars a, b satisfy a >= b. -template -tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero(a >= b); -} - -// For each pair of input scalars, the corresponding bits of the result are -// set if the input scalars a, b satisfy a < b. -template -tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { - return MaskIfNonZero(a < b); -} - -// For each pair of input scalars, the corresponding bits of the result are -// set if the input scalars a, b satisfy a <= b. -template -tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero(a <= b); -} - -// Returns true if all of the input scalars are nonzero. -// This function may currently assume that each of the input scalars has either -// all or none of its bits set. Otherwise, its behavior is currently undefined. -template -bool All(tIntegerType a) { - return a; -} - -// Returns true if any of the input scalars are nonzero. -// This function may currently assume that each of the input scalars has either -// all or none of its bits set. Otherwise, its behavior is currently undefined. -template -bool Any(tIntegerType a) { - return a; -} - -// Returns (a+b)/2, rounded to the nearest integer. -// Equivalent to VRHADD in the ARM NEON instruction set. -template -IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - (void)b; - return a; -} - -template <> -inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { - std::int64_t a64 = a; - std::int64_t b64 = b; - std::int64_t sum = a64 + b64; - std::int64_t sign = sum >= 0 ? 1 : -1; - return (std::int32_t)((sum + sign) / 2); -} - -template <> -inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { - std::int32_t a32 = a; - std::int32_t b32 = b; - std::int32_t sum = a32 + b32; - std::int32_t sign = sum >= 0 ? 1 : -1; - return (std::int16_t)((sum + sign) / 2); -} - -template -IntegerType SaturatingAdd(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - (void)b; - return a; -} - -// So far this is only needed for int16. -template <> -inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { - std::int32_t a32 = a; - std::int32_t b32 = b; - std::int32_t sum = a32 + b32; - return (std::int16_t)(std::min((std::int32_t)(32767), std::max((std::int32_t)(-32768), sum))); -} - -template <> -inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { - std::int16_t a16 = a; - std::int16_t b16 = b; - std::int16_t sum = a16 + b16; - return (std::int8_t)(std::min((int16_t)(std::numeric_limits::max()), - std::max((int16_t)(std::numeric_limits::min()), sum))); -} - -// Returns a+b, saturating if the integers are 16bit or narrower, -// otherwise just a plain addition. -template -struct AddSaturatingIf16BitImpl { - static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } -}; -template -struct AddSaturatingIf16BitImpl { - static IntegerType Run(IntegerType a, IntegerType b) { return SaturatingAdd(a, b); } -}; -template -IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { - using ScalarType = typename FixedPointRawTypeTraits::ScalarRawType; - return AddSaturatingIf16BitImpl::Run(a, b); -} - -// Returns the integer that represents the product of two fixed-point -// numbers, interpreting all integers as fixed-point values in the -// interval [-1, 1), rounding to the nearest value, and saturating -// -1 * -1 to the maximum value (since 1 is not in the half-open -// interval [-1, 1)). -// -// [The explanation below specializes to std::int32_t for example purpose.] -// -// The mapping between IntegerType and the interval [-1, 1) is unique and -// implied by IntegerType, which is assumed to be signed. For example, -// for IntegerType==std::int32_t, the mapping is -// real_value = integer_value / 2^31. -// So in this case, and leaving aside rounding and saturating, this -// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to -// (a * b) / 2^31. -// -// The 'doubling' part in the name of this function comes from the fact that -// this operation is very close to a "multiply-high" operation, keeping only -// the top half bits, except that that would be effectively computing -// (a * b) / 2^32, -// so here we are computing 2x that, since -// 1/2^31 = 2 * 1/2^32. -// The idea is to use all of the available 32 bits in the destination int32 -// value. -// -// [End of the explanation specializing to int32.] -// -// This is equivalent to the VQRDMULH instruction in ARM NEON. -template -IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { - static_assert(std::is_same::value, "unimplemented"); - (void)b; - return a; -} - -// This function implements the same computation as the ARMv7 NEON VQRDMULH -// instruction. -template <> -inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) { - bool overflow = a == b && a == std::numeric_limits::min(); - std::int64_t a_64(a); - std::int64_t b_64(b); - std::int64_t ab_64 = a_64 * b_64; - std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - std::int32_t ab_x2_high32 = (std::int32_t)((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits::max() : ab_x2_high32; -} - -template <> -inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) { - bool overflow = a == b && a == std::numeric_limits::min(); - std::int32_t a_32(a); - std::int32_t b_32(b); - std::int32_t ab_32 = a_32 * b_32; - std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); - std::int16_t ab_x2_high16 = (std::int16_t)((ab_32 + nudge) / (1 << 15)); - return overflow ? std::numeric_limits::max() : ab_x2_high16; +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +inline int SaturatingRoundingDoublingHighMul(int a, int b) { + if (a == INT_MIN && b == INT_MIN) { + return INT_MAX; + } + int64_t ab = ((int64_t)a) * ((int64_t)b); + int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30)); + // do not apply right shift to potential negetive values + int ab_mantissa = (int) ((ab + rounding) / (1ll << 31)); + return ab_mantissa; } -// Correctly-rounded-to-nearest division by a power-of-two. -// Also known as a rounding arithmetic right shift. -template -inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { - assert(exponent >= 0); - assert(exponent <= 31); - const IntegerType mask = Dup((1ll << exponent) - 1); - const IntegerType zero = Dup(0); - const IntegerType one = Dup(1); - const IntegerType remainder = BitAnd(x, mask); - const IntegerType threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); - return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one)); +// division by a 2^exponent with rounding +// or arithmetic right shift with rouding +inline int RoundingDivideByPOT(int x, int exponent) { + MS_ASSERT(exponent >= 0); + MS_ASSERT(exponent <= 31); + const int mask = (1ll << exponent) - 1; + const int remainder = x & mask; + const int threshold = (mask >> 1) + (x < 0 ? 1 : 0); + return (x >> exponent) + (remainder > threshold ? 1 : 0); } inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); } -// Returns the product of a run-time integer value by a compile-time power -// of two, with either a positive exponent (equivalent to an arithmetic -// left shift, saturating) or a negative exponent (equivalent to an arithmetic -// right shift, rounding to nearest). -template 0 ? 1 : Exponent < 0 ? -1 : 0)> -struct ImplSaturatingRoundingMultiplyByPOT {}; - -template -struct ImplSaturatingRoundingMultiplyByPOT { - static IntegerType eval(IntegerType x) { return x; } -}; - -template -struct ImplSaturatingRoundingMultiplyByPOT { - static IntegerType eval(IntegerType x) { - using ScalarIntegerType = typename FixedPointRawTypeTraits::ScalarRawType; - const IntegerType min = Dup(std::numeric_limits::min()); - const IntegerType max = Dup(std::numeric_limits::max()); - const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); - - const std::int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); - const IntegerType positive_mask = MaskIfGreaterThan(x, Dup(threshold)); - const IntegerType negative_mask = MaskIfLessThan(x, Dup(-threshold)); - - IntegerType result = ShiftLeft(x, Exponent); - result = SelectUsingMask(positive_mask, max, result); - result = SelectUsingMask(negative_mask, min, result); - return result; - } -}; - -template -struct ImplSaturatingRoundingMultiplyByPOT { - static IntegerType eval(IntegerType x) { return RoundingDivideByPOT(x, -Exponent); } -}; - -template -IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { - return ImplSaturatingRoundingMultiplyByPOT::eval(x); -} - -// Part 2: the FixedPoint class. - -// A FixedPoint object represents a fixed-point value stored in the underlying -// integer type tRawType, if tRawType is a plain scalar integer type. -// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which -// case a FixedPoint object represents a corresponding SIMD vector of fixed -// point values. -// -// tIntegerBits describes the range of the fixed-point format: if -// tIntegerBits == m then the range of representable values is the half-open -// interval [-2^m; 2^m) where the open boundary on the right side means that -// 2^m is not representable (how close the maximum representable value is to -// it, depends on bit-depth of tRawType). -// -// In "Q format notation", -// https://en.wikipedia.org/wiki/Q_(number_format) -// we are describing the format -// Qm.n -// where -// m = tIntegerBits -// and -// n = NumberOfBits(tRawType) - (m + 1) -// Note that the (m + 1) in the above line is because we adopt the convention -// that we count the integer bits exclusively of the sign bit; so (m + 1) is -// the total number of integer bits inclusive of the sign bit. -// -// Accordingly, the number of integral representable values in our range -// [-2^m ; 2^m) -// is equal to 2^(m+1). -template -class FixedPoint { - public: - typedef tRawType RawType; - - typedef FixedPointRawTypeTraits RawTypeTraits; - typedef typename RawTypeTraits::ScalarRawType ScalarRawType; - - static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); - static constexpr int kIntegerBits = tIntegerBits; - static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; - static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); - - typedef FixedPoint ScalarFixedPointType; - - static const ScalarRawType ScalarRawMin() { return std::numeric_limits::min(); } - - static const ScalarRawType ScalarRawMax() { return std::numeric_limits::max(); } - - static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); } - - static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); } - - static FixedPoint FromRaw(RawType x) { - FixedPoint retval; - retval.raw() = x; - return retval; - } - - static FixedPoint FromScalarRaw(ScalarRawType x) { - FixedPoint retval; - retval.raw() = Dup(x); - return retval; - } - - static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); } - - template - static FixedPoint ConstantPOT() { - static constexpr int kOffset = kFractionalBits + Exponent; - static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format"); - return FromScalarRaw(ScalarRawType(1) << kOffset); - } - - static FixedPoint Zero() { return FromScalarRaw(0); } - - static FixedPoint One() { - return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() - : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); - } - - static FixedPoint FromDouble(double x) { - const double min_bound = (double)(ScalarRawMin()); - const double max_bound = (double)(ScalarRawMax()); - return FromScalarRaw( - (ScalarRawType)(std::min(std::max(round(x * (double)(1ll << kFractionalBits)), min_bound), max_bound))); - } - - RawType raw() const { return i_; } - RawType &raw() { return i_; } - - private: - RawType i_; -}; - -// Part 3: implementation of arithmetic operators for the -// FixedPoint class, and a few related functions. - -// A FixedPoint multiplication is just a -// SaturatingRoundingDoublingHighMul operation on the underlying -// raw integer values. The IntegerBits simply add up, as is obvious -// from the fact that the range is [-2^IntegerBits, 2^IntegerBits). -template -FixedPoint operator*(FixedPoint a, - FixedPoint b) { - FixedPoint c; - c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); - return c; -} - -// Tweaking IntegerBits gives exact multiplication by a power of two. -template -FixedPoint ExactMulByPot(FixedPoint a) { - FixedPoint c; - c.raw() = a.raw(); - return c; -} - -// If we want to leave IntegerBits fixed, then multiplication -// by a power of two has to be saturating/rounding, not exact anymore. -template -FixedPoint SaturatingRoundingMultiplyByPOT(FixedPoint a) { - return FixedPoint::FromRaw(SaturatingRoundingMultiplyByPOT(a.raw())); -} - -// Generic arithmetic operators. - -#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ - template \ - FixedPoint FuncName(FixedPoint a) { \ - return FixedPoint::FromRaw(ImplFuncName(a.raw())); \ - } - -#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ - template \ - FixedPoint FuncName(FixedPoint a, \ - FixedPoint b) { \ - return FixedPoint::FromRaw(ImplFuncName(a.raw(), b.raw())); \ - } - -MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) -MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) -MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) -MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) -MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) -MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) -MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) -MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) - -#undef MAKE_FIXEDPOINT_UNARY_FUNC -#undef MAKE_FIXEDPOINT_BINARY_FUNC - -#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ - template \ - tRawType FuncName(FixedPoint a) { \ - return FuncName(a.raw()); \ - } - -#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ - template \ - tRawType FuncName(FixedPoint a, FixedPoint b) { \ - return FuncName(a.raw(), b.raw()); \ - } - -MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) -MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) - -#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW -#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW - -template -FixedPoint SelectUsingMask(tRawType if_mask, FixedPoint then_val, - FixedPoint else_val) { - return FixedPoint::FromRaw(SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); -} - -template -bool operator==(FixedPoint a, FixedPoint b) { - return All(MaskIfEqual(a.raw(), b.raw())); -} - -template -bool operator!=(FixedPoint a, FixedPoint b) { - return !(a == b); -} - -template -FixedPoint SaturatingAdd(FixedPoint a, - FixedPoint b) { - return FixedPoint::FromRaw(SaturatingAdd(a.raw(), b.raw())); -} - -template -FixedPoint AddSaturatingIf16Bit(FixedPoint a, - FixedPoint b) { - return FixedPoint::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw())); -} - -// Conversion to floating-point. -template -double ToDouble(FixedPoint x) { - static_assert(FixedPointRawTypeTraits::kLanes == 1, "not applicable to SIMD types"); - typedef FixedPoint F; - return x.raw() / (double)(1ll << F::kFractionalBits); -} - -// Rescale changes the number of IntegerBits and updates the underlying -// raw integer value accordingly. -template -FixedPoint Rescale(FixedPoint x) { - static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; - FixedPoint result; - result.raw() = SaturatingRoundingMultiplyByPOT(x.raw()); - return result; -} - -// CheckedFixedPointConstant allows to specify fixed-point constants -// initialized as real numbers, in a way that does not compile floating-point -// arithmetic in production code, yet still checks agreement with the -// floating-point expressions when asserts are enabled. -// -// The raw integer value provided is always a int32, encoding a 32-bit -// fixed-point value, regardless of the actual Scalar type. This allows -// writing generic code that applies just as well to the 32-bit and 16-bit -// cases. In the 16-bit case, the raw integer value is internally -// rounding-shifted by 16 bits to the right. -template -inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(std::int32_t int32_value) { - typedef typename FixedPointType::ScalarRawType ScalarRawType; - static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); - return (ScalarRawType)(RoundingDivideByPOT(int32_value, 32 - ScalarTypeBits)); -} - -// Implementation of exponential function. - -// Returns -tanh(x) for x < 0. -template -FixedPoint neg_tanh_on_negative_values(FixedPoint a) { - return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a))); -} - -// Returns tanh(x) for any x. -template -FixedPoint tanh(FixedPoint a) { - typedef FixedPoint InputF; - typedef FixedPoint ResultF; - tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); - tRawType mask_if_zero = MaskIfZero(a); - InputF n = SelectUsingMask(mask_if_negative, a, -a); - ResultF t = neg_tanh_on_negative_values(n); - return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t)); -} - -// Implementation of logistic function. - -// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. -template -FixedPoint logistic_on_positive_values(FixedPoint a) { - return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); -} - #ifdef ENABLE_NEON inline int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { const int32x4_t shift_vec = vdupq_n_s32(-exponent); diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc index 7e611a2728..7081c664af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc @@ -1494,6 +1494,169 @@ void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step) { #ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)), src_data_70); + float32x4_t t11 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)), src_data_71); + float32x4_t t12 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)), src_data_72); + float32x4_t t13 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)), src_data_73); + float32x4_t t14 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)), src_data_74); + float32x4_t t15 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)), src_data_75); + float32x4_t t16 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)), src_data_76); + float32x4_t t17 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)), t17); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); #else for (int i = 0; i < C4NUM; i++) { float src_data_00 = src_data[i]; @@ -1588,33 +1751,6 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float float d27 = src_data_56 - src_data_66; float d28 = src_data_57 - src_data_67; - float d31 = src_data_10 + src_data_20; - float d32 = src_data_11 + src_data_21; - float d33 = src_data_12 + src_data_22; - float d34 = src_data_13 + src_data_23; - float d35 = src_data_14 + src_data_24; - float d36 = src_data_15 + src_data_25; - float d37 = src_data_16 + src_data_26; - float d38 = src_data_17 + src_data_27; - - float d41 = src_data_30 + src_data_40; - float d42 = src_data_31 + src_data_41; - float d43 = src_data_32 + src_data_42; - float d44 = src_data_33 + src_data_43; - float d45 = src_data_34 + src_data_44; - float d46 = src_data_35 + src_data_45; - float d47 = src_data_36 + src_data_46; - float d48 = src_data_37 + src_data_47; - - float d51 = src_data_50 + src_data_60; - float d52 = src_data_51 + src_data_61; - float d53 = src_data_52 + src_data_62; - float d54 = src_data_53 + src_data_63; - float d55 = src_data_54 + src_data_64; - float d56 = src_data_55 + src_data_65; - float d57 = src_data_56 + src_data_66; - float d58 = src_data_57 + src_data_67; - float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; @@ -1635,25 +1771,13 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float float s11 = t01 - t02; float s12 = t11 - t12; - float s21 = t03 - t04; float s22 = t13 - t14; - float s31 = t05 - t06; float s32 = t15 - t16; - float s41 = t01 + t02; - float s42 = t11 + t12; - - float s51 = t03 + t04; - float s52 = t13 + t14; - - float s61 = t05 + t06; - float s62 = t15 + t16; - float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; float m01 = 0.5f * s11 + s21 + 1.5f * s31 + t07; - float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; float m11 = 0.5f * s12 + s22 + 1.5f * s32 + t17; @@ -1668,70 +1792,296 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step) { #ifdef ENABLE_ARM -#else - for (int i = 0; i < C4NUM; i++) { - float src_data_00 = src_data[i]; - float src_data_01 = src_data[i + src_step]; - float src_data_02 = src_data[i + 2 * src_step]; - float src_data_03 = src_data[i + 3 * src_step]; - float src_data_04 = src_data[i + 4 * src_step]; - float src_data_05 = src_data[i + 5 * src_step]; - float src_data_06 = src_data[i + 6 * src_step]; - float src_data_07 = src_data[i + 7 * src_step]; - float src_data_10 = src_data[i + 8 * src_step]; - float src_data_11 = src_data[i + 9 * src_step]; - float src_data_12 = src_data[i + 10 * src_step]; - float src_data_13 = src_data[i + 11 * src_step]; - float src_data_14 = src_data[i + 12 * src_step]; - float src_data_15 = src_data[i + 13 * src_step]; - float src_data_16 = src_data[i + 14 * src_step]; - float src_data_17 = src_data[i + 15 * src_step]; - float src_data_20 = src_data[i + 16 * src_step]; - float src_data_21 = src_data[i + 17 * src_step]; - float src_data_22 = src_data[i + 18 * src_step]; - float src_data_23 = src_data[i + 19 * src_step]; - float src_data_24 = src_data[i + 20 * src_step]; - float src_data_25 = src_data[i + 21 * src_step]; - float src_data_26 = src_data[i + 22 * src_step]; - float src_data_27 = src_data[i + 23 * src_step]; - float src_data_30 = src_data[i + 24 * src_step]; - float src_data_31 = src_data[i + 25 * src_step]; - float src_data_32 = src_data[i + 26 * src_step]; - float src_data_33 = src_data[i + 27 * src_step]; - float src_data_34 = src_data[i + 28 * src_step]; - float src_data_35 = src_data[i + 29 * src_step]; - float src_data_36 = src_data[i + 30 * src_step]; - float src_data_37 = src_data[i + 31 * src_step]; - float src_data_40 = src_data[i + 32 * src_step]; - float src_data_41 = src_data[i + 33 * src_step]; - float src_data_42 = src_data[i + 34 * src_step]; - float src_data_43 = src_data[i + 35 * src_step]; - float src_data_44 = src_data[i + 36 * src_step]; - float src_data_45 = src_data[i + 37 * src_step]; - float src_data_46 = src_data[i + 38 * src_step]; - float src_data_47 = src_data[i + 39 * src_step]; - float src_data_50 = src_data[i + 40 * src_step]; - float src_data_51 = src_data[i + 41 * src_step]; - float src_data_52 = src_data[i + 42 * src_step]; - float src_data_53 = src_data[i + 43 * src_step]; - float src_data_54 = src_data[i + 44 * src_step]; - float src_data_55 = src_data[i + 45 * src_step]; - float src_data_56 = src_data[i + 46 * src_step]; - float src_data_57 = src_data[i + 47 * src_step]; - float src_data_60 = src_data[i + 48 * src_step]; - float src_data_61 = src_data[i + 49 * src_step]; - float src_data_62 = src_data[i + 50 * src_step]; - float src_data_63 = src_data[i + 51 * src_step]; - float src_data_64 = src_data[i + 52 * src_step]; - float src_data_65 = src_data[i + 53 * src_step]; - float src_data_66 = src_data[i + 54 * src_step]; - float src_data_67 = src_data[i + 55 * src_step]; - float src_data_70 = src_data[i + 56 * src_step]; - float src_data_71 = src_data[i + 57 * src_step]; - float src_data_72 = src_data[i + 58 * src_step]; - float src_data_73 = src_data[i + 59 * src_step]; - float src_data_74 = src_data[i + 60 * src_step]; - float src_data_75 = src_data[i + 61 * src_step]; + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)), src_data_70); + float32x4_t t21 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)), src_data_71); + float32x4_t t22 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)), src_data_72); + float32x4_t t23 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)), src_data_73); + float32x4_t t24 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)), src_data_74); + float32x4_t t25 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)), src_data_75); + float32x4_t t26 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)), src_data_76); + float32x4_t t27 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)), t27); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; float src_data_76 = src_data[i + 62 * src_step]; float src_data_77 = src_data[i + 63 * src_step]; @@ -1867,9 +2217,266 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float #endif } -void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { -#ifdef ENABLE_ARM +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)), src_data_70); + float32x4_t t31 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)), src_data_71); + float32x4_t t32 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)), src_data_72); + float32x4_t t33 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)), src_data_73); + float32x4_t t34 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)), src_data_74); + float32x4_t t35 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)), src_data_75); + float32x4_t t36 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)), src_data_76); + float32x4_t t37 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)), t37); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); #else for (int i = 0; i < C4NUM; i++) { float src_data_00 = src_data[i]; @@ -1973,136 +2580,436 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float float d37 = src_data_16 + src_data_26; float d38 = src_data_17 + src_data_27; - float d41 = src_data_30 + src_data_40; - float d42 = src_data_31 + src_data_41; - float d43 = src_data_32 + src_data_42; - float d44 = src_data_33 + src_data_43; - float d45 = src_data_34 + src_data_44; - float d46 = src_data_35 + src_data_45; - float d47 = src_data_36 + src_data_46; - float d48 = src_data_37 + src_data_47; + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21 + src_data_70; + float t31 = 0.125f * d02 + d12 + 3.375f * d22 + src_data_71; + float t32 = 0.125f * d03 + d13 + 3.375f * d23 + src_data_72; + float t33 = 0.125f * d04 + d14 + 3.375f * d24 + src_data_73; + float t34 = 0.125f * d05 + d15 + 3.375f * d25 + src_data_74; + float t35 = 0.125f * d06 + d16 + 3.375f * d26 + src_data_75; + float t36 = 0.125f * d07 + d17 + 3.375f * d27 + src_data_76; + float t37 = 0.125f * d08 + d18 + 3.375f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34 + t37; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + } +#endif +} + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); - float d51 = src_data_50 + src_data_60; - float d52 = src_data_51 + src_data_61; - float d53 = src_data_52 + src_data_62; - float d54 = src_data_53 + src_data_63; - float d55 = src_data_54 + src_data_64; - float d56 = src_data_55 + src_data_65; - float d57 = src_data_56 + src_data_66; - float d58 = src_data_57 + src_data_67; + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); - float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; - float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; - float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; - float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; - float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; - float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; - float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; - float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); - float t10 = 0.5f * d01 + d11 + 1.5f * d21; - float t11 = 0.5f * d02 + d12 + 1.5f * d22; - float t12 = 0.5f * d03 + d13 + 1.5f * d23; - float t13 = 0.5f * d04 + d14 + 1.5f * d24; - float t14 = 0.5f * d05 + d15 + 1.5f * d25; - float t15 = 0.5f * d06 + d16 + 1.5f * d26; - float t16 = 0.5f * d07 + d17 + 1.5f * d27; - float t17 = 0.5f * d08 + d18 + 1.5f * d28; + float32x4_t t30 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)); + float32x4_t t31 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)); + float32x4_t t32 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)); + float32x4_t t33 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)); + float32x4_t t34 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)); + float32x4_t t35 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)); + float32x4_t t36 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)); + float32x4_t t37 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)); - float t20 = 0.25f * d31 + d41 + 2.25f * d51; - float t21 = 0.25f * d32 + d42 + 2.25f * d52; - float t22 = 0.25f * d33 + d43 + 2.25f * d53; - float t23 = 0.25f * d34 + d44 + 2.25f * d54; - float t24 = 0.25f * d35 + d45 + 2.25f * d55; - float t25 = 0.25f * d36 + d46 + 2.25f * d56; - float t26 = 0.25f * d37 + d47 + 2.25f * d57; - float t27 = 0.25f * d38 + d48 + 2.25f * d58; + float32x4_t t40 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.0625), d41), vmulq_n_f32(d51, 5.0625)), src_data_70); + float32x4_t t41 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.0625), d42), vmulq_n_f32(d52, 5.0625)), src_data_71); + float32x4_t t42 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.0625), d43), vmulq_n_f32(d53, 5.0625)), src_data_72); + float32x4_t t43 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.0625), d44), vmulq_n_f32(d54, 5.0625)), src_data_73); + float32x4_t t44 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.0625), d45), vmulq_n_f32(d55, 5.0625)), src_data_74); + float32x4_t t45 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.0625), d46), vmulq_n_f32(d56, 5.0625)), src_data_75); + float32x4_t t46 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.0625), d47), vmulq_n_f32(d57, 5.0625)), src_data_76); + float32x4_t t47 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.0625), d48), vmulq_n_f32(d58, 5.0625)), src_data_77); - float t30 = 0.125f * d01 + d11 + 3.375f * d21 + src_data_70; - float t31 = 0.125f * d02 + d12 + 3.375f * d22 + src_data_71; - float t32 = 0.125f * d03 + d13 + 3.375f * d23 + src_data_72; - float t33 = 0.125f * d04 + d14 + 3.375f * d24 + src_data_73; - float t34 = 0.125f * d05 + d15 + 3.375f * d25 + src_data_74; - float t35 = 0.125f * d06 + d16 + 3.375f * d26 + src_data_75; - float t36 = 0.125f * d07 + d17 + 3.375f * d27 + src_data_76; - float t37 = 0.125f * d08 + d18 + 3.375f * d28 + src_data_77; + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + float32x4_t s15 = vsubq_f32(t41, t42); - float s11 = t01 - t02; - float s12 = t11 - t12; - float s13 = t21 - t22; - float s14 = t31 - t32; + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + float32x4_t s25 = vsubq_f32(t43, t44); - float s21 = t03 - t04; - float s22 = t13 - t14; - float s23 = t23 - t24; - float s24 = t33 - t34; + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + float32x4_t s35 = vsubq_f32(t45, t46); - float s31 = t05 - t06; - float s32 = t15 - t16; - float s33 = t25 - t26; - float s34 = t35 - t36; + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + float32x4_t s45 = vaddq_f32(t41, t42); - float s41 = t01 + t02; - float s42 = t11 + t12; - float s43 = t21 + t22; - float s44 = t31 + t32; + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + float32x4_t s55 = vaddq_f32(t43, t44); - float s51 = t03 + t04; - float s52 = t13 + t14; - float s53 = t23 + t24; - float s54 = t33 + t34; + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + float32x4_t s65 = vaddq_f32(t45, t46); - float s61 = t05 + t06; - float s62 = t15 + t16; - float s63 = t25 + t26; - float s64 = t35 + t36; + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)); + float32x4_t m04 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.0625), s51), vmulq_n_f32(s61, 5.0625)), t07); - float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; - float m01 = 0.5f * s11 + s21 + 1.5f * s31; - float m02 = 0.25f * s41 + s51 + 2.25f * s61; - float m03 = 0.125f * s11 + s21 + 3.375f * s31 + t07; + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)); + float32x4_t m14 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.0625), s52), vmulq_n_f32(s62, 5.0625)), t17); - float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; - float m11 = 0.5f * s12 + s22 + 1.5f * s32; - float m12 = 0.25f * s42 + s52 + 2.25f * s62; - float m13 = 0.125f * s12 + s22 + 3.375f * s32 + t17; + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)); + float32x4_t m24 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.0625), s53), vmulq_n_f32(s63, 5.0625)), t27); - float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; - float m21 = 0.5f * s13 + s23 + 1.5f * s33; - float m22 = 0.25f * s43 + s53 + 2.25f * s63; - float m23 = 0.125f * s13 + s23 + 3.375f * s33 + t27; + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)); + float32x4_t m34 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.0625), s54), vmulq_n_f32(s64, 5.0625)), t37); - float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; - float m31 = 0.5f * s14 + s24 + 1.5f * s34; - float m32 = 0.25f * s44 + s54 + 2.25f * s64; - float m33 = 0.125f * s14 + s24 + 3.375f * s34 + t37; + float32x4_t m40 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t40, t41), t42), t43), t44), t45), t46); + float32x4_t m41 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.5), s25), vmulq_n_f32(s35, 1.5)); + float32x4_t m42 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.25), s55), vmulq_n_f32(s65, 2.25)); + float32x4_t m43 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.125), s25), vmulq_n_f32(s35, 3.375)); + float32x4_t m44 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.0625), s55), vmulq_n_f32(s65, 5.0625)), t47); - (dst_data + i)[0] = m00 + bias_data[i]; - (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; - (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; - (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + vst1q_f32(dst_data + 4 * C4NUM, vaddq_f32(m04, bias_ptr)); - (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; - (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; - (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; - (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m14, bias_ptr)); - (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; - (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; - (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; - (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m24, bias_ptr)); - (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; - (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; - (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; - (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; - } -#endif -} + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m34, bias_ptr)); -void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, - int dst_step) { -#ifdef ENABLE_ARM + vst1q_f32(dst_data + 4 * dst_step * C4NUM, vaddq_f32(m40, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, vaddq_f32(m41, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m42, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m43, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m44, bias_ptr)); #else for (int i = 0; i < C4NUM; i++) { float src_data_00 = src_data[i]; @@ -3801,4 +4708,3 @@ OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit) { return nullptr; } } -