| @@ -113,6 +113,9 @@ if (BUILD_DEVICE) | |||
| if (PLATFORM_ARM64) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||
| add_compile_definitions(ENABLE_ARM64) | |||
| if (ENABLE_FP16) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||
| endif () | |||
| endif() | |||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) | |||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark) | |||
| @@ -1,15 +1,16 @@ | |||
| file(GLOB_RECURSE KERNEL_SRC | |||
| file(GLOB KERNEL_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp32/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/int8/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/quantization/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc | |||
| ) | |||
| if (PLATFORM_ARM64) | |||
| # assembly | |||
| file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s | |||
| file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.S) | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | |||
| @@ -17,13 +18,15 @@ endif() | |||
| if (PLATFORM_ARM32) | |||
| # assembly | |||
| file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s) | |||
| file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.S | |||
| ) | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | |||
| endif() | |||
| if (ENABLE_FP16) | |||
| file(GLOB_RECURSE FP6_SRC | |||
| file(GLOB FP6_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc | |||
| ) | |||
| @@ -101,7 +101,6 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() { | |||
| int ConvolutionBaseCPUKernel::Init() { | |||
| auto input = this->inputs_.front(); | |||
| auto output = this->outputs_.front(); | |||
| conv_param_->input_batch_ = input->Batch(); | |||
| conv_param_->input_h_ = input->Height(); | |||
| conv_param_->input_w_ = input->Width(); | |||
| @@ -111,7 +110,6 @@ int ConvolutionBaseCPUKernel::Init() { | |||
| conv_param_->output_w_ = output->Width(); | |||
| conv_param_->output_channel_ = output->Channel(); | |||
| conv_param_->thread_num_ = ctx_->threadNum; | |||
| return RET_OK; | |||
| } | |||
| @@ -221,9 +219,24 @@ void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *con | |||
| } | |||
| } | |||
| kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const Context *ctx) { | |||
| bool CheckSupportFP16() { | |||
| bool support_fp16 = false; | |||
| #ifdef ENABLE_ARM64 | |||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | |||
| if (optimize_op_handler != nullptr) { | |||
| support_fp16 = true; | |||
| MS_LOG(INFO) << "Support FP16."; | |||
| } else { | |||
| support_fp16 = false; | |||
| MS_LOG(INFO) << "Your machine doesn't support fp16, return back to float32 kernel."; | |||
| } | |||
| #endif | |||
| return support_fp16; | |||
| } | |||
| kernel::LiteKernel *CpuConvFloatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const Context *ctx) { | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| @@ -240,43 +253,34 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||
| InputTransformUnitFunc input_trans_func = nullptr; | |||
| OutputTransformUnitFunc output_trans_func = nullptr; | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | |||
| bool support_fp16 = CheckSupportFP16(); | |||
| if (kernel_h == 1 && kernel_w == 1) { | |||
| auto kernel = new (std::nothrow) Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| if (support_fp16) { | |||
| #ifdef ENABLE_FP16 | |||
| auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| #endif | |||
| } | |||
| auto kernel = new (std::nothrow) Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| } else if (use_winograd) { | |||
| auto kernel = new (std::nothrow) ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit); | |||
| return kernel; | |||
| } else { | |||
| auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| } | |||
| } | |||
| if (support_fp16) { | |||
| #ifdef ENABLE_FP16 | |||
| kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const Context *ctx) { | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int stride_h = conv_param->stride_h_; | |||
| int stride_w = conv_param->stride_w_; | |||
| int dilation_h = conv_param->dilation_h_; | |||
| int dilation_w = conv_param->dilation_w_; | |||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| } else { | |||
| auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| #endif | |||
| } | |||
| auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); | |||
| return kernel; | |||
| } | |||
| } | |||
| #endif | |||
| kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| @@ -308,17 +312,10 @@ kernel::LiteKernel *CpuConvKernelCreator(const std::vector<lite::tensor::Tensor | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| switch (data_type) { | |||
| case kNumberTypeInt8: | |||
| break; | |||
| case kNumberTypeUInt8: | |||
| kernel = CpuConvInt8KernelCreator(inputs, outputs, opParameter, ctx); | |||
| break; | |||
| #ifdef ENABLE_FP16 | |||
| case kNumberTypeFloat16: | |||
| kernel = CpuConvFp16KernelCreator(inputs, outputs, opParameter, ctx); | |||
| break; | |||
| #endif | |||
| case kNumberTypeFloat32: | |||
| kernel = CpuConvFp32KernelCreator(inputs, outputs, opParameter, ctx); | |||
| kernel = CpuConvFloatKernelCreator(inputs, outputs, opParameter, ctx); | |||
| break; | |||
| default: | |||
| break; | |||
| @@ -385,8 +382,6 @@ kernel::LiteKernel *CpuConvDwKernelCreator(const std::vector<lite::tensor::Tenso | |||
| case kNumberTypeInt8: | |||
| kernel = CpuConvDwInt8KernelCreator(inputs, outputs, opParameter, ctx); | |||
| break; | |||
| case kNumberTypeUInt8: | |||
| break; | |||
| case kNumberTypeFloat32: | |||
| #ifdef ENABLE_FP16 | |||
| kernel = CpuConvDwFp16KernelCreator(inputs, outputs, opParameter, ctx); | |||
| @@ -515,8 +510,6 @@ kernel::LiteKernel *CpuDeConvKernelCreator(const std::vector<lite::tensor::Tenso | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| switch (data_type) { | |||
| case kNumberTypeInt8: | |||
| break; | |||
| case kNumberTypeUInt8: | |||
| kernel = CpuDeConvInt8KernelCreator(inputs, outputs, opParameter, ctx); | |||
| break; | |||
| #ifdef ENABLE_FP16 | |||
| @@ -26,10 +26,9 @@ | |||
| #include <android/log.h> | |||
| #endif | |||
| #include "src/lite_kernel.h" | |||
| #include "include/context.h" | |||
| #include "src/runtime/kernel/arm/base/layout_transform.h" | |||
| #include "src/runtime/kernel/arm/opclib/optimized_kernel.h" | |||
| using mindspore::lite::Context; | |||
| using mindspore::schema::PadMode; | |||
| @@ -40,7 +39,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||
| public: | |||
| ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx) | |||
| : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { | |||
| : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { | |||
| opParameter->thread_num_ = ctx->threadNum; | |||
| conv_param_ = reinterpret_cast<ConvParameter *>(opParameter); | |||
| } | |||
| @@ -63,6 +62,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||
| LayoutConvertor convert_func_; | |||
| }; | |||
| void ComputeQuantOutRange(ConvParameter *conv_param); | |||
| bool CheckSupportFP16(); | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ | |||
| @@ -49,6 +49,8 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara | |||
| int Convolution3x3FP16CPUKernel::InitWeightBias() { | |||
| auto input_channel = conv_param_->input_channel_; | |||
| int output_channel = conv_param_->output_channel_; | |||
| int kernel_h = conv_param_->kernel_h_; | |||
| int kernel_w = conv_param_->kernel_w_; | |||
| int iC4 = UP_DIV(input_channel, C4NUM); | |||
| int oC8 = UP_DIV(output_channel, C8NUM); | |||
| // init weight | |||
| @@ -60,8 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { | |||
| } | |||
| memset(transformed_filter_addr_, 0, transformed_size); | |||
| float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data()); | |||
| size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); | |||
| fp16_weight_ = malloc(fp16_weight_size); | |||
| size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t); | |||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||
| if (fp16_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||
| return RET_ERROR; | |||
| @@ -74,16 +76,17 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() { | |||
| // init bias | |||
| size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); | |||
| bias_data_ = reinterpret_cast<float16_t *>(malloc(new_bias_size)); | |||
| bias_data_ = malloc(new_bias_size); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, new_bias_size); | |||
| auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_); | |||
| if (inputs_.size() == kInputSize2) { | |||
| auto ori_bias_addr = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data()); | |||
| for (int i = 0; i < out_channel; ++i) { | |||
| bias_data_[i] = (float16_t)ori_bias_addr[i]; | |||
| for (int i = 0; i < output_channel; ++i) { | |||
| fp16_bias_data[i] = (float16_t)ori_bias_addr[i]; | |||
| } | |||
| } else { | |||
| MS_ASSERT(inputs_.size() == kInputSize1); | |||
| @@ -129,16 +132,15 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { | |||
| } | |||
| memset(tmp_out_, 0, tmp_out_size); | |||
| size_t fp16_input_size = | |||
| in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | |||
| fp16_input_ = malloc(fp16_input_size); | |||
| size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ * | |||
| conv_param_->input_w_ * sizeof(float16_t); | |||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | |||
| if (fp16_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(fp16_input_, 0, fp16_input_size); | |||
| // init nhwc4 input | |||
| size_t nhwc4_input_size = | |||
| iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | |||
| @@ -249,4 +251,3 @@ int Convolution3x3FP16CPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -42,7 +42,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||
| // init weight | |||
| float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data()); | |||
| size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); | |||
| fp16_weight_ = malloc(fp16_weight_size); | |||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | |||
| if (fp16_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | |||
| return RET_ERROR; | |||
| @@ -60,16 +60,17 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||
| PackWeightFp16(fp16_weight_, conv_param_, packed_weight_); | |||
| // init bias | |||
| bias_data_ = reinterpret_cast<float16_t *>(malloc(oc8 * C8NUM * sizeof(float16_t))); | |||
| bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, oc8 * C8NUM * sizeof(float16_t)); | |||
| auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_); | |||
| if (inputs_.size() == kInputSize2) { | |||
| auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data()); | |||
| for (int i = 0; i < out_channel; ++i) { | |||
| bias_data_[i] = (float16_t)ori_bias[i]; | |||
| fp16_bias_data[i] = (float16_t)ori_bias[i]; | |||
| } | |||
| } else { | |||
| MS_ASSERT(inputs_.size() == kInputSize1); | |||
| @@ -101,7 +102,7 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||
| size_t fp16_input_size = | |||
| in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); | |||
| fp16_input_ = malloc(fp16_input_size); | |||
| fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size)); | |||
| if (fp16_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc fp16_input_ failed."; | |||
| return RET_ERROR; | |||
| @@ -20,9 +20,7 @@ | |||
| #include <arm_neon.h> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||
| #include "src/runtime/kernel/arm/opclib/optimized_kernel.h" | |||
| namespace mindspore::kernel { | |||
| typedef void (*FP16_GEMM_FUNC)(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||
| @@ -33,7 +31,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||
| public: | |||
| ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx) | |||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} | |||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} | |||
| ~ConvolutionFP16CPUKernel() override { | |||
| if (fp16_input_ != nullptr) { | |||
| free(fp16_input_); | |||
| @@ -33,10 +33,17 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| int kernel_w = conv_param_->kernel_w_; | |||
| int in_channel = conv_param_->input_channel_; | |||
| int out_channel = conv_param_->output_channel_; | |||
| int oc8 = UP_DIV(out_channel, C8NUM); | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; | |||
| int oc_block, oc_block_num; | |||
| #ifdef ENABLE_ARM32 | |||
| oc_block = C4NUM; | |||
| oc_block_num = UP_DIV(out_channel, C4NUM); | |||
| #else | |||
| oc_block = C8NUM; | |||
| oc_block_num = UP_DIV(out_channel, C8NUM); | |||
| #endif | |||
| int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane; | |||
| // init weight | |||
| auto origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data()); | |||
| @@ -49,12 +56,12 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| PackWeightFp32(origin_weight, conv_param_, packed_weight_); | |||
| // init bias | |||
| bias_data_ = reinterpret_cast<float *>(malloc(oc8 * C8NUM * sizeof(float))); | |||
| bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float))); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc bias failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, oc8 * C8NUM * sizeof(float)); | |||
| memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float)); | |||
| if (inputs_.size() == kInputSize2) { | |||
| auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data()); | |||
| memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); | |||
| @@ -198,4 +205,3 @@ int ConvolutionCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -55,6 +55,7 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() { | |||
| support_optimize_ = false; | |||
| } | |||
| #endif | |||
| conv_param_->tile_num_ = tile_num_; | |||
| } | |||
| int ConvolutionInt8CPUKernel::InitWeightBias() { | |||
| @@ -78,7 +79,7 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { | |||
| return RET_ERROR; | |||
| } | |||
| memset(packed_weight_, 0, pack_weight_size); | |||
| int32_t *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel)); | |||
| auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel)); | |||
| for (int i = 0; i < out_channel; i++) weight_sum[i] = 0; | |||
| PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum); | |||
| @@ -105,15 +106,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() { | |||
| } | |||
| int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||
| int tile_n = 4; | |||
| int output_count = conv_param_->output_h_ * conv_param_->output_w_; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| int output_tile_count = UP_DIV(output_count, tile_num_); | |||
| int in_channel = conv_param_->input_channel_; | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; | |||
| int plane_c4 = UP_DIV(kernel_plane, C4NUM); | |||
| int unit_size = plane_c4 * C4NUM * ic4 * C4NUM; | |||
| int packed_input_size = output_tile_count * tile_n * unit_size; | |||
| int packed_input_size = output_tile_count * tile_num_ * unit_size; | |||
| packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size)); | |||
| if (packed_input_ == nullptr) { | |||
| @@ -122,14 +122,14 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||
| } | |||
| memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); | |||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_n * thread_count_ * sizeof(int32_t))); | |||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); | |||
| if (input_sum_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc input_sum_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t)); | |||
| memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); | |||
| size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t); | |||
| size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); | |||
| tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size)); | |||
| if (tmp_dst_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_dst_ failed."; | |||
| @@ -137,7 +137,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() { | |||
| } | |||
| memset(tmp_dst_, 0, tmp_dst_size); | |||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_n * conv_param_->output_channel_)); | |||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); | |||
| if (tmp_out_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_out_ failed."; | |||
| return RET_ERROR; | |||
| @@ -173,7 +173,7 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { | |||
| return RET_ERROR; | |||
| } | |||
| memset(packed_weight_, filter_zp, pack_weight_size); | |||
| int32_t *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel)); | |||
| auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel)); | |||
| for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane; | |||
| PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum); | |||
| @@ -200,15 +200,13 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { | |||
| } | |||
| int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||
| // todo | |||
| int tile_n = 24; | |||
| int output_count = conv_param_->output_h_ * conv_param_->output_w_; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| int output_tile_count = UP_DIV(output_count, tile_num_); | |||
| int in_channel = conv_param_->input_channel_; | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; | |||
| int unit_size = kernel_plane * ic4 * C4NUM; | |||
| int packed_input_size = output_tile_count * tile_n * unit_size; | |||
| int packed_input_size = output_tile_count * tile_num_ * unit_size; | |||
| packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size)); | |||
| if (packed_input_ == nullptr) { | |||
| @@ -217,14 +215,14 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||
| } | |||
| memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); | |||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_n * thread_count_ * sizeof(int32_t))); | |||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t))); | |||
| if (input_sum_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc input_sum_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t)); | |||
| memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t)); | |||
| size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t); | |||
| size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t); | |||
| tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size)); | |||
| if (tmp_dst_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_dst_ failed."; | |||
| @@ -232,7 +230,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { | |||
| } | |||
| memset(tmp_dst_, 0, tmp_dst_size); | |||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_n * conv_param_->output_channel_)); | |||
| tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_)); | |||
| if (tmp_out_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc tmp_out_ failed."; | |||
| return RET_ERROR; | |||
| @@ -5,19 +5,22 @@ set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../) | |||
| include_directories(OPTIMIZED_OP_DIR) | |||
| ########################### optimized files ########################### | |||
| set(FP16_ASSEMBLY | |||
| # ${OPTIMIZED_OP_DIR}/assembly/arm64/IndirectGemmFp16_16x8.s | |||
| file(GLOB OPTIMIZED_ASSEMBLY | |||
| ${OPTIMIZED_OP_DIR}/assembly/opt/*.s | |||
| ${OPTIMIZED_OP_DIR}/assembly/opt/*.S | |||
| ) | |||
| file(GLOB_RECURSE OPTIMIZED_INT8_ASSEMBLY | |||
| ${OPTIMIZED_OP_DIR}/assembly/opt/*.S | |||
| file(GLOB FP16_SRC | |||
| # ${OPTIMIZED_OP_DIR}/fp16/*.cc | |||
| # ${OPTIMIZED_OP_DIR}/../fp16/*.cc | |||
| ) | |||
| ########################### share library build ######################## | |||
| set(OPTIMIZED_OPS "opt_op_handler.c") | |||
| set_property(SOURCE ${OPTIMIZED_INT8_ASSEMBLY} PROPERTY LANGUAGE C) | |||
| list(APPEND OPTIMIZED_OPS ${OPTIMIZED_INT8_ASSEMBLY} ${FP16_ASSEMBLY}) | |||
| set_property(SOURCE ${OPTIMIZED_ASSEMBLY} PROPERTY LANGUAGE C) | |||
| list(APPEND OPTIMIZED_OPS ${OPTIMIZED_ASSEMBLY} ${FP16_SRC}) | |||
| if (PLATFORM_ARM64) | |||
| string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/runtime/kernel/arm/opclib/common_func.h" | |||
| #include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" | |||
| #ifndef ENABLE_ARM | |||
| #ifndef __aarch64__ | |||
| void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, | |||
| int output_channel, size_t offset, size_t relu, size_t relu6) { | |||
| for (int i = 0; i < TILE_NUM; i++) { | |||
| @@ -108,24 +108,49 @@ int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } | |||
| int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } | |||
| void ReluFp32(float *data, int ele_num) { | |||
| for (int i = 0; i < ele_num; i++) { | |||
| if (data[i] < 0) { | |||
| data[i] = 0; | |||
| } else { | |||
| // do nothing | |||
| } | |||
| int four_block = UP_DIV(ele_num, C4NUM); | |||
| for (int i = 0; i < four_block - 1; i++) { | |||
| int index = i * C4NUM; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t relu_data = vld1q_f32(data + index); | |||
| float32x4_t zero_data = vdupq_n_f32(0); | |||
| relu_data = vmaxq_f32(relu_data, zero_data); | |||
| #else | |||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||
| #endif | |||
| } | |||
| for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) { | |||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||
| } | |||
| } | |||
| void Relu6Fp32(float *data, int ele_num) { | |||
| for (int i = 0; i < ele_num; i++) { | |||
| if (data[i] < 0) { | |||
| data[i] = 0; | |||
| } else if (data[i] > 6) { | |||
| data[i] = 6; | |||
| } else { | |||
| // do nothing | |||
| } | |||
| int four_block = UP_DIV(ele_num, C4NUM); | |||
| for (int i = 0; i < four_block - 1; i++) { | |||
| int index = i * C4NUM; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t relu6_data = vld1q_f32(data + index); | |||
| float32x4_t zero_data = vdupq_n_f32(0); | |||
| float32x4_t six_data = vdupq_n_f32(6); | |||
| relu6_data = vmaxq_f32(relu6_data, zero_data); | |||
| relu6_data = vminq_f32(relu6_data, six_data); | |||
| #else | |||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||
| data[index] = data[index] > 6 ? 6 : data[index]; | |||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||
| data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1]; | |||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||
| data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2]; | |||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||
| data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3]; | |||
| #endif | |||
| } | |||
| for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) { | |||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||
| data[j] = data[j] > 6 ? 6 : data[j]; | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ struct ConvParameter { | |||
| int pad_l_; | |||
| int pad_r_; | |||
| int group_; | |||
| int n_dim_; | |||
| int tile_num_; | |||
| int input_batch_; | |||
| int input_h_; | |||
| int input_w_; | |||
| @@ -141,7 +141,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||
| int start_index = thread_id * tile_n; | |||
| int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; | |||
| float16_t *gemm_input = | |||
| (float *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); | |||
| (float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); | |||
| Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | |||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | |||
| @@ -16,7 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/opclib/fp32/common_func.h" | |||
| #ifndef ENABLE_ARM | |||
| #ifndef __aarch64__ | |||
| void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, | |||
| size_t row, size_t col) { | |||
| for (int r = 0; r < row; r++) { | |||
| @@ -31,13 +31,12 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons | |||
| int out_w = conv_param->output_w_; | |||
| int out_channel = conv_param->output_channel_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int tile_n = 8; | |||
| int output_count = out_h * out_w; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int unit_size = kernel_plane * ic4 * C4NUM; | |||
| int packed_input_size = output_tile_count * tile_n * unit_size; | |||
| int packed_input_size = output_tile_count * TILE_NUM * unit_size; | |||
| // we accumulate 4 channels per time for input blocks | |||
| int conv_depth = kernel_h * kernel_w; | |||
| @@ -50,13 +49,13 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons | |||
| int out_batch_offset = b * out_channel * out_h * out_w; | |||
| int gemm_in_batch_offset = b * packed_input_size; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | |||
| int start_index = thread_id * tile_n; | |||
| int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; | |||
| float *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; | |||
| int start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | |||
| float *gemm_input = packed_input + thread_id * unit_size * TILE_NUM + gemm_in_batch_offset; | |||
| Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | |||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | |||
| if (real_cal_num == tile_n) { | |||
| int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset; | |||
| if (real_cal_num == TILE_NUM) { | |||
| float *gemm_output = output_data + out_offset; | |||
| IndirectGemmFp32_8x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, | |||
| output_offset, 0, 0, conv_param->is_relu_, conv_param->is_relu6_); | |||
| @@ -121,22 +120,8 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||
| } | |||
| } | |||
| // get real output | |||
| for (int batch = 0; batch < out_batch; batch++) { | |||
| int batch_size = batch * out_channel * conv_param->output_h_ * conv_param->output_w_; | |||
| for (int h = 0; h < conv_param->output_h_; h++) { | |||
| for (int w = 0; w < conv_param->output_w_; w++) { | |||
| for (int c = 0; c < out_channel; c++) { | |||
| int oc4_block = c / C4NUM; | |||
| int oc4_res = c % C4NUM; | |||
| int src_offset = oc4_block * C4NUM * out_w_block * out_h_block * out_unit * out_unit + | |||
| C4NUM * (h * out_w_block * out_unit + w) + oc4_res; | |||
| int dst_offset = (h * conv_param->output_w_ + w) * out_channel + c; | |||
| (output_data + dst_offset)[0] = (tmp_out_data + src_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel, | |||
| out_unit); | |||
| int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; | |||
| if (is_relu) { | |||
| ReluFp32(output_data, output_num); | |||
| @@ -147,6 +132,45 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||
| } | |||
| } | |||
| void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, | |||
| int output_unit) { | |||
| int out_h_block_num = UP_DIV(height, output_unit); | |||
| int out_w_block_num = UP_DIV(width, output_unit); | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_batch_offset = b * c4 * C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; | |||
| int dst_batch_offset = b * height * width * channel; | |||
| for (int h = 0; h < height; h++) { | |||
| int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit); | |||
| int dst_h_offset = dst_batch_offset + h * width * channel; | |||
| for (int w = 0; w < width; w++) { | |||
| int src_w_offset = src_h_offset + w * C4NUM; | |||
| int dst_w_offset = dst_h_offset + w * channel; | |||
| for (int c = 0; c < c4 - 1; c++) { | |||
| int src_c4_offset = src_w_offset + c * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; | |||
| int dst_c4_offset = dst_w_offset + c * C4NUM; | |||
| #ifdef ENABLE_NEON | |||
| vst1q_f32(dst + dst_c4_offset, vld1q_f32(src + src_c4_offset)); | |||
| #else | |||
| dst[dst_c4_offset] = src[src_c4_offset]; | |||
| dst[dst_c4_offset + 1] = src[src_c4_offset + 1]; | |||
| dst[dst_c4_offset + 2] = src[src_c4_offset + 2]; | |||
| dst[dst_c4_offset + 3] = src[src_c4_offset + 3]; | |||
| #endif | |||
| } | |||
| int c_res = channel - (c4 - 1) * C4NUM; | |||
| int src_c_res_offset = (c4 - 1) * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit; | |||
| int dst_c_res_offset = (c4 - 1) * C4NUM; | |||
| for (int c = 0; c < c_res; c++) { | |||
| int src_c4_res_offset = src_w_offset + src_c_res_offset + c; | |||
| int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c; | |||
| dst[dst_c4_res_offset] = src[src_c4_res_offset]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // fp32 conv3x3 | |||
| void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, | |||
| TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param) { | |||
| @@ -182,7 +206,7 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||
| } | |||
| PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); | |||
| } | |||
| int output_num = oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; | |||
| int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; | |||
| if (is_relu) { | |||
| ReluFp32(output_data, output_num); | |||
| } else if (is_relu6) { | |||
| @@ -191,4 +215,3 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||
| // do nothing | |||
| } | |||
| } | |||
| @@ -42,10 +42,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||
| TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, | |||
| InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func); | |||
| void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit); | |||
| // fp32 conv3x3 | |||
| void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, | |||
| TmpBufferAddress *buffer_list, int task_id, | |||
| ConvParameter *conv_param); | |||
| TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_ | |||
| @@ -81,8 +81,9 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo | |||
| } // win_w loop | |||
| } // win_h loop | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t dup_count = vdupq_n_f32(real_count); | |||
| vst1q_f32(output_ptr + out_channel_offset, vdivq_f32(tmp_avg, dup_count)); | |||
| float reverse_count = 1 / real_count; | |||
| float32x4_t dup_count = vdupq_n_f32(reverse_count); | |||
| vst1q_f32(output_ptr + out_channel_offset, vmulq_f32(tmp_avg, dup_count)); | |||
| #else | |||
| *(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count; | |||
| *(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count; | |||
| @@ -24,11 +24,6 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t * | |||
| size_t oc4, size_t offset); | |||
| #ifdef ENABLE_ARM64 | |||
| // void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t | |||
| // ksize, | |||
| // size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, | |||
| // size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t | |||
| // shift_before, size_t shift_after); | |||
| void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, | |||
| size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, | |||
| size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, | |||
| @@ -54,8 +49,9 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in | |||
| #ifdef __aarch64__ | |||
| IndirectGemmInt8_4x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), | |||
| input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); | |||
| // todo arm32 | |||
| #else | |||
| int tile_num = 4; | |||
| int tile_num = conv_param->tile_num_; | |||
| int plane_c4 = UP_DIV(kernel_plane, C4NUM); | |||
| for (int oc = 0; oc < output_channel; oc++) { | |||
| int oc4_block = oc / C4NUM; | |||
| @@ -109,7 +105,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const | |||
| act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); | |||
| #endif | |||
| } else { | |||
| int tile_num = 24; | |||
| int tile_num = conv_param->tile_num_; | |||
| for (int oc = 0; oc < output_channel; oc++) { | |||
| int oc4_block = oc / C4NUM; | |||
| int oc4_res = oc % C4NUM; | |||
| @@ -202,7 +198,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c | |||
| int out_channel = conv_param->output_channel_; | |||
| int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; | |||
| int tile_n = 4; | |||
| int tile_n = conv_param->tile_num_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int output_count = out_h * out_w; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| @@ -255,9 +251,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||
| int out_w = conv_param->output_w_; | |||
| int out_channel = conv_param->output_channel_; | |||
| int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; | |||
| // todo | |||
| int tile_n = 24; | |||
| int tile_n = conv_param->tile_num_; | |||
| int thread_count = conv_param->thread_num_; | |||
| int output_count = out_h * out_w; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| @@ -302,7 +296,6 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param) { | |||
| // todo | |||
| int thread_count = conv_param->thread_num_; | |||
| int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); | |||
| int output_batch = conv_param->output_batch_; | |||
| @@ -331,8 +324,5 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi | |||
| } | |||
| // get real output | |||
| for (int batch = 0; batch < output_batch; batch++) { | |||
| // int batch_size = batch * output_channel * output_h * output_w; | |||
| C4UnpackToHwcInt8(tmp_out, output_data, output_channel, output_h, output_w); | |||
| } | |||
| PackNC4HW4ToNHWCInt8(tmp_out, output_data, output_batch, output_h * output_w, output_channel); | |||
| } | |||
| @@ -16,7 +16,6 @@ | |||
| #include <stdlib.h> | |||
| // todo | |||
| extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, | |||
| size_t ksize, size_t ic4, size_t output_channel, size_t offset, | |||
| const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | |||
| @@ -345,6 +345,7 @@ void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, i | |||
| void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) { | |||
| // original weight format : ohwi | |||
| // todo pack weight for arm32 platform | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int in_channel = conv_param->input_channel_; | |||
| @@ -352,7 +353,7 @@ void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed | |||
| int oc8 = UP_DIV(out_channel, C8NUM); | |||
| int ic4 = UP_DIV(in_channel, C4NUM); | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; | |||
| int pack_weight_size = oc8 * C8NUM * ic4 * C4NUM * kernel_plane; | |||
| int unit_size = C8NUM * C4NUM; | |||
| int block_size = pack_weight_size / oc8; | |||
| @@ -565,7 +566,7 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa | |||
| void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, | |||
| int32_t *input_sum, ConvParameter *conv_param) { | |||
| // input format : nhwc | |||
| int tile_num = 4; | |||
| int tile_num = conv_param->tile_num_; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| @@ -624,7 +625,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real | |||
| void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, | |||
| int32_t *input_sum, ConvParameter *conv_param) { | |||
| // input format : nhwc | |||
| int tile_num = 24; | |||
| int tile_num = conv_param->tile_num_; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| @@ -980,15 +981,23 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_offset = b * plane * c4 * C4NUM; | |||
| int dst_offset = b * plane * channel; | |||
| for (int c = 0; c < channel; c++) { | |||
| int c4_block_num = c / C4NUM; | |||
| int c4_block_res = c % C4NUM; | |||
| int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; | |||
| int dst_c_offset = dst_offset + c; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_kernel_offset = src_c_offset + k * C4NUM; | |||
| int dst_kernel_offset = dst_c_offset + k * channel; | |||
| ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_kernel_offset = src_offset + k * C4NUM; | |||
| int dst_kernel_offset = dst_offset + k * channel; | |||
| for (int c = 0; c < c4 - 1; c++) { | |||
| int src_c_offset = src_kernel_offset + c * plane * C4NUM; | |||
| int dst_c_offset = dst_kernel_offset + c * C4NUM; | |||
| ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; | |||
| ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; | |||
| ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; | |||
| ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; | |||
| } | |||
| // res part | |||
| int res_c = channel - (c4 - 1) * C4NUM; | |||
| for (int i = 0; i < res_c; i++) { | |||
| int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; | |||
| int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; | |||
| ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| @@ -122,52 +122,4 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter | |||
| void PackDepthwiseInt8Weight(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); | |||
| inline void UnpackHwcToChwFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { | |||
| int cur = 0; | |||
| for (int i = 0; i < channel; i++) { | |||
| auto plane = i / BLOCK; | |||
| auto offset = i % BLOCK; | |||
| auto src_plane = plane * h * w * BLOCK + src_ptr; | |||
| for (int j = 0; j < h * w; j++) { | |||
| dst_ptr[cur++] = src_plane[j * BLOCK + offset]; | |||
| } | |||
| } | |||
| } | |||
| inline void C8UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { | |||
| int cur = 0; | |||
| for (int j = 0; j < h * w; j++) { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto plane = i / 8; | |||
| auto offset = i % 8; | |||
| auto src_plane = plane * h * w * 8 + src_ptr; | |||
| dst_ptr[cur++] = src_plane[j * 8 + offset]; | |||
| } | |||
| } | |||
| } | |||
| inline void C4UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { | |||
| int cur = 0; | |||
| for (int j = 0; j < h * w; j++) { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto plane = i / 4; | |||
| auto offset = i % 4; | |||
| auto src_plane = plane * h * w * 4 + src_ptr; | |||
| dst_ptr[cur++] = src_plane[j * 4 + offset]; | |||
| } | |||
| } | |||
| } | |||
| inline void C4UnpackToHwcInt8(int8_t *src_ptr, int8_t *dst_ptr, int channel, int h, int w) { | |||
| int cur = 0; | |||
| for (int j = 0; j < h * w; j++) { | |||
| for (int i = 0; i < channel; i++) { | |||
| auto plane = i / 4; | |||
| auto offset = i % 4; | |||
| auto src_plane = plane * h * w * 4 + src_ptr; | |||
| dst_ptr[cur++] = src_plane[j * 4 + offset]; | |||
| } | |||
| } | |||
| } | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_ | |||
| @@ -17,659 +17,43 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ | |||
| #include <algorithm> | |||
| #include <cassert> | |||
| #include <cmath> | |||
| #include <cstdint> | |||
| #include <limits> | |||
| #include <limits.h> | |||
| #include "include/infer_log.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| // Part 1: Low-level integer-arithmetic primitives. | |||
| // The implementations here are generic implementations valid for | |||
| // scalar types (e.g. std::int32_t). Architecture-specific SIMD types | |||
| // (e.g. NEON int32x4_t) may be supported by providing | |||
| // specializations for them in separate files. | |||
| // | |||
| // The purpose of these primitives is two-fold: | |||
| // - They will be used to implement higher-level fixed-point | |||
| // abstractions, namely the FixedPoint class and its arithmetic | |||
| // operators. | |||
| // - They will be directly used to implement some more involved | |||
| // fixed-point computations, e.g. the fixed-point implementation | |||
| // of math functions such as tanh. | |||
| // Some compile-time traits around raw types to handle SIMD aspects: | |||
| // number of lanes, underlying scalar type. | |||
| template <typename tIntegerType> | |||
| struct FixedPointRawTypeTraits {}; | |||
| template <> | |||
| struct FixedPointRawTypeTraits<std::int32_t> { | |||
| typedef std::int32_t ScalarRawType; | |||
| static constexpr int kLanes = 1; | |||
| }; | |||
| template <> | |||
| struct FixedPointRawTypeTraits<std::int16_t> { | |||
| typedef std::int16_t ScalarRawType; | |||
| static constexpr int kLanes = 1; | |||
| }; | |||
| // Returns a SIMD value duplicating a scalar value across all lanes. | |||
| template <typename tRawType> | |||
| tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { | |||
| return x; | |||
| } | |||
| // Plain bit-wise AND | |||
| template <typename tIntegerType> | |||
| tIntegerType BitAnd(tIntegerType a, tIntegerType b) { | |||
| return a & b; | |||
| } | |||
| // Plain bit-wise OR | |||
| template <typename tIntegerType> | |||
| tIntegerType BitOr(tIntegerType a, tIntegerType b) { | |||
| return a | b; | |||
| } | |||
| // Plain bit-wise XOR | |||
| template <typename tIntegerType> | |||
| tIntegerType BitXor(tIntegerType a, tIntegerType b) { | |||
| return a ^ b; | |||
| } | |||
| // Plain bit-wise NOT | |||
| template <typename tIntegerType> | |||
| tIntegerType BitNot(tIntegerType a) { | |||
| return ~a; | |||
| } | |||
| // Integer addition. Not saturating. Overflow is undefined behavior. | |||
| template <typename tIntegerType> | |||
| tIntegerType Add(tIntegerType a, tIntegerType b) { | |||
| return a + b; | |||
| } | |||
| // Integer multiplication. Not saturating. Overflow is undefined behavior. | |||
| template <typename tIntegerType> | |||
| tIntegerType Mul(tIntegerType a, tIntegerType b) { | |||
| return a * b; | |||
| } | |||
| // Integer subtraction. Not saturating. Overflow is undefined behavior. | |||
| template <typename tIntegerType> | |||
| tIntegerType Sub(tIntegerType a, tIntegerType b) { | |||
| return a - b; | |||
| } | |||
| // Integer unary negative. Not saturating. Overflow is undefined behavior. | |||
| template <typename tIntegerType> | |||
| tIntegerType Neg(tIntegerType a) { | |||
| return -a; | |||
| } | |||
| // Integer arithmetic left-shift, equivalent to multiplying with a power of two. | |||
| // Negative values are OK. In case of overflow, no Undefined | |||
| // Behavior, but the results are implementation-defined (in practice, | |||
| // they currently are saturated, but we make no commitment to that). The idea | |||
| // is that the caller will want to implement the overflowing cases with | |||
| // saturation with compare-and-mask, so we don't care about the results | |||
| // in the overflow case, we just want to avoid undefined behavior. | |||
| // | |||
| // tIntegerType may be int32 or any narrower signed type. | |||
| template <typename tIntegerType, typename OffsetType> | |||
| tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { | |||
| const std::int64_t wide_a = (std::int64_t)(a); | |||
| const std::int64_t wide_shifted = wide_a * (1 << offset); | |||
| const auto min = std::numeric_limits<tIntegerType>::min(); | |||
| const auto max = std::numeric_limits<tIntegerType>::max(); | |||
| return wide_shifted < min ? min : wide_shifted > max ? max : (tIntegerType)(wide_shifted); | |||
| } | |||
| // Integer arithmetic right-shift. Not rounding. | |||
| // Relying on implementation-defined, but in-practice-consistent, | |||
| // C++ compiler behavior. | |||
| template <typename tIntegerType> | |||
| tIntegerType ShiftRight(tIntegerType a, int offset) { | |||
| return a >> offset; | |||
| } | |||
| // Each bit of the result is set to the corresponding bit of either then_val or | |||
| // else_val depending on whether the corresponding bit of if_mask is set. | |||
| // Equivalent to the VBSL instruction in ARM NEON. | |||
| template <typename tIntegerType> | |||
| tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) { | |||
| return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); | |||
| } | |||
| // For each input scalar, the corresponding bits of the result are set if the | |||
| // input scalar is non-zero. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfNonZero(tIntegerType a) { | |||
| static constexpr tIntegerType zero = 0; | |||
| return a ? BitNot(zero) : zero; | |||
| } | |||
| // For each input scalar, the corresponding bits of the result are set if the | |||
| // input scalar is zero. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfZero(tIntegerType a) { | |||
| return MaskIfNonZero<tIntegerType>(!a); | |||
| } | |||
| // For each pair of input scalars, the corresponding bits of the result are | |||
| // set if the input scalars are equal. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { | |||
| return MaskIfNonZero<tIntegerType>(a == b); | |||
| } | |||
| // For each pair of input scalars, the corresponding bits of the result are | |||
| // set if the input scalars are not equal. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { | |||
| return MaskIfNonZero<tIntegerType>(a != b); | |||
| } | |||
| // For each pair of input scalars, the corresponding bits of the result are | |||
| // set if the input scalars a, b satisfy a > b. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { | |||
| return MaskIfNonZero<tIntegerType>(a > b); | |||
| } | |||
| // For each pair of input scalars, the corresponding bits of the result are | |||
| // set if the input scalars a, b satisfy a >= b. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { | |||
| return MaskIfNonZero<tIntegerType>(a >= b); | |||
| } | |||
| // For each pair of input scalars, the corresponding bits of the result are | |||
| // set if the input scalars a, b satisfy a < b. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { | |||
| return MaskIfNonZero<tIntegerType>(a < b); | |||
| } | |||
| // For each pair of input scalars, the corresponding bits of the result are | |||
| // set if the input scalars a, b satisfy a <= b. | |||
| template <typename tIntegerType> | |||
| tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { | |||
| return MaskIfNonZero<tIntegerType>(a <= b); | |||
| } | |||
| // Returns true if all of the input scalars are nonzero. | |||
| // This function may currently assume that each of the input scalars has either | |||
| // all or none of its bits set. Otherwise, its behavior is currently undefined. | |||
| template <typename tIntegerType> | |||
| bool All(tIntegerType a) { | |||
| return a; | |||
| } | |||
| // Returns true if any of the input scalars are nonzero. | |||
| // This function may currently assume that each of the input scalars has either | |||
| // all or none of its bits set. Otherwise, its behavior is currently undefined. | |||
| template <typename tIntegerType> | |||
| bool Any(tIntegerType a) { | |||
| return a; | |||
| } | |||
| // Returns (a+b)/2, rounded to the nearest integer. | |||
| // Equivalent to VRHADD in the ARM NEON instruction set. | |||
| template <typename IntegerType> | |||
| IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { | |||
| static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); | |||
| (void)b; | |||
| return a; | |||
| } | |||
| template <> | |||
| inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { | |||
| std::int64_t a64 = a; | |||
| std::int64_t b64 = b; | |||
| std::int64_t sum = a64 + b64; | |||
| std::int64_t sign = sum >= 0 ? 1 : -1; | |||
| return (std::int32_t)((sum + sign) / 2); | |||
| } | |||
| template <> | |||
| inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { | |||
| std::int32_t a32 = a; | |||
| std::int32_t b32 = b; | |||
| std::int32_t sum = a32 + b32; | |||
| std::int32_t sign = sum >= 0 ? 1 : -1; | |||
| return (std::int16_t)((sum + sign) / 2); | |||
| } | |||
| template <typename IntegerType> | |||
| IntegerType SaturatingAdd(IntegerType a, IntegerType b) { | |||
| static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); | |||
| (void)b; | |||
| return a; | |||
| } | |||
| // So far this is only needed for int16. | |||
| template <> | |||
| inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { | |||
| std::int32_t a32 = a; | |||
| std::int32_t b32 = b; | |||
| std::int32_t sum = a32 + b32; | |||
| return (std::int16_t)(std::min((std::int32_t)(32767), std::max((std::int32_t)(-32768), sum))); | |||
| } | |||
| template <> | |||
| inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { | |||
| std::int16_t a16 = a; | |||
| std::int16_t b16 = b; | |||
| std::int16_t sum = a16 + b16; | |||
| return (std::int8_t)(std::min((int16_t)(std::numeric_limits<int8_t>::max()), | |||
| std::max((int16_t)(std::numeric_limits<int8_t>::min()), sum))); | |||
| } | |||
| // Returns a+b, saturating if the integers are 16bit or narrower, | |||
| // otherwise just a plain addition. | |||
| template <typename IntegerType, bool Is16Bit> | |||
| struct AddSaturatingIf16BitImpl { | |||
| static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } | |||
| }; | |||
| template <typename IntegerType> | |||
| struct AddSaturatingIf16BitImpl<IntegerType, true> { | |||
| static IntegerType Run(IntegerType a, IntegerType b) { return SaturatingAdd(a, b); } | |||
| }; | |||
| template <typename IntegerType> | |||
| IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { | |||
| using ScalarType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; | |||
| return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, b); | |||
| } | |||
| // Returns the integer that represents the product of two fixed-point | |||
| // numbers, interpreting all integers as fixed-point values in the | |||
| // interval [-1, 1), rounding to the nearest value, and saturating | |||
| // -1 * -1 to the maximum value (since 1 is not in the half-open | |||
| // interval [-1, 1)). | |||
| // | |||
| // [The explanation below specializes to std::int32_t for example purpose.] | |||
| // | |||
| // The mapping between IntegerType and the interval [-1, 1) is unique and | |||
| // implied by IntegerType, which is assumed to be signed. For example, | |||
| // for IntegerType==std::int32_t, the mapping is | |||
| // real_value = integer_value / 2^31. | |||
| // So in this case, and leaving aside rounding and saturating, this | |||
| // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to | |||
| // (a * b) / 2^31. | |||
| // | |||
| // The 'doubling' part in the name of this function comes from the fact that | |||
| // this operation is very close to a "multiply-high" operation, keeping only | |||
| // the top half bits, except that that would be effectively computing | |||
| // (a * b) / 2^32, | |||
| // so here we are computing 2x that, since | |||
| // 1/2^31 = 2 * 1/2^32. | |||
| // The idea is to use all of the available 32 bits in the destination int32 | |||
| // value. | |||
| // | |||
| // [End of the explanation specializing to int32.] | |||
| // | |||
| // This is equivalent to the VQRDMULH instruction in ARM NEON. | |||
| template <typename IntegerType> | |||
| IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { | |||
| static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); | |||
| (void)b; | |||
| return a; | |||
| } | |||
| // This function implements the same computation as the ARMv7 NEON VQRDMULH | |||
| // instruction. | |||
| template <> | |||
| inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) { | |||
| bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); | |||
| std::int64_t a_64(a); | |||
| std::int64_t b_64(b); | |||
| std::int64_t ab_64 = a_64 * b_64; | |||
| std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); | |||
| std::int32_t ab_x2_high32 = (std::int32_t)((ab_64 + nudge) / (1ll << 31)); | |||
| return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; | |||
| } | |||
| template <> | |||
| inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) { | |||
| bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min(); | |||
| std::int32_t a_32(a); | |||
| std::int32_t b_32(b); | |||
| std::int32_t ab_32 = a_32 * b_32; | |||
| std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); | |||
| std::int16_t ab_x2_high16 = (std::int16_t)((ab_32 + nudge) / (1 << 15)); | |||
| return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16; | |||
| // returns the high-32 bits of a * b with rounding | |||
| // assume that a and b is divided by 2^31, who fall into [-1, 1] | |||
| // so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 | |||
| // actually we compute 2 * a * b / 2^32 | |||
| // and take 32 bits of mantissa for rounding | |||
| inline int SaturatingRoundingDoublingHighMul(int a, int b) { | |||
| if (a == INT_MIN && b == INT_MIN) { | |||
| return INT_MAX; | |||
| } | |||
| int64_t ab = ((int64_t)a) * ((int64_t)b); | |||
| int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30)); | |||
| // do not apply right shift to potential negetive values | |||
| int ab_mantissa = (int) ((ab + rounding) / (1ll << 31)); | |||
| return ab_mantissa; | |||
| } | |||
| // Correctly-rounded-to-nearest division by a power-of-two. | |||
| // Also known as a rounding arithmetic right shift. | |||
| template <typename IntegerType, typename ExponentType> | |||
| inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { | |||
| assert(exponent >= 0); | |||
| assert(exponent <= 31); | |||
| const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); | |||
| const IntegerType zero = Dup<IntegerType>(0); | |||
| const IntegerType one = Dup<IntegerType>(1); | |||
| const IntegerType remainder = BitAnd(x, mask); | |||
| const IntegerType threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); | |||
| return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one)); | |||
| // division by a 2^exponent with rounding | |||
| // or arithmetic right shift with rouding | |||
| inline int RoundingDivideByPOT(int x, int exponent) { | |||
| MS_ASSERT(exponent >= 0); | |||
| MS_ASSERT(exponent <= 31); | |||
| const int mask = (1ll << exponent) - 1; | |||
| const int remainder = x & mask; | |||
| const int threshold = (mask >> 1) + (x < 0 ? 1 : 0); | |||
| return (x >> exponent) + (remainder > threshold ? 1 : 0); | |||
| } | |||
| inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { | |||
| return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); | |||
| } | |||
| // Returns the product of a run-time integer value by a compile-time power | |||
| // of two, with either a positive exponent (equivalent to an arithmetic | |||
| // left shift, saturating) or a negative exponent (equivalent to an arithmetic | |||
| // right shift, rounding to nearest). | |||
| template <int Exponent, typename IntegerType, int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> | |||
| struct ImplSaturatingRoundingMultiplyByPOT {}; | |||
| template <int Exponent, typename IntegerType> | |||
| struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { | |||
| static IntegerType eval(IntegerType x) { return x; } | |||
| }; | |||
| template <int Exponent, typename IntegerType> | |||
| struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { | |||
| static IntegerType eval(IntegerType x) { | |||
| using ScalarIntegerType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; | |||
| const IntegerType min = Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); | |||
| const IntegerType max = Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); | |||
| const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); | |||
| const std::int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); | |||
| const IntegerType positive_mask = MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); | |||
| const IntegerType negative_mask = MaskIfLessThan(x, Dup<IntegerType>(-threshold)); | |||
| IntegerType result = ShiftLeft(x, Exponent); | |||
| result = SelectUsingMask(positive_mask, max, result); | |||
| result = SelectUsingMask(negative_mask, min, result); | |||
| return result; | |||
| } | |||
| }; | |||
| template <int Exponent, typename IntegerType> | |||
| struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { | |||
| static IntegerType eval(IntegerType x) { return RoundingDivideByPOT<IntegerType>(x, -Exponent); } | |||
| }; | |||
| template <int Exponent, typename IntegerType> | |||
| IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { | |||
| return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); | |||
| } | |||
| // Part 2: the FixedPoint class. | |||
| // A FixedPoint object represents a fixed-point value stored in the underlying | |||
| // integer type tRawType, if tRawType is a plain scalar integer type. | |||
| // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which | |||
| // case a FixedPoint object represents a corresponding SIMD vector of fixed | |||
| // point values. | |||
| // | |||
| // tIntegerBits describes the range of the fixed-point format: if | |||
| // tIntegerBits == m then the range of representable values is the half-open | |||
| // interval [-2^m; 2^m) where the open boundary on the right side means that | |||
| // 2^m is not representable (how close the maximum representable value is to | |||
| // it, depends on bit-depth of tRawType). | |||
| // | |||
| // In "Q format notation", | |||
| // https://en.wikipedia.org/wiki/Q_(number_format) | |||
| // we are describing the format | |||
| // Qm.n | |||
| // where | |||
| // m = tIntegerBits | |||
| // and | |||
| // n = NumberOfBits(tRawType) - (m + 1) | |||
| // Note that the (m + 1) in the above line is because we adopt the convention | |||
| // that we count the integer bits exclusively of the sign bit; so (m + 1) is | |||
| // the total number of integer bits inclusive of the sign bit. | |||
| // | |||
| // Accordingly, the number of integral representable values in our range | |||
| // [-2^m ; 2^m) | |||
| // is equal to 2^(m+1). | |||
| template <typename tRawType, int tIntegerBits> | |||
| class FixedPoint { | |||
| public: | |||
| typedef tRawType RawType; | |||
| typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; | |||
| typedef typename RawTypeTraits::ScalarRawType ScalarRawType; | |||
| static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); | |||
| static constexpr int kIntegerBits = tIntegerBits; | |||
| static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; | |||
| static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); | |||
| typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; | |||
| static const ScalarRawType ScalarRawMin() { return std::numeric_limits<ScalarRawType>::min(); } | |||
| static const ScalarRawType ScalarRawMax() { return std::numeric_limits<ScalarRawType>::max(); } | |||
| static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); } | |||
| static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); } | |||
| static FixedPoint FromRaw(RawType x) { | |||
| FixedPoint retval; | |||
| retval.raw() = x; | |||
| return retval; | |||
| } | |||
| static FixedPoint FromScalarRaw(ScalarRawType x) { | |||
| FixedPoint retval; | |||
| retval.raw() = Dup<RawType>(x); | |||
| return retval; | |||
| } | |||
| static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); } | |||
| template <int Exponent> | |||
| static FixedPoint ConstantPOT() { | |||
| static constexpr int kOffset = kFractionalBits + Exponent; | |||
| static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format"); | |||
| return FromScalarRaw(ScalarRawType(1) << kOffset); | |||
| } | |||
| static FixedPoint Zero() { return FromScalarRaw(0); } | |||
| static FixedPoint One() { | |||
| return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() | |||
| : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); | |||
| } | |||
| static FixedPoint FromDouble(double x) { | |||
| const double min_bound = (double)(ScalarRawMin()); | |||
| const double max_bound = (double)(ScalarRawMax()); | |||
| return FromScalarRaw( | |||
| (ScalarRawType)(std::min(std::max(round(x * (double)(1ll << kFractionalBits)), min_bound), max_bound))); | |||
| } | |||
| RawType raw() const { return i_; } | |||
| RawType &raw() { return i_; } | |||
| private: | |||
| RawType i_; | |||
| }; | |||
| // Part 3: implementation of arithmetic operators for the | |||
| // FixedPoint class, and a few related functions. | |||
| // A FixedPoint multiplication is just a | |||
| // SaturatingRoundingDoublingHighMul operation on the underlying | |||
| // raw integer values. The IntegerBits simply add up, as is obvious | |||
| // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). | |||
| template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> | |||
| FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(FixedPoint<tRawType, tIntegerBits_a> a, | |||
| FixedPoint<tRawType, tIntegerBits_b> b) { | |||
| FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; | |||
| c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); | |||
| return c; | |||
| } | |||
| // Tweaking IntegerBits gives exact multiplication by a power of two. | |||
| template <int tExponent, typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(FixedPoint<tRawType, tIntegerBits> a) { | |||
| FixedPoint<tRawType, tExponent + tIntegerBits> c; | |||
| c.raw() = a.raw(); | |||
| return c; | |||
| } | |||
| // If we want to leave IntegerBits fixed, then multiplication | |||
| // by a power of two has to be saturating/rounding, not exact anymore. | |||
| template <int tExponent, typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(FixedPoint<tRawType, tIntegerBits> a) { | |||
| return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); | |||
| } | |||
| // Generic arithmetic operators. | |||
| #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ | |||
| template <typename tRawType, int tIntegerBits> \ | |||
| FixedPoint<tRawType, tIntegerBits> FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ | |||
| return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ | |||
| } | |||
| #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ | |||
| template <typename tRawType, int tIntegerBits> \ | |||
| FixedPoint<tRawType, tIntegerBits> FuncName(FixedPoint<tRawType, tIntegerBits> a, \ | |||
| FixedPoint<tRawType, tIntegerBits> b) { \ | |||
| return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw(), b.raw())); \ | |||
| } | |||
| MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) | |||
| MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) | |||
| #undef MAKE_FIXEDPOINT_UNARY_FUNC | |||
| #undef MAKE_FIXEDPOINT_BINARY_FUNC | |||
| #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ | |||
| template <typename tRawType, int tIntegerBits> \ | |||
| tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ | |||
| return FuncName(a.raw()); \ | |||
| } | |||
| #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ | |||
| template <typename tRawType, int tIntegerBits> \ | |||
| tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { \ | |||
| return FuncName(a.raw(), b.raw()); \ | |||
| } | |||
| MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) | |||
| MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) | |||
| MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) | |||
| #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW | |||
| #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW | |||
| template <typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, tIntegerBits> SelectUsingMask(tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, | |||
| FixedPoint<tRawType, tIntegerBits> else_val) { | |||
| return FixedPoint<tRawType, tIntegerBits>::FromRaw(SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); | |||
| } | |||
| template <typename tRawType, int tIntegerBits> | |||
| bool operator==(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { | |||
| return All(MaskIfEqual(a.raw(), b.raw())); | |||
| } | |||
| template <typename tRawType, int tIntegerBits> | |||
| bool operator!=(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { | |||
| return !(a == b); | |||
| } | |||
| template <typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, tIntegerBits> SaturatingAdd(FixedPoint<tRawType, tIntegerBits> a, | |||
| FixedPoint<tRawType, tIntegerBits> b) { | |||
| return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingAdd(a.raw(), b.raw())); | |||
| } | |||
| template <typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(FixedPoint<tRawType, tIntegerBits> a, | |||
| FixedPoint<tRawType, tIntegerBits> b) { | |||
| return FixedPoint<tRawType, tIntegerBits>::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw())); | |||
| } | |||
| // Conversion to floating-point. | |||
| template <typename tRawType, int tIntegerBits> | |||
| double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { | |||
| static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, "not applicable to SIMD types"); | |||
| typedef FixedPoint<tRawType, tIntegerBits> F; | |||
| return x.raw() / (double)(1ll << F::kFractionalBits); | |||
| } | |||
| // Rescale changes the number of IntegerBits and updates the underlying | |||
| // raw integer value accordingly. | |||
| template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> | |||
| FixedPoint<tRawType, tIntegerBitsDst> Rescale(FixedPoint<tRawType, tIntegerBitsSrc> x) { | |||
| static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; | |||
| FixedPoint<tRawType, tIntegerBitsDst> result; | |||
| result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); | |||
| return result; | |||
| } | |||
| // CheckedFixedPointConstant allows to specify fixed-point constants | |||
| // initialized as real numbers, in a way that does not compile floating-point | |||
| // arithmetic in production code, yet still checks agreement with the | |||
| // floating-point expressions when asserts are enabled. | |||
| // | |||
| // The raw integer value provided is always a int32, encoding a 32-bit | |||
| // fixed-point value, regardless of the actual Scalar type. This allows | |||
| // writing generic code that applies just as well to the 32-bit and 16-bit | |||
| // cases. In the 16-bit case, the raw integer value is internally | |||
| // rounding-shifted by 16 bits to the right. | |||
| template <typename FixedPointType> | |||
| inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(std::int32_t int32_value) { | |||
| typedef typename FixedPointType::ScalarRawType ScalarRawType; | |||
| static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); | |||
| return (ScalarRawType)(RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits)); | |||
| } | |||
| // Implementation of exponential function. | |||
| // Returns -tanh(x) for x < 0. | |||
| template <typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, 0> neg_tanh_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) { | |||
| return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a))); | |||
| } | |||
| // Returns tanh(x) for any x. | |||
| template <typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { | |||
| typedef FixedPoint<tRawType, tIntegerBits> InputF; | |||
| typedef FixedPoint<tRawType, 0> ResultF; | |||
| tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); | |||
| tRawType mask_if_zero = MaskIfZero(a); | |||
| InputF n = SelectUsingMask(mask_if_negative, a, -a); | |||
| ResultF t = neg_tanh_on_negative_values(n); | |||
| return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t)); | |||
| } | |||
| // Implementation of logistic function. | |||
| // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. | |||
| template <typename tRawType, int tIntegerBits> | |||
| FixedPoint<tRawType, 0> logistic_on_positive_values(FixedPoint<tRawType, tIntegerBits> a) { | |||
| return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); | |||
| } | |||
| #ifdef ENABLE_NEON | |||
| inline int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { | |||
| const int32x4_t shift_vec = vdupq_n_s32(-exponent); | |||