From: @ling_qiao_min Reviewed-by: @zhang_xue_tong Signed-off-by:tags/v1.2.0-rc1
| @@ -15,7 +15,7 @@ file(GLOB KERNEL_SRC | |||
| ${NNACL_DIR}/*.c | |||
| ${NNACL_DIR}/fp32/*.c | |||
| ${NNACL_DIR}/int8/*.c | |||
| ${NNACL_DIR}/quantization/*.c | |||
| ${NNACL_DIR}/base/*.c | |||
| ) | |||
| if (SUPPORT_TRAIN) | |||
| @@ -42,12 +42,4 @@ typedef struct ArithmeticParameter { | |||
| int multiples1_[10]; | |||
| } ArithmeticParameter; | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void CalcMultiplesAndStrides(ArithmeticParameter *param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_ARTITHMETIC_H_ | |||
| @@ -19,7 +19,7 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| // For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. | |||
| typedef struct ArithmeticSelfParameter { | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/arithmetic.h" | |||
| #include "nnacl/base/arithmetic_base.h" | |||
| void CalcMultiplesAndStrides(ArithmeticParameter *param) { | |||
| NNACL_ASSERT(param->in_shape0_[i] != 0); | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_ | |||
| #define MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_ | |||
| #include "nnacl/arithmetic.h" | |||
| #include "nnacl/nnacl_utils.h" | |||
| #include "nnacl/nnacl_common.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void CalcMultiplesAndStrides(ArithmeticParameter *param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_ | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/base/conv1x1_base.h" | |||
| void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { | |||
| /* support nhwc */ | |||
| char *src = (char *)src_ptr; | |||
| char *dst = (char *)dst_ptr; | |||
| for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { | |||
| int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_; | |||
| if (src_h < 0 || src_h >= conv_param->input_h_) { | |||
| continue; | |||
| } | |||
| const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size; | |||
| char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size; | |||
| for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { | |||
| int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_; | |||
| if (src_w < 0 || src_w >= conv_param->input_w_) { | |||
| continue; | |||
| } | |||
| memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size, | |||
| src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_ | |||
| #define MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_ | |||
| #include "nnacl/conv_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_ | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/depth_to_space.h" | |||
| #include <string.h> | |||
| #include "nnacl/base/depth_to_space_base.h" | |||
| void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceParameter *param) { | |||
| int32_t block_size = param->block_size_; | |||
| @@ -15,6 +15,8 @@ | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_ | |||
| #define MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_ | |||
| #include <string.h> | |||
| #include "nnacl/depth_to_space_parameter.h" | |||
| #ifdef __cplusplus | |||
| @@ -18,7 +18,7 @@ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "mindspore/lite/nnacl/int8/fixed_point.h" | |||
| typedef struct ClipParameter { | |||
| OpParameter op_parameter_; | |||
| @@ -17,11 +17,10 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_COMMON_FUNC_H_ | |||
| #define MINDSPORE_LITE_NNACL_COMMON_FUNC_H_ | |||
| #include <stdint.h> | |||
| #include <stdio.h> | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/nnacl_common.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -63,14 +62,6 @@ static inline int GetStride(int *strides, const int *shape, int length) { | |||
| return stride; | |||
| } | |||
| inline void ComputeStrides(const int *shape, int *strides, const int ndim) { | |||
| int stride = 1; | |||
| for (int i = ndim - 1; i >= 0; i--) { | |||
| strides[i] = stride; | |||
| stride *= shape[i]; | |||
| } | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); | |||
| void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_CONCAT_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "mindspore/lite/nnacl/int8/quantize.h" | |||
| typedef struct ConcatParameter { | |||
| OpParameter op_parameter_; | |||
| @@ -21,7 +21,7 @@ | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "mindspore/lite/nnacl/int8/quantize.h" | |||
| typedef struct ConvParameter { | |||
| OpParameter op_parameter_; | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_CROP_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "mindspore/lite/nnacl/int8/quantize.h" | |||
| #define CROP_OFFSET_MAX_SIZE 4 | |||
| @@ -21,7 +21,7 @@ | |||
| #endif | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "mindspore/lite/nnacl/int8/fixed_point.h" | |||
| typedef struct ActivationParameter { | |||
| OpParameter op_parameter_; | |||
| @@ -20,7 +20,7 @@ | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/arithmetic.h" | |||
| #include "nnacl/base/arithmetic_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #ifdef __cplusplus | |||
| @@ -16,7 +16,6 @@ | |||
| #include "nnacl/fp16/pack_fp16.h" | |||
| #include <string.h> | |||
| #include <stdlib.h> | |||
| void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, | |||
| int block_index) { | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_ | |||
| #include <math.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| @@ -18,7 +18,7 @@ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "mindspore/lite/nnacl/int8/fixed_point.h" | |||
| typedef struct ActivationParameter { | |||
| OpParameter op_parameter_; | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/arg_min_max_fp32.h" | |||
| #include <stdlib.h> | |||
| #include <float.h> | |||
| int ArgCompareAscFp32(const void *a, const void *b) { | |||
| @@ -20,7 +20,7 @@ | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/arithmetic.h" | |||
| #include "nnacl/base/arithmetic_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #ifdef __cplusplus | |||
| @@ -17,8 +17,6 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_ | |||
| #include <stdint.h> | |||
| #include <stdio.h> | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| @@ -0,0 +1,479 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/pack_fp32.h" | |||
| void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { | |||
| return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); | |||
| } | |||
| void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { | |||
| for (int i = 0; i < height; ++i) { | |||
| for (int j = 0; j < width; ++j) { | |||
| memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); | |||
| } | |||
| } | |||
| } | |||
| void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, | |||
| int block_index) { | |||
| // input format : nhwc | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| int kernel_plane = kernel_h * kernel_w; | |||
| int dilation_h = conv_param->dilation_h_; | |||
| int dilation_w = conv_param->dilation_w_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int in_w = conv_param->input_w_; | |||
| int out_w = conv_param->output_w_; | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int block_start = block_index + i; | |||
| int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; | |||
| int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; | |||
| int input_stride = (input_h * in_w + input_w) * in_channel; | |||
| int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); | |||
| int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); | |||
| int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); | |||
| int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); | |||
| if (dilation_w == 1 && dilation_h == 1) { | |||
| for (int j = kh_s; j < kh_e; j++) { | |||
| int input_y_stride = j * in_w * in_channel + input_stride; | |||
| int input_x_stride = input_y_stride + kw_s * in_channel; | |||
| int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; | |||
| memcpy(packed_input + input_plane_offset, input_data + input_x_stride, | |||
| (kw_e - kw_s) * in_channel * sizeof(float)); | |||
| } // kernel_h loop | |||
| } else { | |||
| for (int j = kh_s; j < kh_e; j++) { | |||
| int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; | |||
| for (int k = kw_s; k < kw_e; ++k) { | |||
| int input_x_stride = input_y_stride + k * dilation_w * in_channel; | |||
| int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; | |||
| memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); | |||
| } | |||
| } // kernel_h loop | |||
| } | |||
| } // tile num loop | |||
| } | |||
| void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| int c4_minus = c4 - 1; | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_oc_offset = b * plane * channel; | |||
| int dst_oc_offset = b * plane * c4 * C4NUM; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_kernel_offset = src_oc_offset + k * channel; | |||
| int dst_kernel_offset = dst_oc_offset + k * C4NUM; | |||
| for (int j = 0; j < c4_minus; ++j) { | |||
| int src_ic_offset = src_kernel_offset + j * C4NUM; | |||
| int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM; | |||
| #ifdef ENABLE_ARM | |||
| vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset)); | |||
| #else | |||
| for (int i = 0; i < C4NUM; ++i) { | |||
| ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; | |||
| } | |||
| #endif | |||
| } | |||
| int tmp_c = c4_minus * C4NUM; | |||
| int tmp_c_offset = tmp_c * plane; | |||
| int res_c = channel - tmp_c; | |||
| for (int l = 0; l < res_c; ++l) { | |||
| int src_ic_offset = src_kernel_offset + tmp_c + l; | |||
| int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; | |||
| ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_offset = b * plane * channel; | |||
| int dst_offset = b * plane * c4 * C4NUM; | |||
| for (int c = 0; c < channel; c++) { | |||
| int c4_block_num = c / C4NUM; | |||
| int c4_block_rem = c % C4NUM; | |||
| int src_c_offset = src_offset + c * plane; | |||
| int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_kernel_offset = src_c_offset + k; | |||
| int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; | |||
| ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| int c4_channel = c4 * C4NUM; | |||
| int nhwc4_batch_unit_offset = c4 * C4NUM * plane; | |||
| int ic_remainder_ = channel % C4NUM; | |||
| if (ic_remainder_ != 0) { | |||
| int nhwc4_batch_offset = 0; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * channel * plane; | |||
| for (int i = 0; i < plane; i++) { | |||
| float *dst_per_plane = (float *)dst + nhwc4_batch_offset + i * c4_channel; | |||
| memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); | |||
| for (int j = channel; j < c4_channel; ++j) { | |||
| dst_per_plane[j] = 0; | |||
| } | |||
| } | |||
| nhwc4_batch_offset += nhwc4_batch_unit_offset; | |||
| } | |||
| } else { | |||
| size_t ori_input_size = batch * plane * channel * sizeof(float); | |||
| memcpy((float *)dst, (float *)src, ori_input_size); | |||
| } | |||
| } | |||
| void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c8 = UP_DIV(channel, C8NUM); | |||
| int c8_channel = c8 * C8NUM; | |||
| int nhwc8_batch_unit_offset = c8 * C8NUM * plane; | |||
| int ic_remainder_ = channel % C8NUM; | |||
| if (ic_remainder_ != 0) { | |||
| int nhwc8_batch_offset = 0; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * channel * plane; | |||
| for (int i = 0; i < plane; i++) { | |||
| float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel; | |||
| memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); | |||
| for (int j = channel; j < c8_channel; ++j) { | |||
| dst_per_plane[j] = 0; | |||
| } | |||
| } | |||
| nhwc8_batch_offset += nhwc8_batch_unit_offset; | |||
| } | |||
| } else { | |||
| size_t ori_input_size = batch * plane * channel * sizeof(float); | |||
| memcpy((float *)dst, (float *)src, ori_input_size); | |||
| } | |||
| } | |||
| void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| int ic_remainder_ = channel % C4NUM; | |||
| if (ic_remainder_ != 0) { | |||
| int nhwc_batch_unit_offset = channel * plane; | |||
| for (int b = 0; b < batch; b++) { | |||
| int batch_offset = b * c4 * C4NUM * plane; | |||
| for (int i = 0; i < plane; i++) { | |||
| memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM, | |||
| channel * sizeof(float)); | |||
| } | |||
| } | |||
| } else { | |||
| size_t ori_input_size = batch * plane * channel * sizeof(float); | |||
| memcpy((float *)dst, (float *)src, ori_input_size); | |||
| } | |||
| } | |||
| void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_offset = b * plane * c4 * C4NUM; | |||
| int dst_offset = b * plane * channel; | |||
| for (int c = 0; c < channel; c++) { | |||
| int c4_block_num = c / C4NUM; | |||
| int c4_block_res = c % C4NUM; | |||
| int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; | |||
| int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_kernel_offset = src_c_offset + k * C4NUM; | |||
| int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; | |||
| ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| for (int b = 0; b < batch; b++) { | |||
| int src_offset = b * plane * c4 * C4NUM; | |||
| int dst_offset = b * plane * channel; | |||
| for (int k = 0; k < plane; k++) { | |||
| int src_kernel_offset = src_offset + k * C4NUM; | |||
| int dst_kernel_offset = dst_offset + k * channel; | |||
| for (int c = 0; c < c4 - 1; c++) { | |||
| int src_c_offset = src_kernel_offset + c * plane * C4NUM; | |||
| int dst_c_offset = dst_kernel_offset + c * C4NUM; | |||
| #ifdef ENABLE_NEON | |||
| vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset)); | |||
| #else | |||
| ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; | |||
| ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; | |||
| ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; | |||
| ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; | |||
| #endif | |||
| } | |||
| // res part | |||
| int res_c = channel - (c4 - 1) * C4NUM; | |||
| for (int i = 0; i < res_c; i++) { | |||
| int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; | |||
| int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; | |||
| ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| for (int n = 0; n < batch; n++) { | |||
| for (int hw = 0; hw < plane; hw++) { | |||
| for (int c = 0; c < channel; c++) { | |||
| int c8div = c / C8NUM; | |||
| int c8mod = c % C8NUM; | |||
| int src_index = n * plane * channel + hw * channel + c; | |||
| int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; | |||
| ((float *)dst)[dst_index] = ((float *)src)[src_index]; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { | |||
| int c4 = UP_DIV(channel, C4NUM); | |||
| for (int c = 0; c < c4; c++) { | |||
| int dst_off_c = c * C4NUM * height * width; | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| int src_off_c = (c * C4NUM + i) * height * width; | |||
| for (int kh = 0; kh < height; kh++) { | |||
| int src_off_kh = src_off_c + kh * width; | |||
| for (int kw = 0; kw < width; kw++) { | |||
| int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i; | |||
| ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) { | |||
| int c8 = UP_DIV(channel, C8NUM); | |||
| for (int c = 0; c < c8; c++) { | |||
| int dst_off_c = c * C8NUM * height * width; | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| int src_off_c = (c * C8NUM + i) * height * width; | |||
| for (int kh = 0; kh < height; kh++) { | |||
| int src_off_kh = src_off_c + kh * width; | |||
| for (int kw = 0; kw < width; kw++) { | |||
| int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i; | |||
| ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #ifndef ENABLE_SSE | |||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | |||
| int hw8 = plane / C8NUM * C8NUM; | |||
| int c8 = channel / C8NUM * C8NUM; | |||
| int batch = plane * channel; | |||
| for (int n = 0; n < batches; n++) { | |||
| const float *src_batch = (const float *)src + n * batch; | |||
| float *dst_batch = (float *)dst + n * batch; | |||
| int hw = 0; | |||
| for (; hw < hw8; hw += C8NUM) { | |||
| int c = 0; | |||
| for (; c < c8; c += C8NUM) { | |||
| const float *src_ptr = src_batch + hw * channel + c; | |||
| float *dst_ptr = dst_batch + c * plane + hw; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t srcStride = channel * sizeof(float); | |||
| size_t dstStride = plane * sizeof(float); | |||
| asm volatile( | |||
| "mov x10, %[src_ptr]\n" | |||
| "mov x11, %[dst_ptr]\n" | |||
| "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" | |||
| "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" | |||
| "zip1 v8.4s, v0.4s, v2.4s\n" | |||
| "zip2 v9.4s, v0.4s, v2.4s\n" | |||
| "zip1 v12.4s, v1.4s, v3.4s\n" | |||
| "zip2 v13.4s, v1.4s, v3.4s\n" | |||
| "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" | |||
| "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" | |||
| "zip1 v10.4s, v4.4s, v6.4s\n" | |||
| "zip2 v11.4s, v4.4s, v6.4s\n" | |||
| "zip1 v14.4s, v5.4s, v7.4s\n" | |||
| "zip2 v15.4s, v5.4s, v7.4s\n" | |||
| "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" | |||
| "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" | |||
| "trn1 v16.2d, v8.2d, v10.2d\n" | |||
| "trn2 v18.2d, v8.2d, v10.2d\n" | |||
| "trn1 v20.2d, v9.2d, v11.2d\n" | |||
| "trn2 v22.2d, v9.2d, v11.2d\n" | |||
| "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" | |||
| "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" | |||
| "trn1 v24.2d, v12.2d, v14.2d\n" | |||
| "trn2 v26.2d, v12.2d, v14.2d\n" | |||
| "trn1 v28.2d, v13.2d, v15.2d\n" | |||
| "trn2 v30.2d, v13.2d, v15.2d\n" | |||
| "zip1 v8.4s, v0.4s, v2.4s\n" | |||
| "zip2 v9.4s, v0.4s, v2.4s\n" | |||
| "zip1 v12.4s, v1.4s, v3.4s\n" | |||
| "zip2 v13.4s, v1.4s, v3.4s\n" | |||
| "zip1 v10.4s, v4.4s, v6.4s\n" | |||
| "zip2 v11.4s, v4.4s, v6.4s\n" | |||
| "zip1 v14.4s, v5.4s, v7.4s\n" | |||
| "zip2 v15.4s, v5.4s, v7.4s\n" | |||
| "trn1 v17.2d, v8.2d, v10.2d\n" | |||
| "trn2 v19.2d, v8.2d, v10.2d\n" | |||
| "trn1 v21.2d, v9.2d, v11.2d\n" | |||
| "trn2 v23.2d, v9.2d, v11.2d\n" | |||
| "trn1 v25.2d, v12.2d, v14.2d\n" | |||
| "trn2 v27.2d, v12.2d, v14.2d\n" | |||
| "trn1 v29.2d, v13.2d, v15.2d\n" | |||
| "trn2 v31.2d, v13.2d, v15.2d\n" | |||
| "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n" | |||
| "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n" | |||
| : | |||
| : | |||
| [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) | |||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | |||
| "v30", "v31"); | |||
| #elif ENABLE_ARM32 | |||
| size_t srcStride = channel * sizeof(float); | |||
| size_t dstStride = plane * sizeof(float); | |||
| asm volatile( | |||
| "mov r10, %[src_ptr]\n" | |||
| "mov r12, %[dst_ptr]\n" | |||
| "vld1.32 {q0, q1}, [r10], %[srcStride]\n" | |||
| "vld1.32 {q2, q3}, [r10], %[srcStride]\n" | |||
| "vtrn.32 d0, d4\n" | |||
| "vtrn.32 d1, d5\n" | |||
| "vtrn.32 d2, d6\n" | |||
| "vtrn.32 d3, d7\n" | |||
| "vld1.32 {q4, q5}, [r10], %[srcStride]\n" | |||
| "vld1.32 {q6, q7}, [r10], %[srcStride]\n" | |||
| "vtrn.32 d8, d12\n" | |||
| "vtrn.32 d9, d13\n" | |||
| "vtrn.32 d10, d14\n" | |||
| "vtrn.32 d11, d15\n" | |||
| "vld1.32 {q8, q9}, [r10], %[srcStride]\n" | |||
| "vld1.32 {q10, q11}, [r10], %[srcStride]\n" | |||
| "vswp d1, d8\n" | |||
| "vswp d3, d10\n" | |||
| "vswp d5, d12\n" | |||
| "vswp d7, d14\n" | |||
| "vtrn.32 d16, d20\n" | |||
| "vtrn.32 d17, d21\n" | |||
| "vtrn.32 d18, d22\n" | |||
| "vtrn.32 d19, d23\n" | |||
| "vld1.32 {q12, q13}, [r10], %[srcStride]\n" | |||
| "vld1.32 {q14, q15}, [r10], %[srcStride]\n" | |||
| "vtrn.32 d24, d28\n" | |||
| "vtrn.32 d25, d29\n" | |||
| "vtrn.32 d26, d30\n" | |||
| "vtrn.32 d27, d31\n" | |||
| "vswp d17, d24\n" | |||
| "vswp d19, d26\n" | |||
| "vswp d21, d28\n" | |||
| "vswp d23, d30\n" | |||
| "add r10, r12, #16\n" | |||
| "vst1.32 {q0}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q8}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q2}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q10}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q4}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q12}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q6}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q14}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q1}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q9}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q3}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q11}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q5}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q13}, [r10], %[dstStride]\n" | |||
| "vst1.32 {q7}, [r12], %[dstStride]\n" | |||
| "vst1.32 {q15}, [r10], %[dstStride]\n" | |||
| : | |||
| : | |||
| [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", | |||
| "q15"); | |||
| #else | |||
| for (int tr = 0; tr < C8NUM; tr++) { | |||
| for (int tc = 0; tc < C8NUM; tc++) { | |||
| dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| for (; c < channel; c++) { | |||
| const float *src_ptr = src_batch + hw * channel + c; | |||
| float *dst_ptr = dst_batch + c * plane + hw; | |||
| for (size_t i = 0; i < C8NUM; i++) { | |||
| dst_ptr[i] = src_ptr[i * channel]; | |||
| } | |||
| } | |||
| } | |||
| for (; hw < plane; hw++) { | |||
| const float *src_ptr = src_batch + hw * channel; | |||
| float *dst_ptr = dst_batch + hw; | |||
| for (size_t i = 0; i < channel; i++) { | |||
| dst_ptr[i * plane] = src_ptr[i]; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| #endif | |||
| void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { | |||
| return PackNHWCToNCHWFp32(src, dst, batch, channel, plane); | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_PACK_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_PACK_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/conv_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); | |||
| void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); | |||
| void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); | |||
| void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); | |||
| void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, | |||
| int block_index); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_PAD_H_ | |||
| @@ -22,7 +22,7 @@ | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/pooling_parameter.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "mindspore/lite/nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -18,7 +18,7 @@ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { | |||
| int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); | |||
| @@ -17,7 +17,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_INT8_ARG_MIN_MAX_INT8_H_ | |||
| #include "nnacl/arg_min_max_parameter.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -17,8 +17,8 @@ | |||
| #define MINDSPORE_LITE_NNACL_INT8_ARITHMETIC_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/arithmetic.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/base/arithmetic_base.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -21,7 +21,7 @@ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #endif | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| int Int8ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { | |||
| float in_scale = para.in_args_.scale_; | |||
| @@ -22,7 +22,7 @@ | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -16,7 +16,7 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, | |||
| size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi, | |||
| @@ -17,8 +17,6 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_ | |||
| #include <stdint.h> | |||
| #include <stdio.h> | |||
| #include <string.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| @@ -29,9 +27,6 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); | |||
| void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, | |||
| int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, | |||
| int32_t maxi); | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/int8/conv1x1_int8.h" | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) { | |||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||
| matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | |||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, | |||
| filter_zp); | |||
| return; | |||
| } | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) { | |||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||
| MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, | |||
| conv_param->output_channel_, is_per_oc, filter_zp); | |||
| return; | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp); | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_ | |||
| @@ -0,0 +1,900 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/int8/conv3x3_int8.h" | |||
| void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { | |||
| #ifdef ENABLE_ARM | |||
| int16x8_t zp = vdupq_n_s16(input_zp); | |||
| int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); | |||
| int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); | |||
| int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); | |||
| int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); | |||
| int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); | |||
| int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); | |||
| int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); | |||
| int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); | |||
| int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); | |||
| int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); | |||
| int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); | |||
| int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); | |||
| int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); | |||
| int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); | |||
| int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); | |||
| int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); | |||
| int16x8_t t00 = vsubq_s16(d00, d20); | |||
| int16x8_t t01 = vsubq_s16(d01, d21); | |||
| int16x8_t t02 = vsubq_s16(d02, d22); | |||
| int16x8_t t03 = vsubq_s16(d03, d23); | |||
| int16x8_t t10 = vaddq_s16(d10, d20); | |||
| int16x8_t t11 = vaddq_s16(d11, d21); | |||
| int16x8_t t12 = vaddq_s16(d12, d22); | |||
| int16x8_t t13 = vaddq_s16(d13, d23); | |||
| int16x8_t t20 = vsubq_s16(d20, d10); | |||
| int16x8_t t21 = vsubq_s16(d21, d11); | |||
| int16x8_t t22 = vsubq_s16(d22, d12); | |||
| int16x8_t t23 = vsubq_s16(d23, d13); | |||
| int16x8_t t30 = vsubq_s16(d10, d30); | |||
| int16x8_t t31 = vsubq_s16(d11, d31); | |||
| int16x8_t t32 = vsubq_s16(d12, d32); | |||
| int16x8_t t33 = vsubq_s16(d13, d33); | |||
| int16x8_t m00 = vsubq_s16(t00, t02); | |||
| int16x8_t m01 = vaddq_s16(t01, t02); | |||
| int16x8_t m02 = vsubq_s16(t02, t01); | |||
| int16x8_t m03 = vsubq_s16(t01, t03); | |||
| int16x8_t m10 = vsubq_s16(t10, t12); | |||
| int16x8_t m11 = vaddq_s16(t11, t12); | |||
| int16x8_t m12 = vsubq_s16(t12, t11); | |||
| int16x8_t m13 = vsubq_s16(t11, t13); | |||
| int16x8_t m20 = vsubq_s16(t20, t22); | |||
| int16x8_t m21 = vaddq_s16(t21, t22); | |||
| int16x8_t m22 = vsubq_s16(t22, t21); | |||
| int16x8_t m23 = vsubq_s16(t21, t23); | |||
| int16x8_t m30 = vsubq_s16(t30, t32); | |||
| int16x8_t m31 = vaddq_s16(t31, t32); | |||
| int16x8_t m32 = vsubq_s16(t32, t31); | |||
| int16x8_t m33 = vsubq_s16(t31, t33); | |||
| vst1q_s16(trans_input_data, m00); | |||
| vst1q_s16(trans_input_data + step, m01); | |||
| vst1q_s16(trans_input_data + 2 * step, m02); | |||
| vst1q_s16(trans_input_data + 3 * step, m03); | |||
| vst1q_s16(trans_input_data + 4 * step, m10); | |||
| vst1q_s16(trans_input_data + 5 * step, m11); | |||
| vst1q_s16(trans_input_data + 6 * step, m12); | |||
| vst1q_s16(trans_input_data + 7 * step, m13); | |||
| vst1q_s16(trans_input_data + 8 * step, m20); | |||
| vst1q_s16(trans_input_data + 9 * step, m21); | |||
| vst1q_s16(trans_input_data + 10 * step, m22); | |||
| vst1q_s16(trans_input_data + 11 * step, m23); | |||
| vst1q_s16(trans_input_data + 12 * step, m30); | |||
| vst1q_s16(trans_input_data + 13 * step, m31); | |||
| vst1q_s16(trans_input_data + 14 * step, m32); | |||
| vst1q_s16(trans_input_data + 15 * step, m33); | |||
| #else | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| int16_t *local_ptr = tmp_data + i; | |||
| int16_t d00 = local_ptr[0] - input_zp; | |||
| int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; | |||
| int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; | |||
| int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; | |||
| int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; | |||
| int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; | |||
| int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; | |||
| int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; | |||
| int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; | |||
| int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; | |||
| int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; | |||
| int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; | |||
| int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; | |||
| int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; | |||
| int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; | |||
| int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; | |||
| int16_t t00 = d00 - d20; | |||
| int16_t t01 = d01 - d21; | |||
| int16_t t02 = d02 - d22; | |||
| int16_t t03 = d03 - d23; | |||
| int16_t t10 = d10 + d20; | |||
| int16_t t11 = d11 + d21; | |||
| int16_t t12 = d12 + d22; | |||
| int16_t t13 = d13 + d23; | |||
| int16_t t20 = d20 - d10; | |||
| int16_t t21 = d21 - d11; | |||
| int16_t t22 = d22 - d12; | |||
| int16_t t23 = d23 - d13; | |||
| int16_t t30 = d10 - d30; | |||
| int16_t t31 = d11 - d31; | |||
| int16_t t32 = d12 - d32; | |||
| int16_t t33 = d13 - d33; | |||
| int16_t m00 = t00 - t02; | |||
| int16_t m01 = t01 + t02; | |||
| int16_t m02 = t02 - t01; | |||
| int16_t m03 = t01 - t03; | |||
| int16_t m10 = t10 - t12; | |||
| int16_t m11 = t11 + t12; | |||
| int16_t m12 = t12 - t11; | |||
| int16_t m13 = t11 - t13; | |||
| int16_t m20 = t20 - t22; | |||
| int16_t m21 = t21 + t22; | |||
| int16_t m22 = t22 - t21; | |||
| int16_t m23 = t21 - t23; | |||
| int16_t m30 = t30 - t32; | |||
| int16_t m31 = t31 + t32; | |||
| int16_t m32 = t32 - t31; | |||
| int16_t m33 = t31 - t33; | |||
| (trans_input_data + i)[0] = m00; | |||
| (trans_input_data + i + step)[0] = m01; | |||
| (trans_input_data + i + 2 * step)[0] = m02; | |||
| (trans_input_data + i + 3 * step)[0] = m03; | |||
| (trans_input_data + i + 4 * step)[0] = m10; | |||
| (trans_input_data + i + 5 * step)[0] = m11; | |||
| (trans_input_data + i + 6 * step)[0] = m12; | |||
| (trans_input_data + i + 7 * step)[0] = m13; | |||
| (trans_input_data + i + 8 * step)[0] = m20; | |||
| (trans_input_data + i + 9 * step)[0] = m21; | |||
| (trans_input_data + i + 10 * step)[0] = m22; | |||
| (trans_input_data + i + 11 * step)[0] = m23; | |||
| (trans_input_data + i + 12 * step)[0] = m30; | |||
| (trans_input_data + i + 13 * step)[0] = m31; | |||
| (trans_input_data + i + 14 * step)[0] = m32; | |||
| (trans_input_data + i + 15 * step)[0] = m33; | |||
| } | |||
| #endif | |||
| } | |||
| void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, | |||
| int kernel_plane) { | |||
| const int input_unit = 4; | |||
| int dst_step = iC8 * C8NUM * C4NUM; | |||
| for (int o = 0; o < output_channel; o++) { | |||
| int oc4_block_num = o / C4NUM; | |||
| int oc4_block_rem = o % C4NUM; | |||
| int src_oc_offset = o * iC8 * C8NUM * kernel_plane; | |||
| int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; | |||
| for (int i = 0; i < iC8; i++) { | |||
| const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; | |||
| int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; | |||
| #ifdef ENABLE_ARM | |||
| int16x8_t g00 = vld1q_s16(src_ic8_ptr); | |||
| int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); | |||
| int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); | |||
| int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); | |||
| int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); | |||
| int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); | |||
| int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); | |||
| int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); | |||
| int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); | |||
| int16x8_t dst00 = vmulq_n_s16(g00, 2); | |||
| int16x8_t dst01 = vmulq_n_s16(g01, 2); | |||
| int16x8_t dst02 = vmulq_n_s16(g02, 2); | |||
| int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); | |||
| int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); | |||
| int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); | |||
| int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); | |||
| int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); | |||
| int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); | |||
| int16x8_t dst30 = vmulq_n_s16(g20, 2); | |||
| int16x8_t dst31 = vmulq_n_s16(g21, 2); | |||
| int16x8_t dst32 = vmulq_n_s16(g22, 2); | |||
| int16x8_t m00 = vmulq_n_s16(dst00, 2); | |||
| int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); | |||
| int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); | |||
| int16x8_t m03 = vmulq_n_s16(dst02, 2); | |||
| int16x8_t m10 = vmulq_n_s16(dst10, 2); | |||
| int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); | |||
| int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); | |||
| int16x8_t m13 = vmulq_n_s16(dst12, 2); | |||
| int16x8_t m20 = vmulq_n_s16(dst20, 2); | |||
| int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); | |||
| int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); | |||
| int16x8_t m23 = vmulq_n_s16(dst22, 2); | |||
| int16x8_t m30 = vmulq_n_s16(dst30, 2); | |||
| int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); | |||
| int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); | |||
| int16x8_t m33 = vmulq_n_s16(dst32, 2); | |||
| dst_ic8_ptr[0] = m00[0]; | |||
| dst_ic8_ptr[4] = m00[1]; | |||
| dst_ic8_ptr[8] = m00[2]; | |||
| dst_ic8_ptr[12] = m00[3]; | |||
| dst_ic8_ptr[16] = m00[4]; | |||
| dst_ic8_ptr[20] = m00[5]; | |||
| dst_ic8_ptr[24] = m00[6]; | |||
| dst_ic8_ptr[28] = m00[7]; | |||
| dst_ic8_ptr[0 + dst_step] = m01[0]; | |||
| dst_ic8_ptr[4 + dst_step] = m01[1]; | |||
| dst_ic8_ptr[8 + dst_step] = m01[2]; | |||
| dst_ic8_ptr[12 + dst_step] = m01[3]; | |||
| dst_ic8_ptr[16 + dst_step] = m01[4]; | |||
| dst_ic8_ptr[20 + dst_step] = m01[5]; | |||
| dst_ic8_ptr[24 + dst_step] = m01[6]; | |||
| dst_ic8_ptr[28 + dst_step] = m01[7]; | |||
| dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; | |||
| dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; | |||
| dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; | |||
| dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; | |||
| dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; | |||
| dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; | |||
| dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; | |||
| dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; | |||
| dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; | |||
| dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; | |||
| dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; | |||
| dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; | |||
| dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; | |||
| dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; | |||
| dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; | |||
| dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; | |||
| dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; | |||
| dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; | |||
| dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; | |||
| dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; | |||
| dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; | |||
| dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; | |||
| dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; | |||
| dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; | |||
| dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; | |||
| dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; | |||
| dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; | |||
| dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; | |||
| dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; | |||
| dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; | |||
| dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; | |||
| dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; | |||
| dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; | |||
| dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; | |||
| dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; | |||
| dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; | |||
| dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; | |||
| dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; | |||
| dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; | |||
| dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; | |||
| dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; | |||
| dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; | |||
| dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; | |||
| dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; | |||
| dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; | |||
| dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; | |||
| dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; | |||
| dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; | |||
| dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; | |||
| dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; | |||
| dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; | |||
| dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; | |||
| dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; | |||
| dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; | |||
| dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; | |||
| dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; | |||
| dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; | |||
| dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; | |||
| dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; | |||
| dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; | |||
| dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; | |||
| dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; | |||
| dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; | |||
| dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; | |||
| dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; | |||
| dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; | |||
| dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; | |||
| dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; | |||
| dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; | |||
| dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; | |||
| dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; | |||
| dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; | |||
| dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; | |||
| dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; | |||
| dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; | |||
| dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; | |||
| dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; | |||
| dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; | |||
| dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; | |||
| dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; | |||
| dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; | |||
| dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; | |||
| dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; | |||
| dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; | |||
| dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; | |||
| dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; | |||
| dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; | |||
| dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; | |||
| dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; | |||
| dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; | |||
| dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; | |||
| dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; | |||
| dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; | |||
| dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; | |||
| dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; | |||
| dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; | |||
| dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; | |||
| dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; | |||
| dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; | |||
| dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; | |||
| dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; | |||
| dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; | |||
| dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; | |||
| dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; | |||
| dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; | |||
| dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; | |||
| dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; | |||
| dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; | |||
| dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; | |||
| dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; | |||
| dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; | |||
| dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; | |||
| #else | |||
| for (int j = 0; j < C8NUM; j++) { | |||
| const int16_t *local_ptr = src_ic8_ptr + j; | |||
| int16_t dst00 = local_ptr[0] * 2; | |||
| int16_t dst01 = (local_ptr + 8)[0] * 2; | |||
| int16_t dst02 = (local_ptr + 16)[0] * 2; | |||
| int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; | |||
| int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; | |||
| int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; | |||
| int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; | |||
| int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; | |||
| int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; | |||
| int16_t dst30 = (local_ptr + 48)[0] * 2; | |||
| int16_t dst31 = (local_ptr + 56)[0] * 2; | |||
| int16_t dst32 = (local_ptr + 64)[0] * 2; | |||
| int16_t m00 = dst00 * 2; | |||
| int16_t m01 = dst00 + dst01 + dst02; | |||
| int16_t m02 = dst00 - dst01 + dst02; | |||
| int16_t m03 = dst02 * 2; | |||
| int16_t m10 = dst10 * 2; | |||
| int16_t m11 = dst10 + dst11 + dst12; | |||
| int16_t m12 = dst10 - dst11 + dst12; | |||
| int16_t m13 = dst12 * 2; | |||
| int16_t m20 = dst20 * 2; | |||
| int16_t m21 = dst20 + dst21 + dst22; | |||
| int16_t m22 = dst20 - dst21 + dst22; | |||
| int16_t m23 = dst22 * 2; | |||
| int16_t m30 = dst30 * 2; | |||
| int16_t m31 = dst30 + dst31 + dst32; | |||
| int16_t m32 = dst30 - dst31 + dst32; | |||
| int16_t m33 = dst32 * 2; | |||
| *(dst_ic8_ptr + j * 4) = m00; | |||
| *(dst_ic8_ptr + j * 4 + dst_step) = m01; | |||
| *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; | |||
| *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; | |||
| *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; | |||
| *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; | |||
| *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; | |||
| *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; | |||
| *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; | |||
| *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; | |||
| *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; | |||
| *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; | |||
| *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; | |||
| *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; | |||
| *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; | |||
| *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, | |||
| bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { | |||
| int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; | |||
| int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; | |||
| int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; | |||
| int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | |||
| int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; | |||
| int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; | |||
| #ifdef ENABLE_ARM | |||
| int32x4_t bias_ptr = vld1q_s32(bias_data); | |||
| int32x4_t s00 = vld1q_s32(gemm_out); | |||
| int32x4_t s01 = vld1q_s32(gemm_out + 4); | |||
| int32x4_t s02 = vld1q_s32(gemm_out + 8); | |||
| int32x4_t s03 = vld1q_s32(gemm_out + 12); | |||
| int32x4_t s10 = vld1q_s32(gemm_out + 16); | |||
| int32x4_t s11 = vld1q_s32(gemm_out + 20); | |||
| int32x4_t s12 = vld1q_s32(gemm_out + 24); | |||
| int32x4_t s13 = vld1q_s32(gemm_out + 28); | |||
| int32x4_t s20 = vld1q_s32(gemm_out + 32); | |||
| int32x4_t s21 = vld1q_s32(gemm_out + 36); | |||
| int32x4_t s22 = vld1q_s32(gemm_out + 40); | |||
| int32x4_t s23 = vld1q_s32(gemm_out + 44); | |||
| int32x4_t s30 = vld1q_s32(gemm_out + 48); | |||
| int32x4_t s31 = vld1q_s32(gemm_out + 52); | |||
| int32x4_t s32 = vld1q_s32(gemm_out + 56); | |||
| int32x4_t s33 = vld1q_s32(gemm_out + 60); | |||
| int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); | |||
| int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); | |||
| int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); | |||
| int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); | |||
| int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); | |||
| int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); | |||
| int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); | |||
| int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); | |||
| int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); | |||
| int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); | |||
| int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); | |||
| int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); | |||
| int32x4_t out_multiplier; | |||
| int32x4_t ls; | |||
| int32x4_t rs; | |||
| if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | |||
| out_multiplier = vld1q_s32(quant_multiplier + oc_start); | |||
| ls = vld1q_s32(left_shift + oc_start); | |||
| rs = vld1q_s32(right_shift + oc_start); | |||
| } else { | |||
| out_multiplier = vdupq_n_s32(quant_multiplier[0]); | |||
| ls = vdupq_n_s32(left_shift[0]); | |||
| rs = vdupq_n_s32(right_shift[0]); | |||
| } | |||
| int32x4_t out_zp = vdupq_n_s32(output_zp); | |||
| int32x4_t output_min = vdupq_n_s32(out_min); | |||
| int32x4_t output_max = vdupq_n_s32(out_max); | |||
| d00 = vqshlq_s32(d00, ls); | |||
| d00 = vqrdmulhq_s32(d00, out_multiplier); | |||
| int32x4_t carry = vandq_s32(d00, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d00 = vqaddq_s32(d00, carry); | |||
| d00 = vqrshlq_s32(d00, rs); | |||
| d00 = vaddq_s32(d00, out_zp); | |||
| d00 = vmaxq_s32(d00, output_min); | |||
| d00 = vminq_s32(d00, output_max); | |||
| d01 = vqshlq_s32(d01, ls); | |||
| d01 = vqrdmulhq_s32(d01, out_multiplier); | |||
| carry = vandq_s32(d01, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d01 = vqaddq_s32(d01, carry); | |||
| d01 = vqrshlq_s32(d01, rs); | |||
| d01 = vaddq_s32(d01, out_zp); | |||
| d01 = vmaxq_s32(d01, output_min); | |||
| d01 = vminq_s32(d01, output_max); | |||
| d10 = vqshlq_s32(d10, ls); | |||
| d10 = vqrdmulhq_s32(d10, out_multiplier); | |||
| carry = vandq_s32(d10, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d10 = vqaddq_s32(d10, carry); | |||
| d10 = vqrshlq_s32(d10, rs); | |||
| d10 = vaddq_s32(d10, out_zp); | |||
| d10 = vmaxq_s32(d10, output_min); | |||
| d10 = vminq_s32(d10, output_max); | |||
| d11 = vqshlq_s32(d11, ls); | |||
| d11 = vqrdmulhq_s32(d11, out_multiplier); | |||
| carry = vandq_s32(d11, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d11 = vqaddq_s32(d11, carry); | |||
| d11 = vqrshlq_s32(d11, rs); | |||
| d11 = vaddq_s32(d11, out_zp); | |||
| d11 = vmaxq_s32(d11, output_min); | |||
| d11 = vminq_s32(d11, output_max); | |||
| (output_data)[0] = (int8_t)d00[0]; | |||
| (output_data + 1)[0] = (int8_t)d00[1]; | |||
| (output_data + 2)[0] = (int8_t)d00[2]; | |||
| (output_data + 3)[0] = (int8_t)d00[3]; | |||
| if (w_not_bound) { | |||
| *(output_data + 4) = (int8_t)d01[0]; | |||
| *(output_data + 5) = (int8_t)d01[1]; | |||
| *(output_data + 6) = (int8_t)d01[2]; | |||
| *(output_data + 7) = (int8_t)d01[3]; | |||
| } | |||
| if (h_not_bound) { | |||
| *(output_data + output_w * 4) = (int8_t)d10[0]; | |||
| *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; | |||
| *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; | |||
| *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; | |||
| if (w_not_bound) { | |||
| *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; | |||
| *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; | |||
| *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; | |||
| *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; | |||
| } | |||
| } | |||
| #else | |||
| if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| const int32_t *local_ptr = gemm_out + i; | |||
| const int32_t *bias_ptr = bias_data + i; | |||
| int32_t s00 = local_ptr[0]; | |||
| int32_t s01 = (local_ptr + 4)[0]; | |||
| int32_t s02 = (local_ptr + 8)[0]; | |||
| int32_t s03 = (local_ptr + 12)[0]; | |||
| int32_t s10 = (local_ptr + 16)[0]; | |||
| int32_t s11 = (local_ptr + 20)[0]; | |||
| int32_t s12 = (local_ptr + 24)[0]; | |||
| int32_t s13 = (local_ptr + 28)[0]; | |||
| int32_t s20 = (local_ptr + 32)[0]; | |||
| int32_t s21 = (local_ptr + 36)[0]; | |||
| int32_t s22 = (local_ptr + 40)[0]; | |||
| int32_t s23 = (local_ptr + 44)[0]; | |||
| int32_t s30 = (local_ptr + 48)[0]; | |||
| int32_t s31 = (local_ptr + 52)[0]; | |||
| int32_t s32 = (local_ptr + 56)[0]; | |||
| int32_t s33 = (local_ptr + 60)[0]; | |||
| int32_t t00 = (s00 + s10 + s20) / 2; | |||
| int32_t t01 = (s01 + s11 + s21) / 2; | |||
| int32_t t02 = (s02 + s12 + s22) / 2; | |||
| int32_t t03 = (s03 + s13 + s23) / 2; | |||
| int32_t t10 = (s10 - s20 - s30) / 2; | |||
| int32_t t11 = (s11 - s21 - s31) / 2; | |||
| int32_t t12 = (s12 - s22 - s32) / 2; | |||
| int32_t t13 = (s13 - s23 - s33) / 2; | |||
| int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; | |||
| int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; | |||
| int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; | |||
| int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; | |||
| int oc_index = oc_start + i; | |||
| d00 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d00 += output_zp; | |||
| d00 = d00 > out_min ? d00 : out_min; | |||
| d00 = d00 < out_max ? d00 : out_max; | |||
| d01 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d01 += output_zp; | |||
| d01 = d01 > out_min ? d01 : out_min; | |||
| d01 = d01 < out_max ? d01 : out_max; | |||
| d10 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d10 += output_zp; | |||
| d10 = d10 > out_min ? d10 : out_min; | |||
| d10 = d10 < out_max ? d10 : out_max; | |||
| d11 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d11 += output_zp; | |||
| d11 = d11 > out_min ? d11 : out_min; | |||
| d11 = d11 < out_max ? d11 : out_max; | |||
| (output_data + i)[0] = (int8_t)d00; | |||
| if (w_not_bound) { | |||
| (output_data + i + C4NUM)[0] = (int8_t)d01; | |||
| } | |||
| if (h_not_bound) { | |||
| (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; | |||
| if (w_not_bound) { | |||
| (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| const int32_t *local_ptr = gemm_out + i; | |||
| const int32_t *bias_ptr = bias_data + i; | |||
| int32_t s00 = local_ptr[0]; | |||
| int32_t s01 = (local_ptr + 4)[0]; | |||
| int32_t s02 = (local_ptr + 8)[0]; | |||
| int32_t s03 = (local_ptr + 12)[0]; | |||
| int32_t s10 = (local_ptr + 16)[0]; | |||
| int32_t s11 = (local_ptr + 20)[0]; | |||
| int32_t s12 = (local_ptr + 24)[0]; | |||
| int32_t s13 = (local_ptr + 28)[0]; | |||
| int32_t s20 = (local_ptr + 32)[0]; | |||
| int32_t s21 = (local_ptr + 36)[0]; | |||
| int32_t s22 = (local_ptr + 40)[0]; | |||
| int32_t s23 = (local_ptr + 44)[0]; | |||
| int32_t s30 = (local_ptr + 48)[0]; | |||
| int32_t s31 = (local_ptr + 52)[0]; | |||
| int32_t s32 = (local_ptr + 56)[0]; | |||
| int32_t s33 = (local_ptr + 60)[0]; | |||
| int32_t t00 = (s00 + s10 + s20) / 2; | |||
| int32_t t01 = (s01 + s11 + s21) / 2; | |||
| int32_t t02 = (s02 + s12 + s22) / 2; | |||
| int32_t t03 = (s03 + s13 + s23) / 2; | |||
| int32_t t10 = (s10 - s20 - s30) / 2; | |||
| int32_t t11 = (s11 - s21 - s31) / 2; | |||
| int32_t t12 = (s12 - s22 - s32) / 2; | |||
| int32_t t13 = (s13 - s23 - s33) / 2; | |||
| int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; | |||
| int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; | |||
| int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; | |||
| int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; | |||
| d00 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d00 += output_zp; | |||
| d00 = d00 > out_min ? d00 : out_min; | |||
| d00 = d00 < out_max ? d00 : out_max; | |||
| d01 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d01 += output_zp; | |||
| d01 = d01 > out_min ? d01 : out_min; | |||
| d01 = d01 < out_max ? d01 : out_max; | |||
| d10 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d10 += output_zp; | |||
| d10 = d10 > out_min ? d10 : out_min; | |||
| d10 = d10 < out_max ? d10 : out_max; | |||
| d11 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d11 += output_zp; | |||
| d11 = d11 > out_min ? d11 : out_min; | |||
| d11 = d11 < out_max ? d11 : out_max; | |||
| (output_data + i)[0] = (int8_t)d00; | |||
| if (w_not_bound) { | |||
| (output_data + i + C4NUM)[0] = (int8_t)d01; | |||
| } | |||
| if (h_not_bound) { | |||
| (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; | |||
| if (w_not_bound) { | |||
| (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, | |||
| int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||
| int output_channel = conv_param->output_channel_; | |||
| int output_w = conv_param->output_w_; | |||
| int output_h = conv_param->output_h_; | |||
| const int oc4 = UP_DIV(output_channel, C4NUM); | |||
| const int input_unit = 4; | |||
| if (out_w_block == 0) { | |||
| return; | |||
| } | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int out_w_index = (start_index + i) % out_w_block; | |||
| int out_h_index = (start_index + i) / out_w_block; | |||
| int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; | |||
| int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); | |||
| for (int j = 0; j < oc4; j++) { | |||
| int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; | |||
| int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; | |||
| const int32_t *src_ptr = gemm_out + src_oc4_offset; | |||
| const int32_t *bias_ptr = bias_data + j * C4NUM; | |||
| int8_t *dst_ptr = out_data + dst_oc4_offset; | |||
| // output transform | |||
| int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; | |||
| bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; | |||
| bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; | |||
| Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, | |||
| conv_param); | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, | |||
| int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||
| // input data format : nhwc | |||
| int input_channel = conv_param->input_channel_; | |||
| int input_width = conv_param->input_w_; | |||
| int input_height = conv_param->input_h_; | |||
| int pad_w = conv_param->pad_l_; | |||
| int pad_h = conv_param->pad_u_; | |||
| ConvQuantArg quant_arg = conv_param->conv_quant_arg_; | |||
| int input_zp = quant_arg.input_quant_args_[0].zp_; | |||
| const int ic8 = UP_DIV(input_channel, C8NUM); | |||
| const int input_unit = 4; | |||
| if (out_w_block == 0) { | |||
| return; | |||
| } | |||
| for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { | |||
| int x_id = start_index + cal_id; | |||
| int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; | |||
| int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; | |||
| int real_x_start = origin_x > 0 ? 0 : -origin_x; | |||
| int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); | |||
| int real_y_start = origin_y > 0 ? 0 : -origin_y; | |||
| int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); | |||
| int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); | |||
| int dst_plane_offset = cal_id * C8NUM; | |||
| for (int ic = 0; ic < ic8; ic++) { | |||
| // copy data from origin input to tmp buffer | |||
| for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; | |||
| int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; | |||
| for (int j = real_y_start; j < real_y_end; j++) { | |||
| const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); | |||
| int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); | |||
| memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); | |||
| } | |||
| // input transform | |||
| int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; | |||
| size_t dst_step = ic8 * C8NUM * TILE_NUM; | |||
| int16_t *trans_input_ptr = trans_input + dst_ic8_offset; | |||
| Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { | |||
| int oc4 = UP_DIV(oc, C4NUM); | |||
| #ifdef ENABLE_ARM | |||
| IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); | |||
| #else | |||
| const int input_unit_square = 16; | |||
| for (int c = 0; c < oc4; c++) { | |||
| int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; | |||
| int dst_oc_offset = c * input_unit_square * C4NUM; | |||
| for (int n = 0; n < real_cal_num; n++) { | |||
| int src_tile_offset = n * C8NUM; | |||
| int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; | |||
| for (int i = 0; i < 4; i++) { | |||
| int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; | |||
| int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; | |||
| int dst_h_offset = dst_tile_offset + i * 4 * 4; | |||
| for (int m = 0; m < 4; m++) { | |||
| int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; | |||
| int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; | |||
| int dst_w_offset = dst_h_offset + m * C4NUM; | |||
| int32_t acc[4] = {0}; | |||
| for (int z = 0; z < 4; z++) { | |||
| int filter_offset = filter_w_offset + z; | |||
| for (int j = 0; j < ic8; j++) { | |||
| int filter_c8_offset = filter_offset + j * 4 * 8; | |||
| int src_c8_offset = src_w_offset + j * 8 * 8; | |||
| for (int k = 0; k < 8; k++) { | |||
| const int16_t *w_ptr = weight + filter_c8_offset + k * 4; | |||
| const int16_t *input_ptr = src + src_c8_offset + k; | |||
| acc[z] += w_ptr[0] * input_ptr[0]; | |||
| } | |||
| } | |||
| (dst + dst_w_offset + z)[0] = acc[z]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| // int8 convolution 3x3 | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param) { | |||
| int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); | |||
| int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); | |||
| int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; | |||
| const int block_unit_buffer_offset = 16 * C8NUM; | |||
| int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; | |||
| for (int batch = 0; batch < conv_param->input_batch_; batch++) { | |||
| int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | |||
| Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||
| out_w_block, conv_param); | |||
| Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| transed_weight, conv_param->output_channel_, ic8, real_cal_num); | |||
| Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ | |||
| #include <string.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, | |||
| int kernel_plane); | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ | |||
| @@ -16,7 +16,7 @@ | |||
| #include "nnacl/int8/conv_depthwise_int8.h" | |||
| #include <string.h> | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| /*conv depthwise int8 begin*/ | |||
| @@ -15,52 +15,6 @@ | |||
| */ | |||
| #include "nnacl/int8/conv_int8.h" | |||
| #include <string.h> | |||
| #include "nnacl/winograd_transform.h" | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { | |||
| int oc4 = UP_DIV(oc, C4NUM); | |||
| #ifdef ENABLE_ARM | |||
| IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); | |||
| #else | |||
| const int input_unit_square = 16; | |||
| for (int c = 0; c < oc4; c++) { | |||
| int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; | |||
| int dst_oc_offset = c * input_unit_square * C4NUM; | |||
| for (int n = 0; n < real_cal_num; n++) { | |||
| int src_tile_offset = n * C8NUM; | |||
| int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; | |||
| for (int i = 0; i < 4; i++) { | |||
| int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; | |||
| int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; | |||
| int dst_h_offset = dst_tile_offset + i * 4 * 4; | |||
| for (int m = 0; m < 4; m++) { | |||
| int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; | |||
| int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; | |||
| int dst_w_offset = dst_h_offset + m * C4NUM; | |||
| int32_t acc[4] = {0}; | |||
| for (int z = 0; z < 4; z++) { | |||
| int filter_offset = filter_w_offset + z; | |||
| for (int j = 0; j < ic8; j++) { | |||
| int filter_c8_offset = filter_offset + j * 4 * 8; | |||
| int src_c8_offset = src_w_offset + j * 8 * 8; | |||
| for (int k = 0; k < 8; k++) { | |||
| const int16_t *w_ptr = weight + filter_c8_offset + k * 4; | |||
| const int16_t *input_ptr = src + src_c8_offset + k; | |||
| acc[z] += w_ptr[0] * input_ptr[0]; | |||
| } | |||
| } | |||
| (dst + dst_w_offset + z)[0] = acc[z]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, | |||
| const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, | |||
| @@ -141,717 +95,3 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in | |||
| } | |||
| } | |||
| } | |||
| void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) { | |||
| int ic4 = UP_ROUND(input_channel, C4NUM); | |||
| int oc8 = UP_ROUND(output_channel, C8NUM); | |||
| int hw8 = UP_ROUND(plane_size, C8NUM); | |||
| size_t hw_8div = plane_size / C8NUM * C8NUM; | |||
| size_t oc_8div = output_channel / C8NUM * C8NUM; | |||
| size_t oc_8res = output_channel - oc_8div; | |||
| size_t ic_4div = input_channel / C4NUM * C4NUM; | |||
| const int8_t *src_r = src_input; | |||
| int8_t *pack_r = packed_input; | |||
| int32_t *input_sum_r = input_sum; | |||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| int32_t *input_sum_oc = input_sum_r; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t src_stride = input_channel; | |||
| size_t ic_4res = input_channel - ic_4div; | |||
| size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; | |||
| asm volatile( | |||
| "dup v16.4s, wzr \n" | |||
| "dup v17.4s, wzr \n" | |||
| "mov x10, %[src_ic] \n" | |||
| "mov x11, %[pack_ic] \n" | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[ic_4div] \n" | |||
| "add x0, x0, #4\n" | |||
| "mov x12, x10 \n" | |||
| "add x10, x10, #4\n" | |||
| "blt 2f \n" | |||
| "cmp %[ic_4res], #0\n" | |||
| "beq 6f \n" | |||
| "cmp %[ic_4res], #1\n" | |||
| "beq 3f \n" | |||
| "cmp %[ic_4res], #2\n" | |||
| "beq 4f \n" | |||
| "cmp %[ic_4res], #3\n" | |||
| "beq 5f \n" | |||
| "2: \n" | |||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 1b \n" | |||
| "3: \n" /* col res 1 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "4: \n" /* col res 2 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "5: \n" /* col res 3 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "add x13, x12, #2 \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "6: \n" | |||
| "dup v0.4s, v16.s[0] \n" | |||
| "dup v1.4s, v16.s[1] \n" | |||
| "dup v2.4s, v16.s[2] \n" | |||
| "dup v3.4s, v16.s[3] \n" | |||
| "dup v4.4s, v17.s[0] \n" | |||
| "dup v5.4s, v17.s[1] \n" | |||
| "dup v6.4s, v17.s[2] \n" | |||
| "dup v7.4s, v17.s[3] \n" | |||
| "mov x4, #0 \n" | |||
| "mov x10, %[filter_zp] \n" | |||
| "mov x11, %[input_sum_oc] \n" | |||
| "7: \n" | |||
| "cmp x4, %[oc_8div] \n" | |||
| "beq 8f \n" | |||
| "add x4, x4, #8\n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.4s}, [x10], #16\n" | |||
| "mul v18.4s, v16.4s, v0.4s \n" | |||
| "mul v19.4s, v17.4s, v0.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "mul v20.4s, v16.4s, v1.4s \n" | |||
| "mul v21.4s, v17.4s, v1.4s \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "mul v22.4s, v16.4s, v2.4s \n" | |||
| "mul v23.4s, v17.4s, v2.4s \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "mul v24.4s, v16.4s, v3.4s \n" | |||
| "mul v25.4s, v17.4s, v3.4s \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "mul v18.4s, v16.4s, v4.4s \n" | |||
| "mul v19.4s, v17.4s, v4.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "mul v20.4s, v16.4s, v5.4s \n" | |||
| "mul v21.4s, v17.4s, v5.4s \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "mul v22.4s, v16.4s, v6.4s \n" | |||
| "mul v23.4s, v17.4s, v6.4s \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "mul v24.4s, v16.4s, v7.4s \n" | |||
| "mul v25.4s, v17.4s, v7.4s \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "add x11, x11, %[input_sum_stride] \n" | |||
| "b 7b \n" | |||
| "8: \n" | |||
| "cmp %[oc_8res], #0\n" | |||
| "beq 17f \n" | |||
| "dup v16.4s, wzr \n" | |||
| "dup v17.4s, wzr \n" | |||
| "cmp %[oc_8res], #1\n" | |||
| "beq 9f \n" | |||
| "cmp %[oc_8res], #2\n" | |||
| "beq 10f \n" | |||
| "cmp %[oc_8res], #3\n" | |||
| "beq 11f \n" | |||
| "cmp %[oc_8res], #4\n" | |||
| "beq 12f \n" | |||
| "cmp %[oc_8res], #5\n" | |||
| "beq 13f \n" | |||
| "cmp %[oc_8res], #6\n" | |||
| "beq 14f \n" | |||
| "cmp %[oc_8res], #7\n" | |||
| "beq 15f \n" | |||
| "9: \n" | |||
| "ld1 {v16.s}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "10: \n" | |||
| "ld1 {v16.d}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "11: \n" | |||
| "ld1 {v16.d}[0], [x10] \n" | |||
| "add x10, x10, #8 \n" | |||
| "ld1 {v16.s}[2], [x10] \n" | |||
| "b 16f \n" | |||
| "12: \n" | |||
| "ld1 {v16.4s}, [x10] \n" | |||
| "b 16f \n" | |||
| "13: \n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.s}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "14: \n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.d}[0], [x10] \n" | |||
| "b 16f \n" | |||
| "15: \n" | |||
| "ld1 {v16.4s}, [x10], #16\n" | |||
| "ld1 {v17.d}[0], [x10] \n" | |||
| "add x10, x10, #8 \n" | |||
| "ld1 {v17.s}[2], [x10] \n" | |||
| "b 16f \n" | |||
| "16: \n" | |||
| "mul v18.4s, v16.4s, v0.4s \n" | |||
| "mul v19.4s, v17.4s, v0.4s \n" | |||
| "mul v20.4s, v16.4s, v1.4s \n" | |||
| "mul v21.4s, v17.4s, v1.4s \n" | |||
| "mul v22.4s, v16.4s, v2.4s \n" | |||
| "mul v23.4s, v17.4s, v2.4s \n" | |||
| "mul v24.4s, v16.4s, v3.4s \n" | |||
| "mul v25.4s, v17.4s, v3.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "mul v18.4s, v16.4s, v4.4s \n" | |||
| "mul v19.4s, v17.4s, v4.4s \n" | |||
| "mul v20.4s, v16.4s, v5.4s \n" | |||
| "mul v21.4s, v17.4s, v5.4s \n" | |||
| "mul v22.4s, v16.4s, v6.4s \n" | |||
| "mul v23.4s, v17.4s, v6.4s \n" | |||
| "mul v24.4s, v16.4s, v7.4s \n" | |||
| "mul v25.4s, v17.4s, v7.4s \n" | |||
| "st1 {v18.4s}, [x11], #16 \n" | |||
| "st1 {v19.4s}, [x11], #16 \n" | |||
| "st1 {v20.4s}, [x11], #16 \n" | |||
| "st1 {v21.4s}, [x11], #16 \n" | |||
| "st1 {v22.4s}, [x11], #16 \n" | |||
| "st1 {v23.4s}, [x11], #16 \n" | |||
| "st1 {v24.4s}, [x11], #16 \n" | |||
| "st1 {v25.4s}, [x11], #16 \n" | |||
| "17: \n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), | |||
| [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), | |||
| [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) | |||
| : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", | |||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); | |||
| #else | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[0 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[1 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[2 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[3 + i * input_channel]; | |||
| pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; | |||
| pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; | |||
| pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; | |||
| pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; | |||
| } | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[i * input_channel]; | |||
| pack_ic[i * C4NUM] = src_ic[i * input_channel]; | |||
| } | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| for (int ici = input_channel; ici < ic4; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| pack_ic[i * C4NUM] = 0; | |||
| } | |||
| pack_ic += 1; | |||
| } | |||
| for (int oci = 0; oci < oc_8div; oci += C8NUM) { | |||
| for (int ri = 0; ri < C8NUM; ri++) { | |||
| input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; | |||
| input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; | |||
| input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; | |||
| input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; | |||
| input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; | |||
| input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; | |||
| input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; | |||
| input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; | |||
| } | |||
| input_sum_oc += inputsum_stride; | |||
| } | |||
| if (oc_8div != output_channel) { | |||
| for (int oci = 0; oci < oc_8res; oci += 1) { | |||
| for (int ri = 0; ri < C8NUM; ri++) { | |||
| input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; | |||
| } | |||
| } | |||
| for (int oci = oc_8res; oci < C8NUM; oci += 1) { | |||
| for (int ri = 0; ri < C8NUM; ri++) { | |||
| input_sum_oc[ri * C8NUM + oci] = 0; | |||
| } | |||
| } | |||
| } /* oc8 res done */ | |||
| #endif | |||
| src_r += input_channel * C8NUM; | |||
| pack_r += ic4 * C8NUM; | |||
| input_sum_r += C8NUM * C8NUM; | |||
| } | |||
| if (hw_8div != plane_size) { | |||
| memset(pack_r, 0, C8NUM * ic4); | |||
| for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { | |||
| int32_t *input_sum_oc = input_sum_r; | |||
| int32_t tmp_sum_value = 0; | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| tmp_sum_value += src_ic[0]; | |||
| tmp_sum_value += src_ic[1]; | |||
| tmp_sum_value += src_ic[2]; | |||
| tmp_sum_value += src_ic[3]; | |||
| pack_ic[0] = src_ic[0]; | |||
| pack_ic[1] = src_ic[1]; | |||
| pack_ic[2] = src_ic[2]; | |||
| pack_ic[3] = src_ic[3]; | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| tmp_sum_value += src_ic[0]; | |||
| pack_ic[0] = src_ic[0]; | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| for (int oci = 0; oci < oc_8div; oci += C8NUM) { | |||
| for (int curoi = 0; curoi < C8NUM; curoi++) { | |||
| input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; | |||
| } | |||
| input_sum_oc += inputsum_stride; | |||
| } | |||
| if (oc_8div != output_channel) { | |||
| for (int oci = 0; oci < oc_8res; oci += 1) { | |||
| input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; | |||
| } | |||
| for (int oci = oc_8res; oci < C8NUM; oci += 1) { | |||
| input_sum_oc[oci] = 0; | |||
| } | |||
| } /* oc8 res done */ | |||
| src_r += input_channel; | |||
| pack_r += C4NUM; | |||
| input_sum_r += C8NUM; | |||
| } | |||
| for (int hwi = plane_size; hwi < hw8; hwi++) { | |||
| for (int oc = 0; oc < oc8; oc++) { | |||
| int oc8div = oc / C8NUM, oc8res = oc % C8NUM; | |||
| input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t plane_size, ConvParameter *conv_param) { | |||
| int ic4 = UP_ROUND(input_channel, C4NUM); | |||
| size_t hw_8div = plane_size / C8NUM * C8NUM; | |||
| size_t ic_4div = input_channel / C4NUM * C4NUM; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| const int8_t *src_r = src_input; | |||
| int8_t *pack_r = packed_input; | |||
| /* per layer */ | |||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| int32_t *input_sum_r = input_sum + hwi; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t src_stride = input_channel; | |||
| size_t ic_4res = input_channel - ic_4div; | |||
| asm volatile( | |||
| "dup v16.4s, wzr \n" | |||
| "dup v17.4s, wzr \n" | |||
| "mov x14, %[input_sum_r] \n" | |||
| "dup v20.4s, %w[filter_zp] \n" | |||
| "mov x10, %[src_ic] \n" | |||
| "mov x11, %[pack_ic] \n" | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[ic_4div] \n" | |||
| "add x0, x0, #4\n" | |||
| "mov x12, x10 \n" | |||
| "add x10, x10, #4\n" | |||
| "blt 2f \n" | |||
| "cmp %[ic_4res], #0\n" | |||
| "beq 6f \n" | |||
| "cmp %[ic_4res], #1\n" | |||
| "beq 3f \n" | |||
| "cmp %[ic_4res], #2\n" | |||
| "beq 4f \n" | |||
| "cmp %[ic_4res], #3\n" | |||
| "beq 5f \n" | |||
| "2: \n" | |||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 1b \n" | |||
| "3: \n" /* col res 1 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "4: \n" /* col res 2 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "5: \n" /* col res 3 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "add x13, x12, #2 \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v16.4s, v16.4s, v0.4s \n" | |||
| "add v17.4s, v17.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "6: \n" | |||
| "mul v16.4s, v16.4s, v20.4s \n" | |||
| "mul v17.4s, v17.4s, v20.4s \n" | |||
| "st1 {v16.4s}, [x14], #16 \n" | |||
| "st1 {v17.4s}, [x14], #16 \n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | |||
| [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) | |||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
| "v20"); | |||
| #else | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[0 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[1 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[2 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[3 + i * input_channel]; | |||
| pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; | |||
| pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; | |||
| pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; | |||
| pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; | |||
| } | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[i * input_channel]; | |||
| pack_ic[i * C4NUM] = src_ic[i * input_channel]; | |||
| } | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| for (int ici = input_channel; ici < ic4; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| pack_ic[i * C4NUM] = 0; | |||
| } | |||
| pack_ic += 1; | |||
| } | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| input_sum_r[i] = tmp_sum_value[i] * filter_zp; | |||
| } | |||
| #endif | |||
| src_r += input_channel * C8NUM; | |||
| pack_r += ic4 * C8NUM; | |||
| } | |||
| if (hw_8div != plane_size) { | |||
| memset(pack_r, 0, C8NUM * ic4); | |||
| for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { | |||
| int32_t tmp_sum_value = 0; | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| tmp_sum_value += src_ic[0]; | |||
| tmp_sum_value += src_ic[1]; | |||
| tmp_sum_value += src_ic[2]; | |||
| tmp_sum_value += src_ic[3]; | |||
| pack_ic[0] = src_ic[0]; | |||
| pack_ic[1] = src_ic[1]; | |||
| pack_ic[2] = src_ic[2]; | |||
| pack_ic[3] = src_ic[3]; | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| tmp_sum_value += src_ic[0]; | |||
| pack_ic[0] = src_ic[0]; | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| input_sum[hwi] = tmp_sum_value * filter_zp; | |||
| src_r += input_channel; | |||
| pack_r += C4NUM; | |||
| } | |||
| for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { | |||
| input_sum[hwi] = 0; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) { | |||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||
| matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | |||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, | |||
| filter_zp); | |||
| return; | |||
| } | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) { | |||
| int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; | |||
| MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, | |||
| conv_param->output_channel_, is_per_oc, filter_zp); | |||
| return; | |||
| } | |||
| // int8 convolution 3x3 | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param) { | |||
| int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); | |||
| int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | |||
| int output_count = out_w_block * out_h_block; | |||
| int output_tile_count = UP_DIV(output_count, TILE_NUM); | |||
| int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); | |||
| int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; | |||
| const int block_unit_buffer_offset = 16 * C8NUM; | |||
| int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; | |||
| for (int batch = 0; batch < conv_param->input_batch_; batch++) { | |||
| int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||
| int start_index = thread_id * TILE_NUM; | |||
| int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; | |||
| Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | |||
| out_w_block, conv_param); | |||
| Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, | |||
| transed_weight, conv_param->output_channel_, ic8, real_cal_num); | |||
| Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, | |||
| bias_data, start_index, real_cal_num, out_w_block, conv_param); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ | |||
| #include <string.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| @@ -24,9 +25,10 @@ | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -36,23 +38,6 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in | |||
| const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, | |||
| ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize); | |||
| // int8 convolution 1x1 | |||
| void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride); | |||
| void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t plane_size, ConvParameter *conv_param); | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp); | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp); | |||
| // int8 convolution 3x3 | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | |||
| int task_id, ConvParameter *conv_param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -17,7 +17,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ | |||
| #include "nnacl/depth_to_space_parameter.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -15,9 +15,6 @@ | |||
| */ | |||
| #include "nnacl/int8/div_int8.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) { | |||
| int index = 0; | |||
| @@ -18,7 +18,9 @@ | |||
| #define MINDSPORE_LITE_NNACL_INT8_DIV_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| // returns the high-32 bits of a * b with rounding | |||
| // assume that a and b is divided by 2^31, who fall into [-1, 1] | |||
| @@ -107,7 +107,7 @@ int CountLeadingZeroBits(uint32_t x) { | |||
| if (x == 0) { | |||
| return 8 * sizeof(uint32_t); | |||
| } | |||
| const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1); | |||
| const int32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1); | |||
| int leading_zeros = 0; | |||
| while (x < leading_positive) { | |||
| x <<= 1; | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -16,7 +16,7 @@ | |||
| */ | |||
| #include "nnacl/int8/gather_int8.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/errorcode.h" | |||
| int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, const int *indices, | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -19,7 +19,7 @@ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| typedef struct HswishQuantArg { | |||
| double input_scale; | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "nnacl/int8/l2_norm_int8.h" | |||
| #include <limits.h> | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/errorcode.h" | |||
| int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, | |||
| @@ -18,7 +18,7 @@ | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/layer_norm_parameter.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PRELU_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||
| int col16 = UP_ROUND(col, C16NUM); | |||
| @@ -17,7 +17,6 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ | |||
| #include <stdio.h> | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| @@ -15,15 +15,8 @@ | |||
| */ | |||
| #include "nnacl/int8/mul_int8.h" | |||
| #include "nnacl/mul_parameter.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #endif | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #ifdef ENABLE_NEON | |||
| int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, | |||
| int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { | |||
| int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); | |||
| @@ -19,6 +19,11 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/mul_parameter.h" | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_ | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/int8/matmul_int8.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); | |||
| void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); | |||
| void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); | |||
| void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); | |||
| void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, | |||
| int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param, | |||
| bool per_channel, bool is_optimize); | |||
| #ifdef ENABLE_ARM | |||
| void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); | |||
| void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, | |||
| size_t oc_res, size_t stride); | |||
| #endif | |||
| void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); | |||
| void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg); | |||
| void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_INT8_PAD_INT8_H_ | |||
| @@ -19,7 +19,7 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/power_parameter.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -14,8 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include <stdio.h> | |||
| #include "nnacl/int8/quantize.h" | |||
| const uint64_t dSignMask = 1ull << 63; | |||
| const uint64_t dExponentMask = 0x7ffull << 52; | |||
| @@ -57,8 +56,6 @@ void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t | |||
| /* multipiler is in[0x40000000, 0x7FFFFF80] range */ | |||
| *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); | |||
| if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { | |||
| printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n", | |||
| quantized_multiplier[0]); | |||
| return; | |||
| } | |||
| /* shift is in [0, 31] range */ | |||
| @@ -28,11 +28,6 @@ | |||
| #define FILTER_PER_CHANNEL 0b010 | |||
| #define OUTPUT_PER_CHANNEL 0b100 | |||
| typedef struct QuantArg { | |||
| float scale_; | |||
| int32_t zp_; | |||
| } QuantArg; | |||
| typedef struct ConvQuantArg { | |||
| RoundingMode round_mode_; | |||
| CalFixedMultiplierMode quant_multiplier_mode_; | |||
| @@ -58,24 +53,6 @@ typedef struct ConcatQuantArg { | |||
| int8_t output_activation_max_; | |||
| } ConcatQuantArg; | |||
| typedef struct SqueezeQuantArg { | |||
| QuantArg *in_quant_args_; | |||
| QuantArg *out_quant_args_; | |||
| } SqueezeQuantArg; | |||
| typedef struct UnSqueezeQuantArg { | |||
| int *input_sizes_; | |||
| int output_size_; | |||
| int **input_shapes_; | |||
| int *output_shape_; | |||
| float alpha; | |||
| int axis_; | |||
| size_t input_num_; | |||
| size_t output_dim_; | |||
| QuantArg in_quant_args_; | |||
| QuantArg out_quant_args_; | |||
| } UnSqueezeQuantArg; | |||
| typedef struct PreluQuantArg { | |||
| int *input_sizes_; | |||
| int output_size_; | |||
| @@ -103,22 +80,6 @@ typedef struct MatmulQuantArg { | |||
| int32_t quant_multiplier; | |||
| } MatmulQuantArg; | |||
| typedef struct PadQuantArg { | |||
| QuantArg *in_quant_args_; | |||
| QuantArg *out_quanr_args_; | |||
| int8_t *constant_value_; | |||
| } PadQuantArg; | |||
| typedef struct MulQuantArg { | |||
| QuantArg in_quant_args_[2]; | |||
| QuantArg out_quant_arg_; | |||
| int output_multiplier_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| int shift_left_; | |||
| int shift_right_; | |||
| } MulQuantArg; | |||
| typedef struct CropQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_; | |||
| @@ -142,13 +103,6 @@ typedef struct GatherQuantArg { | |||
| int zp_out_; | |||
| } GatherQuantArg; | |||
| typedef struct SplitQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_[20]; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } SplitQuantArg; | |||
| typedef struct SoftmaxQuantArg { | |||
| QuantArg in_quant_args_; | |||
| QuantArg out_quant_arg_; | |||
| @@ -159,19 +113,6 @@ typedef struct SoftmaxQuantArg { | |||
| int shift_right_; | |||
| } SoftmaxQuantArg; | |||
| typedef struct ReshapeQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } ReshapeQuantArg; | |||
| typedef struct QuantMulArg { | |||
| int32_t multiplier_; | |||
| int left_shift_; | |||
| int right_shift_; | |||
| } QuantMulArg; | |||
| typedef struct SubQuantArg { | |||
| QuantArg in0_args_; | |||
| QuantArg in1_args_; | |||
| @@ -227,21 +168,6 @@ typedef struct ReduceQuantArg { | |||
| int sum_square_right_shift_; | |||
| } ReduceQuantArg; | |||
| typedef struct SliceQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } SliceQuantArg; | |||
| typedef struct PowerQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg exp_args_; | |||
| QuantArg out_args_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } PowerQuantArg; | |||
| typedef struct LeakyReluQuantArg { | |||
| OpParameter op_parameter_; | |||
| PreluQuantArg quant_arg; | |||
| @@ -17,7 +17,7 @@ | |||
| #include <stdint.h> | |||
| #include "nnacl/int8/reduce_int8.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/common_func.h" | |||
| int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { | |||
| @@ -16,7 +16,9 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_ | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -19,8 +19,8 @@ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| typedef struct ReluXQuantArg { | |||
| QuantArg input_arg; | |||
| @@ -16,6 +16,8 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/reshape_parameter.h" | |||
| @@ -16,7 +16,7 @@ | |||
| #include <math.h> | |||
| #include "nnacl/int8/resize_int8.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/errorcode.h" | |||
| int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, | |||
| @@ -21,7 +21,7 @@ | |||
| #endif | |||
| #include <memory.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/resize_parameter.h" | |||
| #ifdef __cplusplus | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "nnacl/int8/scale_int8.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef ENABLE_NEON | |||
| int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, | |||
| @@ -19,6 +19,8 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/scale.h" | |||
| #include "nnacl/nnacl_common.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -19,7 +19,7 @@ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -15,8 +15,6 @@ | |||
| */ | |||
| #include "nnacl/int8/slice_int8.h" | |||
| #include <string.h> | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param) { | |||
| double input_scale = param->quant_arg_.in_args_.scale_; | |||
| @@ -16,8 +16,11 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_ | |||
| #include <math.h> | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/slice_parameter.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -15,9 +15,6 @@ | |||
| */ | |||
| #include "nnacl/int8/softmax_int8.h" | |||
| #include <math.h> | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data, | |||
| SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) { | |||
| @@ -17,9 +17,11 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/softmax_parameter.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -16,6 +16,8 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/split_parameter.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_ | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/squeeze_parameter.h" | |||
| #ifdef __cplusplus | |||
| @@ -19,7 +19,7 @@ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/int8/common_func_int8.h" | |||
| #endif | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef ENABLE_NEON | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_INT8_SUB_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -18,8 +18,8 @@ | |||
| #define MINDSPORE_LITE_NNACL_INT8_TANH_INT8_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| #include "nnacl/int8/quantize.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #include "nnacl/int8/quant_dtype_cast_int8.h" | |||
| #include "nnacl/fp32/activation_fp32.h" | |||
| @@ -16,7 +16,6 @@ | |||
| #include "nnacl/unsqueeze_parameter.h" | |||
| #include "nnacl/int8/unsqueeze_int8.h" | |||
| #include <string.h> | |||
| int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) { | |||
| float output_scale = para_->quant_arg.out_quant_args_.scale_; | |||
| @@ -17,7 +17,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_L2NORM_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "mindspore/lite/nnacl/int8/quantize.h" | |||
| typedef struct L2NormParameter { | |||
| // Primitive parameter | |||
| @@ -17,7 +17,7 @@ | |||
| #define MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #include "mindspore/lite/nnacl/int8/quantize.h" | |||
| enum ElementwiseMode { ELEMENTWISE_NOT = 0, ELEMENTWISE_PER_CHANNEL = 1, ELEMENTWISE_PER_NUM = 2 }; | |||
| typedef struct LayerNormParameter { | |||
| @@ -18,7 +18,6 @@ | |||
| #define MINDSPORE_LITE_NNACL_MATMUL_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias); | |||
| @@ -18,7 +18,16 @@ | |||
| #define MINDSPORE_LITE_NNACL_MUL_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef struct MulQuantArg { | |||
| QuantArg in_quant_args_[2]; | |||
| QuantArg out_quant_arg_; | |||
| int output_multiplier_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| int shift_left_; | |||
| int shift_right_; | |||
| } MulQuantArg; | |||
| typedef struct MulParameter { | |||
| // Primitive parameter | |||
| @@ -14,11 +14,4 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PYBIND_API_PYBIND_PATCH_H_ | |||
| #define PYBIND_API_PYBIND_PATCH_H_ | |||
| namespace pybind11 { | |||
| PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) | |||
| } | |||
| #endif // PYBIND_API_PYBIND_PATCH_H_ | |||
| #include "nnacl/nnacl_common.h" | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_NNACL_COMMON_H_ | |||
| #define MINDSPORE_LITE_NNACL_NNACL_COMMON_H_ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| inline void ComputeStrides(const int *shape, int *strides, const int ndim) { | |||
| int stride = 1; | |||
| for (int i = ndim - 1; i >= 0; i--) { | |||
| strides[i] = stride; | |||
| stride *= shape[i]; | |||
| } | |||
| } | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_NNACL_COMMON_H_ | |||
| @@ -28,6 +28,7 @@ | |||
| #include <stdint.h> | |||
| #include <stdlib.h> | |||
| #include <stdbool.h> | |||
| #include <string.h> | |||
| #define C2NUM 2 | |||
| #define C4NUM 4 | |||
| @@ -78,6 +79,17 @@ typedef struct OpParameter { | |||
| int thread_num_; | |||
| } OpParameter; | |||
| typedef struct QuantArg { | |||
| float scale_; | |||
| int32_t zp_; | |||
| } QuantArg; | |||
| typedef struct QuantMulArg { | |||
| int32_t multiplier_; | |||
| int left_shift_; | |||
| int right_shift_; | |||
| } QuantMulArg; | |||
| typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; | |||
| typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode; | |||
| typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; | |||
| @@ -17,102 +17,12 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_PACK_H_ | |||
| #define MINDSPORE_LITE_NNACL_PACK_H_ | |||
| #include <stdio.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/fp32/pack_fp32.h" | |||
| #include "nnacl/int8/pack_int8.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, | |||
| int block_index); | |||
| void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); | |||
| void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, | |||
| int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param, | |||
| bool per_channel, bool is_optimize); | |||
| void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); | |||
| void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, | |||
| size_t plane_size, size_t input_channel, size_t output_channel); | |||
| void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, | |||
| size_t plane_size, size_t input_channel, size_t output_channel); | |||
| void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); | |||
| void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); | |||
| void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); | |||
| void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); | |||
| void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); | |||
| void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); | |||
| void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); | |||
| void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); | |||
| void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); | |||
| void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); | |||
| void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); | |||
| void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg); | |||
| void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, | |||
| ConvQuantArg *quant_qrg); | |||
| #ifdef ENABLE_ARM | |||
| void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); | |||
| void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, | |||
| size_t oc_res, size_t stride); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -17,11 +17,16 @@ | |||
| #define MINDSPORE_LITE_NNACL_PAD_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #define MAX_PAD_SIZE 8 | |||
| #define DEFAULT_PAD_NDIMS 4 | |||
| typedef struct PadQuantArg { | |||
| QuantArg *in_quant_args_; | |||
| QuantArg *out_quanr_args_; | |||
| int8_t *constant_value_; | |||
| } PadQuantArg; | |||
| typedef struct PadParameter { | |||
| // Primitive parameter | |||
| OpParameter op_parameter_; | |||
| @@ -17,7 +17,6 @@ | |||
| #define MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode; | |||
| @@ -18,7 +18,14 @@ | |||
| #define MINDSPORE_LITE_NNACL_POWER_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef struct PowerQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg exp_args_; | |||
| QuantArg out_args_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } PowerQuantArg; | |||
| typedef struct PowerParameter { | |||
| // Primitive parameter | |||
| @@ -18,7 +18,13 @@ | |||
| #define MINDSPORE_LITE_NNACL_RESHAHPE_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| typedef struct ReshapeQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } ReshapeQuantArg; | |||
| typedef struct ReshapeParameter { | |||
| // primitive parameter | |||
| @@ -17,8 +17,8 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_SCALE_H_ | |||
| #define MINDSPORE_LITE_NNACL_SCALE_H_ | |||
| #include <mindspore/lite/nnacl/quantization/quantize.h> | |||
| #include "nnacl/op_base.h" | |||
| typedef struct ScaleParameter { | |||
| // primitive parameter | |||
| OpParameter op_parameter_; | |||
| @@ -16,7 +16,6 @@ | |||
| #include "nnacl/scatter_nd.h" | |||
| #include <string.h> | |||
| #include <stdio.h> | |||
| #include "nnacl/errorcode.h" | |||
| int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units) { | |||
| @@ -18,10 +18,16 @@ | |||
| #define MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #define SLICE_SHAPE_MAX_SIZE 4 | |||
| typedef struct SliceQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } SliceQuantArg; | |||
| typedef struct SliceParameter { | |||
| // primitive parameter | |||
| OpParameter op_parameter_; | |||
| @@ -18,8 +18,16 @@ | |||
| #define MINDSPORE_LITE_NNACL_SPLIT_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #define SPLIT_STRIDES_SIZE 32 | |||
| typedef struct SplitQuantArg { | |||
| QuantArg in_args_; | |||
| QuantArg out_args_[20]; | |||
| int output_activation_min_; | |||
| int output_activation_max_; | |||
| } SplitQuantArg; | |||
| typedef struct SplitParameter { | |||
| // primitive parameter | |||
| OpParameter op_parameter_; | |||
| @@ -16,11 +16,16 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_ | |||
| #define MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #define SQUEEZE_OFFSET_MAX_SIZE 4 | |||
| typedef struct SqueezeQuantArg { | |||
| QuantArg *in_quant_args_; | |||
| QuantArg *out_quant_args_; | |||
| } SqueezeQuantArg; | |||
| typedef struct SqueezeParameter { | |||
| // primitive parameter | |||
| OpParameter op_parameter_; | |||
| @@ -16,11 +16,26 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_ | |||
| #define MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_ | |||
| #include <string.h> | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/quantize.h" | |||
| #define UNSQUEEZE_OFFSET_MAX_SIZE 4 | |||
| typedef struct UnSqueezeQuantArg { | |||
| int *input_sizes_; | |||
| int output_size_; | |||
| int **input_shapes_; | |||
| int *output_shape_; | |||
| float alpha; | |||
| int axis_; | |||
| size_t input_num_; | |||
| size_t output_dim_; | |||
| QuantArg in_quant_args_; | |||
| QuantArg out_quant_args_; | |||
| } UnSqueezeQuantArg; | |||
| typedef struct UnSqueezeParameter { | |||
| // primitive parameter | |||
| OpParameter op_parameter_; | |||
| @@ -140,810 +140,3 @@ void WinogradOutputTransform(const float *gemm_out, float *out_data, const float | |||
| out_tile_index++; | |||
| } | |||
| } | |||
| // int8 conv3x3 | |||
| void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { | |||
| #ifdef ENABLE_ARM | |||
| int16x8_t zp = vdupq_n_s16(input_zp); | |||
| int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); | |||
| int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); | |||
| int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); | |||
| int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); | |||
| int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); | |||
| int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); | |||
| int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); | |||
| int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); | |||
| int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); | |||
| int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); | |||
| int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); | |||
| int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); | |||
| int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); | |||
| int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); | |||
| int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); | |||
| int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); | |||
| int16x8_t t00 = vsubq_s16(d00, d20); | |||
| int16x8_t t01 = vsubq_s16(d01, d21); | |||
| int16x8_t t02 = vsubq_s16(d02, d22); | |||
| int16x8_t t03 = vsubq_s16(d03, d23); | |||
| int16x8_t t10 = vaddq_s16(d10, d20); | |||
| int16x8_t t11 = vaddq_s16(d11, d21); | |||
| int16x8_t t12 = vaddq_s16(d12, d22); | |||
| int16x8_t t13 = vaddq_s16(d13, d23); | |||
| int16x8_t t20 = vsubq_s16(d20, d10); | |||
| int16x8_t t21 = vsubq_s16(d21, d11); | |||
| int16x8_t t22 = vsubq_s16(d22, d12); | |||
| int16x8_t t23 = vsubq_s16(d23, d13); | |||
| int16x8_t t30 = vsubq_s16(d10, d30); | |||
| int16x8_t t31 = vsubq_s16(d11, d31); | |||
| int16x8_t t32 = vsubq_s16(d12, d32); | |||
| int16x8_t t33 = vsubq_s16(d13, d33); | |||
| int16x8_t m00 = vsubq_s16(t00, t02); | |||
| int16x8_t m01 = vaddq_s16(t01, t02); | |||
| int16x8_t m02 = vsubq_s16(t02, t01); | |||
| int16x8_t m03 = vsubq_s16(t01, t03); | |||
| int16x8_t m10 = vsubq_s16(t10, t12); | |||
| int16x8_t m11 = vaddq_s16(t11, t12); | |||
| int16x8_t m12 = vsubq_s16(t12, t11); | |||
| int16x8_t m13 = vsubq_s16(t11, t13); | |||
| int16x8_t m20 = vsubq_s16(t20, t22); | |||
| int16x8_t m21 = vaddq_s16(t21, t22); | |||
| int16x8_t m22 = vsubq_s16(t22, t21); | |||
| int16x8_t m23 = vsubq_s16(t21, t23); | |||
| int16x8_t m30 = vsubq_s16(t30, t32); | |||
| int16x8_t m31 = vaddq_s16(t31, t32); | |||
| int16x8_t m32 = vsubq_s16(t32, t31); | |||
| int16x8_t m33 = vsubq_s16(t31, t33); | |||
| vst1q_s16(trans_input_data, m00); | |||
| vst1q_s16(trans_input_data + step, m01); | |||
| vst1q_s16(trans_input_data + 2 * step, m02); | |||
| vst1q_s16(trans_input_data + 3 * step, m03); | |||
| vst1q_s16(trans_input_data + 4 * step, m10); | |||
| vst1q_s16(trans_input_data + 5 * step, m11); | |||
| vst1q_s16(trans_input_data + 6 * step, m12); | |||
| vst1q_s16(trans_input_data + 7 * step, m13); | |||
| vst1q_s16(trans_input_data + 8 * step, m20); | |||
| vst1q_s16(trans_input_data + 9 * step, m21); | |||
| vst1q_s16(trans_input_data + 10 * step, m22); | |||
| vst1q_s16(trans_input_data + 11 * step, m23); | |||
| vst1q_s16(trans_input_data + 12 * step, m30); | |||
| vst1q_s16(trans_input_data + 13 * step, m31); | |||
| vst1q_s16(trans_input_data + 14 * step, m32); | |||
| vst1q_s16(trans_input_data + 15 * step, m33); | |||
| #else | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| int16_t *local_ptr = tmp_data + i; | |||
| int16_t d00 = local_ptr[0] - input_zp; | |||
| int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; | |||
| int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; | |||
| int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; | |||
| int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; | |||
| int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; | |||
| int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; | |||
| int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; | |||
| int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; | |||
| int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; | |||
| int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; | |||
| int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; | |||
| int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; | |||
| int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; | |||
| int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; | |||
| int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; | |||
| int16_t t00 = d00 - d20; | |||
| int16_t t01 = d01 - d21; | |||
| int16_t t02 = d02 - d22; | |||
| int16_t t03 = d03 - d23; | |||
| int16_t t10 = d10 + d20; | |||
| int16_t t11 = d11 + d21; | |||
| int16_t t12 = d12 + d22; | |||
| int16_t t13 = d13 + d23; | |||
| int16_t t20 = d20 - d10; | |||
| int16_t t21 = d21 - d11; | |||
| int16_t t22 = d22 - d12; | |||
| int16_t t23 = d23 - d13; | |||
| int16_t t30 = d10 - d30; | |||
| int16_t t31 = d11 - d31; | |||
| int16_t t32 = d12 - d32; | |||
| int16_t t33 = d13 - d33; | |||
| int16_t m00 = t00 - t02; | |||
| int16_t m01 = t01 + t02; | |||
| int16_t m02 = t02 - t01; | |||
| int16_t m03 = t01 - t03; | |||
| int16_t m10 = t10 - t12; | |||
| int16_t m11 = t11 + t12; | |||
| int16_t m12 = t12 - t11; | |||
| int16_t m13 = t11 - t13; | |||
| int16_t m20 = t20 - t22; | |||
| int16_t m21 = t21 + t22; | |||
| int16_t m22 = t22 - t21; | |||
| int16_t m23 = t21 - t23; | |||
| int16_t m30 = t30 - t32; | |||
| int16_t m31 = t31 + t32; | |||
| int16_t m32 = t32 - t31; | |||
| int16_t m33 = t31 - t33; | |||
| (trans_input_data + i)[0] = m00; | |||
| (trans_input_data + i + step)[0] = m01; | |||
| (trans_input_data + i + 2 * step)[0] = m02; | |||
| (trans_input_data + i + 3 * step)[0] = m03; | |||
| (trans_input_data + i + 4 * step)[0] = m10; | |||
| (trans_input_data + i + 5 * step)[0] = m11; | |||
| (trans_input_data + i + 6 * step)[0] = m12; | |||
| (trans_input_data + i + 7 * step)[0] = m13; | |||
| (trans_input_data + i + 8 * step)[0] = m20; | |||
| (trans_input_data + i + 9 * step)[0] = m21; | |||
| (trans_input_data + i + 10 * step)[0] = m22; | |||
| (trans_input_data + i + 11 * step)[0] = m23; | |||
| (trans_input_data + i + 12 * step)[0] = m30; | |||
| (trans_input_data + i + 13 * step)[0] = m31; | |||
| (trans_input_data + i + 14 * step)[0] = m32; | |||
| (trans_input_data + i + 15 * step)[0] = m33; | |||
| } | |||
| #endif | |||
| } | |||
| void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, | |||
| int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||
| // input data format : nhwc | |||
| int input_channel = conv_param->input_channel_; | |||
| int input_width = conv_param->input_w_; | |||
| int input_height = conv_param->input_h_; | |||
| int pad_w = conv_param->pad_l_; | |||
| int pad_h = conv_param->pad_u_; | |||
| ConvQuantArg quant_arg = conv_param->conv_quant_arg_; | |||
| int input_zp = quant_arg.input_quant_args_[0].zp_; | |||
| const int ic8 = UP_DIV(input_channel, C8NUM); | |||
| const int input_unit = 4; | |||
| if (out_w_block == 0) { | |||
| return; | |||
| } | |||
| for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { | |||
| int x_id = start_index + cal_id; | |||
| int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; | |||
| int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; | |||
| int real_x_start = origin_x > 0 ? 0 : -origin_x; | |||
| int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); | |||
| int real_y_start = origin_y > 0 ? 0 : -origin_y; | |||
| int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); | |||
| int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); | |||
| int dst_plane_offset = cal_id * C8NUM; | |||
| for (int ic = 0; ic < ic8; ic++) { | |||
| // copy data from origin input to tmp buffer | |||
| for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; | |||
| int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; | |||
| for (int j = real_y_start; j < real_y_end; j++) { | |||
| const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); | |||
| int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); | |||
| memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); | |||
| } | |||
| // input transform | |||
| int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; | |||
| size_t dst_step = ic8 * C8NUM * TILE_NUM; | |||
| int16_t *trans_input_ptr = trans_input + dst_ic8_offset; | |||
| Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, | |||
| int kernel_plane) { | |||
| const int input_unit = 4; | |||
| int dst_step = iC8 * C8NUM * C4NUM; | |||
| for (int o = 0; o < output_channel; o++) { | |||
| int oc4_block_num = o / C4NUM; | |||
| int oc4_block_rem = o % C4NUM; | |||
| int src_oc_offset = o * iC8 * C8NUM * kernel_plane; | |||
| int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; | |||
| for (int i = 0; i < iC8; i++) { | |||
| const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; | |||
| int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; | |||
| #ifdef ENABLE_ARM | |||
| int16x8_t g00 = vld1q_s16(src_ic8_ptr); | |||
| int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); | |||
| int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); | |||
| int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); | |||
| int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); | |||
| int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); | |||
| int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); | |||
| int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); | |||
| int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); | |||
| int16x8_t dst00 = vmulq_n_s16(g00, 2); | |||
| int16x8_t dst01 = vmulq_n_s16(g01, 2); | |||
| int16x8_t dst02 = vmulq_n_s16(g02, 2); | |||
| int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); | |||
| int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); | |||
| int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); | |||
| int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); | |||
| int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); | |||
| int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); | |||
| int16x8_t dst30 = vmulq_n_s16(g20, 2); | |||
| int16x8_t dst31 = vmulq_n_s16(g21, 2); | |||
| int16x8_t dst32 = vmulq_n_s16(g22, 2); | |||
| int16x8_t m00 = vmulq_n_s16(dst00, 2); | |||
| int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); | |||
| int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); | |||
| int16x8_t m03 = vmulq_n_s16(dst02, 2); | |||
| int16x8_t m10 = vmulq_n_s16(dst10, 2); | |||
| int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); | |||
| int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); | |||
| int16x8_t m13 = vmulq_n_s16(dst12, 2); | |||
| int16x8_t m20 = vmulq_n_s16(dst20, 2); | |||
| int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); | |||
| int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); | |||
| int16x8_t m23 = vmulq_n_s16(dst22, 2); | |||
| int16x8_t m30 = vmulq_n_s16(dst30, 2); | |||
| int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); | |||
| int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); | |||
| int16x8_t m33 = vmulq_n_s16(dst32, 2); | |||
| dst_ic8_ptr[0] = m00[0]; | |||
| dst_ic8_ptr[4] = m00[1]; | |||
| dst_ic8_ptr[8] = m00[2]; | |||
| dst_ic8_ptr[12] = m00[3]; | |||
| dst_ic8_ptr[16] = m00[4]; | |||
| dst_ic8_ptr[20] = m00[5]; | |||
| dst_ic8_ptr[24] = m00[6]; | |||
| dst_ic8_ptr[28] = m00[7]; | |||
| dst_ic8_ptr[0 + dst_step] = m01[0]; | |||
| dst_ic8_ptr[4 + dst_step] = m01[1]; | |||
| dst_ic8_ptr[8 + dst_step] = m01[2]; | |||
| dst_ic8_ptr[12 + dst_step] = m01[3]; | |||
| dst_ic8_ptr[16 + dst_step] = m01[4]; | |||
| dst_ic8_ptr[20 + dst_step] = m01[5]; | |||
| dst_ic8_ptr[24 + dst_step] = m01[6]; | |||
| dst_ic8_ptr[28 + dst_step] = m01[7]; | |||
| dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; | |||
| dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; | |||
| dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; | |||
| dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; | |||
| dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; | |||
| dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; | |||
| dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; | |||
| dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; | |||
| dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; | |||
| dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; | |||
| dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; | |||
| dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; | |||
| dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; | |||
| dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; | |||
| dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; | |||
| dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; | |||
| dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; | |||
| dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; | |||
| dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; | |||
| dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; | |||
| dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; | |||
| dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; | |||
| dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; | |||
| dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; | |||
| dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; | |||
| dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; | |||
| dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; | |||
| dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; | |||
| dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; | |||
| dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; | |||
| dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; | |||
| dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; | |||
| dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; | |||
| dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; | |||
| dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; | |||
| dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; | |||
| dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; | |||
| dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; | |||
| dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; | |||
| dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; | |||
| dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; | |||
| dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; | |||
| dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; | |||
| dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; | |||
| dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; | |||
| dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; | |||
| dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; | |||
| dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; | |||
| dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; | |||
| dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; | |||
| dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; | |||
| dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; | |||
| dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; | |||
| dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; | |||
| dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; | |||
| dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; | |||
| dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; | |||
| dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; | |||
| dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; | |||
| dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; | |||
| dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; | |||
| dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; | |||
| dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; | |||
| dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; | |||
| dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; | |||
| dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; | |||
| dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; | |||
| dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; | |||
| dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; | |||
| dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; | |||
| dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; | |||
| dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; | |||
| dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; | |||
| dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; | |||
| dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; | |||
| dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; | |||
| dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; | |||
| dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; | |||
| dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; | |||
| dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; | |||
| dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; | |||
| dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; | |||
| dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; | |||
| dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; | |||
| dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; | |||
| dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; | |||
| dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; | |||
| dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; | |||
| dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; | |||
| dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; | |||
| dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; | |||
| dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; | |||
| dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; | |||
| dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; | |||
| dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; | |||
| dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; | |||
| dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; | |||
| dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; | |||
| dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; | |||
| dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; | |||
| dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; | |||
| dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; | |||
| dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; | |||
| dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; | |||
| dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; | |||
| dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; | |||
| dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; | |||
| dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; | |||
| dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; | |||
| dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; | |||
| dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; | |||
| dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; | |||
| #else | |||
| for (int j = 0; j < C8NUM; j++) { | |||
| const int16_t *local_ptr = src_ic8_ptr + j; | |||
| int16_t dst00 = local_ptr[0] * 2; | |||
| int16_t dst01 = (local_ptr + 8)[0] * 2; | |||
| int16_t dst02 = (local_ptr + 16)[0] * 2; | |||
| int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; | |||
| int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; | |||
| int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; | |||
| int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; | |||
| int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; | |||
| int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; | |||
| int16_t dst30 = (local_ptr + 48)[0] * 2; | |||
| int16_t dst31 = (local_ptr + 56)[0] * 2; | |||
| int16_t dst32 = (local_ptr + 64)[0] * 2; | |||
| int16_t m00 = dst00 * 2; | |||
| int16_t m01 = dst00 + dst01 + dst02; | |||
| int16_t m02 = dst00 - dst01 + dst02; | |||
| int16_t m03 = dst02 * 2; | |||
| int16_t m10 = dst10 * 2; | |||
| int16_t m11 = dst10 + dst11 + dst12; | |||
| int16_t m12 = dst10 - dst11 + dst12; | |||
| int16_t m13 = dst12 * 2; | |||
| int16_t m20 = dst20 * 2; | |||
| int16_t m21 = dst20 + dst21 + dst22; | |||
| int16_t m22 = dst20 - dst21 + dst22; | |||
| int16_t m23 = dst22 * 2; | |||
| int16_t m30 = dst30 * 2; | |||
| int16_t m31 = dst30 + dst31 + dst32; | |||
| int16_t m32 = dst30 - dst31 + dst32; | |||
| int16_t m33 = dst32 * 2; | |||
| *(dst_ic8_ptr + j * 4) = m00; | |||
| *(dst_ic8_ptr + j * 4 + dst_step) = m01; | |||
| *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; | |||
| *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; | |||
| *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; | |||
| *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; | |||
| *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; | |||
| *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; | |||
| *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; | |||
| *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; | |||
| *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; | |||
| *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; | |||
| *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; | |||
| *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; | |||
| *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; | |||
| *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, | |||
| bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { | |||
| int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; | |||
| int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; | |||
| int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; | |||
| int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | |||
| int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; | |||
| int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; | |||
| #ifdef ENABLE_ARM | |||
| int32x4_t bias_ptr = vld1q_s32(bias_data); | |||
| int32x4_t s00 = vld1q_s32(gemm_out); | |||
| int32x4_t s01 = vld1q_s32(gemm_out + 4); | |||
| int32x4_t s02 = vld1q_s32(gemm_out + 8); | |||
| int32x4_t s03 = vld1q_s32(gemm_out + 12); | |||
| int32x4_t s10 = vld1q_s32(gemm_out + 16); | |||
| int32x4_t s11 = vld1q_s32(gemm_out + 20); | |||
| int32x4_t s12 = vld1q_s32(gemm_out + 24); | |||
| int32x4_t s13 = vld1q_s32(gemm_out + 28); | |||
| int32x4_t s20 = vld1q_s32(gemm_out + 32); | |||
| int32x4_t s21 = vld1q_s32(gemm_out + 36); | |||
| int32x4_t s22 = vld1q_s32(gemm_out + 40); | |||
| int32x4_t s23 = vld1q_s32(gemm_out + 44); | |||
| int32x4_t s30 = vld1q_s32(gemm_out + 48); | |||
| int32x4_t s31 = vld1q_s32(gemm_out + 52); | |||
| int32x4_t s32 = vld1q_s32(gemm_out + 56); | |||
| int32x4_t s33 = vld1q_s32(gemm_out + 60); | |||
| int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); | |||
| int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); | |||
| int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); | |||
| int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); | |||
| int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); | |||
| int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); | |||
| int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); | |||
| int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); | |||
| int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); | |||
| int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); | |||
| int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); | |||
| int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); | |||
| int32x4_t out_multiplier; | |||
| int32x4_t ls; | |||
| int32x4_t rs; | |||
| if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | |||
| out_multiplier = vld1q_s32(quant_multiplier + oc_start); | |||
| ls = vld1q_s32(left_shift + oc_start); | |||
| rs = vld1q_s32(right_shift + oc_start); | |||
| } else { | |||
| out_multiplier = vdupq_n_s32(quant_multiplier[0]); | |||
| ls = vdupq_n_s32(left_shift[0]); | |||
| rs = vdupq_n_s32(right_shift[0]); | |||
| } | |||
| int32x4_t out_zp = vdupq_n_s32(output_zp); | |||
| int32x4_t output_min = vdupq_n_s32(out_min); | |||
| int32x4_t output_max = vdupq_n_s32(out_max); | |||
| d00 = vqshlq_s32(d00, ls); | |||
| d00 = vqrdmulhq_s32(d00, out_multiplier); | |||
| int32x4_t carry = vandq_s32(d00, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d00 = vqaddq_s32(d00, carry); | |||
| d00 = vqrshlq_s32(d00, rs); | |||
| d00 = vaddq_s32(d00, out_zp); | |||
| d00 = vmaxq_s32(d00, output_min); | |||
| d00 = vminq_s32(d00, output_max); | |||
| d01 = vqshlq_s32(d01, ls); | |||
| d01 = vqrdmulhq_s32(d01, out_multiplier); | |||
| carry = vandq_s32(d01, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d01 = vqaddq_s32(d01, carry); | |||
| d01 = vqrshlq_s32(d01, rs); | |||
| d01 = vaddq_s32(d01, out_zp); | |||
| d01 = vmaxq_s32(d01, output_min); | |||
| d01 = vminq_s32(d01, output_max); | |||
| d10 = vqshlq_s32(d10, ls); | |||
| d10 = vqrdmulhq_s32(d10, out_multiplier); | |||
| carry = vandq_s32(d10, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d10 = vqaddq_s32(d10, carry); | |||
| d10 = vqrshlq_s32(d10, rs); | |||
| d10 = vaddq_s32(d10, out_zp); | |||
| d10 = vmaxq_s32(d10, output_min); | |||
| d10 = vminq_s32(d10, output_max); | |||
| d11 = vqshlq_s32(d11, ls); | |||
| d11 = vqrdmulhq_s32(d11, out_multiplier); | |||
| carry = vandq_s32(d11, rs); | |||
| carry = vshrq_n_s32(carry, 31); | |||
| d11 = vqaddq_s32(d11, carry); | |||
| d11 = vqrshlq_s32(d11, rs); | |||
| d11 = vaddq_s32(d11, out_zp); | |||
| d11 = vmaxq_s32(d11, output_min); | |||
| d11 = vminq_s32(d11, output_max); | |||
| (output_data)[0] = (int8_t)d00[0]; | |||
| (output_data + 1)[0] = (int8_t)d00[1]; | |||
| (output_data + 2)[0] = (int8_t)d00[2]; | |||
| (output_data + 3)[0] = (int8_t)d00[3]; | |||
| if (w_not_bound) { | |||
| *(output_data + 4) = (int8_t)d01[0]; | |||
| *(output_data + 5) = (int8_t)d01[1]; | |||
| *(output_data + 6) = (int8_t)d01[2]; | |||
| *(output_data + 7) = (int8_t)d01[3]; | |||
| } | |||
| if (h_not_bound) { | |||
| *(output_data + output_w * 4) = (int8_t)d10[0]; | |||
| *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; | |||
| *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; | |||
| *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; | |||
| if (w_not_bound) { | |||
| *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; | |||
| *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; | |||
| *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; | |||
| *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; | |||
| } | |||
| } | |||
| #else | |||
| if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| const int32_t *local_ptr = gemm_out + i; | |||
| const int32_t *bias_ptr = bias_data + i; | |||
| int32_t s00 = local_ptr[0]; | |||
| int32_t s01 = (local_ptr + 4)[0]; | |||
| int32_t s02 = (local_ptr + 8)[0]; | |||
| int32_t s03 = (local_ptr + 12)[0]; | |||
| int32_t s10 = (local_ptr + 16)[0]; | |||
| int32_t s11 = (local_ptr + 20)[0]; | |||
| int32_t s12 = (local_ptr + 24)[0]; | |||
| int32_t s13 = (local_ptr + 28)[0]; | |||
| int32_t s20 = (local_ptr + 32)[0]; | |||
| int32_t s21 = (local_ptr + 36)[0]; | |||
| int32_t s22 = (local_ptr + 40)[0]; | |||
| int32_t s23 = (local_ptr + 44)[0]; | |||
| int32_t s30 = (local_ptr + 48)[0]; | |||
| int32_t s31 = (local_ptr + 52)[0]; | |||
| int32_t s32 = (local_ptr + 56)[0]; | |||
| int32_t s33 = (local_ptr + 60)[0]; | |||
| int32_t t00 = (s00 + s10 + s20) / 2; | |||
| int32_t t01 = (s01 + s11 + s21) / 2; | |||
| int32_t t02 = (s02 + s12 + s22) / 2; | |||
| int32_t t03 = (s03 + s13 + s23) / 2; | |||
| int32_t t10 = (s10 - s20 - s30) / 2; | |||
| int32_t t11 = (s11 - s21 - s31) / 2; | |||
| int32_t t12 = (s12 - s22 - s32) / 2; | |||
| int32_t t13 = (s13 - s23 - s33) / 2; | |||
| int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; | |||
| int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; | |||
| int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; | |||
| int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; | |||
| int oc_index = oc_start + i; | |||
| d00 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d00 += output_zp; | |||
| d00 = d00 > out_min ? d00 : out_min; | |||
| d00 = d00 < out_max ? d00 : out_max; | |||
| d01 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d01 += output_zp; | |||
| d01 = d01 > out_min ? d01 : out_min; | |||
| d01 = d01 < out_max ? d01 : out_max; | |||
| d10 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d10 += output_zp; | |||
| d10 = d10 > out_min ? d10 : out_min; | |||
| d10 = d10 < out_max ? d10 : out_max; | |||
| d11 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), | |||
| -right_shift[oc_index]); | |||
| d11 += output_zp; | |||
| d11 = d11 > out_min ? d11 : out_min; | |||
| d11 = d11 < out_max ? d11 : out_max; | |||
| (output_data + i)[0] = (int8_t)d00; | |||
| if (w_not_bound) { | |||
| (output_data + i + C4NUM)[0] = (int8_t)d01; | |||
| } | |||
| if (h_not_bound) { | |||
| (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; | |||
| if (w_not_bound) { | |||
| (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| const int32_t *local_ptr = gemm_out + i; | |||
| const int32_t *bias_ptr = bias_data + i; | |||
| int32_t s00 = local_ptr[0]; | |||
| int32_t s01 = (local_ptr + 4)[0]; | |||
| int32_t s02 = (local_ptr + 8)[0]; | |||
| int32_t s03 = (local_ptr + 12)[0]; | |||
| int32_t s10 = (local_ptr + 16)[0]; | |||
| int32_t s11 = (local_ptr + 20)[0]; | |||
| int32_t s12 = (local_ptr + 24)[0]; | |||
| int32_t s13 = (local_ptr + 28)[0]; | |||
| int32_t s20 = (local_ptr + 32)[0]; | |||
| int32_t s21 = (local_ptr + 36)[0]; | |||
| int32_t s22 = (local_ptr + 40)[0]; | |||
| int32_t s23 = (local_ptr + 44)[0]; | |||
| int32_t s30 = (local_ptr + 48)[0]; | |||
| int32_t s31 = (local_ptr + 52)[0]; | |||
| int32_t s32 = (local_ptr + 56)[0]; | |||
| int32_t s33 = (local_ptr + 60)[0]; | |||
| int32_t t00 = (s00 + s10 + s20) / 2; | |||
| int32_t t01 = (s01 + s11 + s21) / 2; | |||
| int32_t t02 = (s02 + s12 + s22) / 2; | |||
| int32_t t03 = (s03 + s13 + s23) / 2; | |||
| int32_t t10 = (s10 - s20 - s30) / 2; | |||
| int32_t t11 = (s11 - s21 - s31) / 2; | |||
| int32_t t12 = (s12 - s22 - s32) / 2; | |||
| int32_t t13 = (s13 - s23 - s33) / 2; | |||
| int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; | |||
| int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; | |||
| int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; | |||
| int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; | |||
| d00 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d00 += output_zp; | |||
| d00 = d00 > out_min ? d00 : out_min; | |||
| d00 = d00 < out_max ? d00 : out_max; | |||
| d01 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d01 += output_zp; | |||
| d01 = d01 > out_min ? d01 : out_min; | |||
| d01 = d01 < out_max ? d01 : out_max; | |||
| d10 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d10 += output_zp; | |||
| d10 = d10 > out_min ? d10 : out_min; | |||
| d10 = d10 < out_max ? d10 : out_max; | |||
| d11 = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), | |||
| -right_shift[0]); | |||
| d11 += output_zp; | |||
| d11 = d11 > out_min ? d11 : out_min; | |||
| d11 = d11 < out_max ? d11 : out_max; | |||
| (output_data + i)[0] = (int8_t)d00; | |||
| if (w_not_bound) { | |||
| (output_data + i + C4NUM)[0] = (int8_t)d01; | |||
| } | |||
| if (h_not_bound) { | |||
| (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; | |||
| if (w_not_bound) { | |||
| (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, | |||
| int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||
| int output_channel = conv_param->output_channel_; | |||
| int output_w = conv_param->output_w_; | |||
| int output_h = conv_param->output_h_; | |||
| const int oc4 = UP_DIV(output_channel, C4NUM); | |||
| const int input_unit = 4; | |||
| if (out_w_block == 0) { | |||
| return; | |||
| } | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int out_w_index = (start_index + i) % out_w_block; | |||
| int out_h_index = (start_index + i) / out_w_block; | |||
| int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; | |||
| int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); | |||
| for (int j = 0; j < oc4; j++) { | |||
| int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; | |||
| int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; | |||
| const int32_t *src_ptr = gemm_out + src_oc4_offset; | |||
| const int32_t *bias_ptr = bias_data + j * C4NUM; | |||
| int8_t *dst_ptr = out_data + dst_oc4_offset; | |||
| // output transform | |||
| int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; | |||
| bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; | |||
| bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; | |||
| Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, | |||
| conv_param); | |||
| } | |||
| } | |||
| } | |||