|
|
|
@@ -18,16 +18,19 @@ |
|
|
|
#ifdef ENABLE_NEON |
|
|
|
#include <arm_neon.h> |
|
|
|
#endif |
|
|
|
#ifdef ENABLE_AVX |
|
|
|
#include <x86intrin.h> |
|
|
|
#include "nnacl/x86_64_avx/common_utils.h" |
|
|
|
#endif |
|
|
|
#include "nnacl/int8/fixed_point.h" |
|
|
|
|
|
|
|
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { |
|
|
|
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); |
|
|
|
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); |
|
|
|
int index = 0; |
|
|
|
|
|
|
|
#ifdef ENABLE_ARM |
|
|
|
const int8x16_t min_vec = vdupq_n_s8(params->min_); |
|
|
|
const int8x16_t max_vac = vdupq_n_s8(params->max_); |
|
|
|
const int8x16_t max_vec = vdupq_n_s8(params->max_); |
|
|
|
|
|
|
|
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_); |
|
|
|
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_); |
|
|
|
@@ -142,12 +145,11 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz |
|
|
|
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); |
|
|
|
|
|
|
|
const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); |
|
|
|
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vac, out)); |
|
|
|
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); |
|
|
|
|
|
|
|
vst1q_s8(output + index, int8_out); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
for (; index < size; index++) { |
|
|
|
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; |
|
|
|
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; |
|
|
|
@@ -173,7 +175,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i |
|
|
|
#ifdef ENABLE_ARM |
|
|
|
/* const value init */ |
|
|
|
const int8x16_t min_vec = vdupq_n_s8(params->min_); |
|
|
|
const int8x16_t max_vac = vdupq_n_s8(params->max_); |
|
|
|
const int8x16_t max_vec = vdupq_n_s8(params->max_); |
|
|
|
|
|
|
|
const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_); |
|
|
|
const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_); |
|
|
|
@@ -293,7 +295,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i |
|
|
|
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); |
|
|
|
|
|
|
|
const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); |
|
|
|
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vac, out)); |
|
|
|
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); |
|
|
|
|
|
|
|
vst1q_s8(output + index, int8_out); |
|
|
|
} |
|
|
|
@@ -325,3 +327,357 @@ int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int |
|
|
|
TileDimensionsInt8(in0, in1, tile_in0, tile_in1, param); |
|
|
|
return ElementAddInt8(tile_in0, tile_in1, out, size); |
|
|
|
} |
|
|
|
|
|
|
|
#ifdef ENABLE_AVX |
|
|
|
void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { |
|
|
|
const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); |
|
|
|
const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); |
|
|
|
const __m128i min_vec = _mm_set1_epi8(params->min_); |
|
|
|
const __m128i max_vec = _mm_set1_epi8(params->max_); |
|
|
|
const __m128i in0_zp_vec = _mm_set1_epi16(params->in0_args_.zp_); |
|
|
|
const __m128i in1_zp_vec = _mm_set1_epi16(params->in1_args_.zp_); |
|
|
|
const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); |
|
|
|
const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); |
|
|
|
const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); |
|
|
|
const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); |
|
|
|
const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); |
|
|
|
const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); |
|
|
|
int index = 0; |
|
|
|
for (; index <= size - 16; index += 16) { |
|
|
|
const __m128i in0_src = _mm_loadu_si128((__m128i_u *)(input0 + index)); |
|
|
|
const __m128i in1_src = _mm_loadu_si128((__m128i_u *)(input1 + index)); |
|
|
|
|
|
|
|
const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); |
|
|
|
const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); |
|
|
|
const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); |
|
|
|
const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); |
|
|
|
const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); |
|
|
|
const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); |
|
|
|
|
|
|
|
const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); |
|
|
|
const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); |
|
|
|
const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); |
|
|
|
const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); |
|
|
|
|
|
|
|
__m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); |
|
|
|
__m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); |
|
|
|
__m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); |
|
|
|
tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); |
|
|
|
__m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); |
|
|
|
__m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); |
|
|
|
__m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); |
|
|
|
__m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); |
|
|
|
__m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); |
|
|
|
tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); |
|
|
|
__m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); |
|
|
|
__m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec); |
|
|
|
in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec); |
|
|
|
in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec); |
|
|
|
in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec); |
|
|
|
in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec); |
|
|
|
in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec); |
|
|
|
in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec); |
|
|
|
in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec); |
|
|
|
|
|
|
|
// Apply the fixed-point part of the multiplier. |
|
|
|
in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier); |
|
|
|
in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier); |
|
|
|
in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier); |
|
|
|
in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier); |
|
|
|
in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier); |
|
|
|
in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier); |
|
|
|
in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier); |
|
|
|
in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier); |
|
|
|
|
|
|
|
// Apply right shift |
|
|
|
int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1; |
|
|
|
int32_t in0_remainder_threshold = in0_remainder_mask >> 1; |
|
|
|
const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask); |
|
|
|
const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold); |
|
|
|
const __m128i in0_1_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1)); |
|
|
|
in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold)); |
|
|
|
const __m128i in0_2_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2)); |
|
|
|
in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold)); |
|
|
|
const __m128i in0_3_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3)); |
|
|
|
in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold)); |
|
|
|
const __m128i in0_4_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4)); |
|
|
|
in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold)); |
|
|
|
|
|
|
|
int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1; |
|
|
|
int32_t in1_remainder_threshold = in1_remainder_mask >> 1; |
|
|
|
const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); |
|
|
|
const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); |
|
|
|
const __m128i in1_1_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1)); |
|
|
|
in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold)); |
|
|
|
const __m128i in1_2_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2)); |
|
|
|
in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold)); |
|
|
|
const __m128i in1_3_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3)); |
|
|
|
in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold)); |
|
|
|
const __m128i in1_4_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4)); |
|
|
|
in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold)); |
|
|
|
|
|
|
|
/* calculate output */ |
|
|
|
__m128i out1 = _mm_add_epi32(in0_1, in1_1); |
|
|
|
__m128i out2 = _mm_add_epi32(in0_2, in1_2); |
|
|
|
__m128i out3 = _mm_add_epi32(in0_3, in1_3); |
|
|
|
__m128i out4 = _mm_add_epi32(in0_4, in1_4); |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
out1 = _mm_slli_epi32(out1, params->out_left_shift_); |
|
|
|
out2 = _mm_slli_epi32(out2, params->out_left_shift_); |
|
|
|
out3 = _mm_slli_epi32(out3, params->out_left_shift_); |
|
|
|
out4 = _mm_slli_epi32(out4, params->out_left_shift_); |
|
|
|
|
|
|
|
// Apply the fixed-point part of the multiplier. |
|
|
|
out1 = _mm_qrdmulh_epi32(out1, out_multiplier); |
|
|
|
out2 = _mm_qrdmulh_epi32(out2, out_multiplier); |
|
|
|
out3 = _mm_qrdmulh_epi32(out3, out_multiplier); |
|
|
|
out4 = _mm_qrdmulh_epi32(out4, out_multiplier); |
|
|
|
|
|
|
|
// Apply right shift |
|
|
|
int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; |
|
|
|
int32_t out_remainder_threshold = out_remainder_mask >> 1; |
|
|
|
const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); |
|
|
|
const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); |
|
|
|
const __m128i out1_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); |
|
|
|
out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); |
|
|
|
const __m128i out2_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); |
|
|
|
out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); |
|
|
|
const __m128i out3_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); |
|
|
|
out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); |
|
|
|
const __m128i out4_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); |
|
|
|
out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); |
|
|
|
|
|
|
|
__m128i out1_s16 = _mm_packs_epi32(out1, out2); |
|
|
|
__m128i out2_s16 = _mm_packs_epi32(out3, out4); |
|
|
|
|
|
|
|
__m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); |
|
|
|
__m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); |
|
|
|
__m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); |
|
|
|
__m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); |
|
|
|
|
|
|
|
_mm_storeu_si128((__m128i_u *)(output + index), int8_out); |
|
|
|
} |
|
|
|
for (; index < size; index++) { |
|
|
|
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; |
|
|
|
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; |
|
|
|
const int32_t in0 = |
|
|
|
MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); |
|
|
|
const int32_t in1 = |
|
|
|
MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); |
|
|
|
|
|
|
|
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, |
|
|
|
-params->out_right_shift_); |
|
|
|
out += params->out_zp_; |
|
|
|
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params, |
|
|
|
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args) { |
|
|
|
// input0: ptr_in |
|
|
|
// input1: element_in |
|
|
|
// load quant parameters of input0 and input1 |
|
|
|
const int in0_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); |
|
|
|
const int in1_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); |
|
|
|
const __m128i min_vec = _mm_set1_epi8(params->min_); |
|
|
|
const __m128i max_vec = _mm_set1_epi8(params->max_); |
|
|
|
const __m128i in0_zp_vec = _mm_set1_epi16(ptr_args->zp_); |
|
|
|
const __m128i in1_zp_vec = _mm_set1_epi16(ele_args->zp_); |
|
|
|
const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); |
|
|
|
const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); |
|
|
|
const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); |
|
|
|
const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); |
|
|
|
const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); |
|
|
|
const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); |
|
|
|
|
|
|
|
// input1 can be processed once because it is const |
|
|
|
const __m128i in1_src = _mm_set1_epi8(element_in); |
|
|
|
const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); |
|
|
|
const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); |
|
|
|
const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); |
|
|
|
const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); |
|
|
|
const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); |
|
|
|
__m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); |
|
|
|
__m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); |
|
|
|
__m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); |
|
|
|
tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); |
|
|
|
__m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); |
|
|
|
__m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec); |
|
|
|
in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec); |
|
|
|
in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec); |
|
|
|
in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec); |
|
|
|
|
|
|
|
// Apply the fixed-point part of the multiplier. |
|
|
|
in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier); |
|
|
|
in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier); |
|
|
|
in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier); |
|
|
|
in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier); |
|
|
|
|
|
|
|
// Apply right shift |
|
|
|
int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1; |
|
|
|
int32_t in1_remainder_threshold = in1_remainder_mask >> 1; |
|
|
|
const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); |
|
|
|
const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); |
|
|
|
const __m128i in1_1_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1)); |
|
|
|
in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold)); |
|
|
|
const __m128i in1_2_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2)); |
|
|
|
in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold)); |
|
|
|
const __m128i in1_3_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3)); |
|
|
|
in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold)); |
|
|
|
const __m128i in1_4_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4)); |
|
|
|
in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold)); |
|
|
|
|
|
|
|
int index = 0; |
|
|
|
for (; index <= size - 16; index += 16) { |
|
|
|
const __m128i in0_src = _mm_loadu_si128((__m128i_u *)(ptr_in + index)); |
|
|
|
const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); |
|
|
|
const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); |
|
|
|
const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); |
|
|
|
const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); |
|
|
|
const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); |
|
|
|
|
|
|
|
__m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); |
|
|
|
__m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); |
|
|
|
__m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); |
|
|
|
tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); |
|
|
|
__m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); |
|
|
|
__m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec); |
|
|
|
in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec); |
|
|
|
in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec); |
|
|
|
in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec); |
|
|
|
|
|
|
|
// Apply the fixed-point part of the multiplier. |
|
|
|
in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier); |
|
|
|
in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier); |
|
|
|
in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier); |
|
|
|
in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier); |
|
|
|
|
|
|
|
// Apply right shift |
|
|
|
int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1; |
|
|
|
int32_t in0_remainder_threshold = in0_remainder_mask >> 1; |
|
|
|
const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask); |
|
|
|
const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold); |
|
|
|
const __m128i in0_1_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1)); |
|
|
|
in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold)); |
|
|
|
const __m128i in0_2_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2)); |
|
|
|
in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold)); |
|
|
|
const __m128i in0_3_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3)); |
|
|
|
in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold)); |
|
|
|
const __m128i in0_4_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4)); |
|
|
|
in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_), |
|
|
|
_mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold)); |
|
|
|
|
|
|
|
/* calculate output */ |
|
|
|
__m128i out1 = _mm_add_epi32(in0_1, in1_1); |
|
|
|
__m128i out2 = _mm_add_epi32(in0_2, in1_2); |
|
|
|
__m128i out3 = _mm_add_epi32(in0_3, in1_3); |
|
|
|
__m128i out4 = _mm_add_epi32(in0_4, in1_4); |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
out1 = _mm_slli_epi32(out1, params->out_left_shift_); |
|
|
|
out2 = _mm_slli_epi32(out2, params->out_left_shift_); |
|
|
|
out3 = _mm_slli_epi32(out3, params->out_left_shift_); |
|
|
|
out4 = _mm_slli_epi32(out4, params->out_left_shift_); |
|
|
|
|
|
|
|
// Apply the fixed-point part of the multiplier. |
|
|
|
out1 = _mm_qrdmulh_epi32(out1, out_multiplier); |
|
|
|
out2 = _mm_qrdmulh_epi32(out2, out_multiplier); |
|
|
|
out3 = _mm_qrdmulh_epi32(out3, out_multiplier); |
|
|
|
out4 = _mm_qrdmulh_epi32(out4, out_multiplier); |
|
|
|
|
|
|
|
// Apply right shift |
|
|
|
int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; |
|
|
|
int32_t out_remainder_threshold = out_remainder_mask >> 1; |
|
|
|
const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); |
|
|
|
const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); |
|
|
|
const __m128i out1_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); |
|
|
|
out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); |
|
|
|
const __m128i out2_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); |
|
|
|
out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); |
|
|
|
const __m128i out3_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); |
|
|
|
out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); |
|
|
|
const __m128i out4_remainder = |
|
|
|
_mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); |
|
|
|
out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), |
|
|
|
_mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); |
|
|
|
|
|
|
|
__m128i out1_s16 = _mm_packs_epi32(out1, out2); |
|
|
|
__m128i out2_s16 = _mm_packs_epi32(out3, out4); |
|
|
|
|
|
|
|
__m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); |
|
|
|
__m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); |
|
|
|
__m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); |
|
|
|
__m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); |
|
|
|
|
|
|
|
_mm_storeu_si128((__m128i_u *)(output + index), int8_out); |
|
|
|
} |
|
|
|
for (; index < size; index++) { |
|
|
|
const int32_t in0_left = (ptr_in[index] + ptr_args->zp_) * in0_left_shift; |
|
|
|
const int32_t in1_left = (element_in + ele_args->zp_) * in1_left_shift; |
|
|
|
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, ptr_args->multiplier_, ptr_args->right_shift_); |
|
|
|
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, ele_args->multiplier_, ele_args->right_shift_); |
|
|
|
|
|
|
|
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, |
|
|
|
-params->out_right_shift_); |
|
|
|
out += params->out_zp_; |
|
|
|
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
#endif |