You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixed_point.c 8.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/quantization/fixed_point.h"
  17. // returns the high-32 bits of a * b with rounding
  18. // assume that a and b is divided by 2^31, who fall into [-1, 1]
  19. // so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31
  20. // actually we compute 2 * a * b / 2^32
  21. // and take 32 bits of mantissa for rounding
  22. int SaturatingRoundingDoublingHighMul(int a, int b) {
  23. if (a == INT_MIN && b == INT_MIN) {
  24. return INT_MAX;
  25. }
  26. int64_t ab = ((int64_t)a) * ((int64_t)b);
  27. int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30));
  28. // do not apply right shift to potential negetive values
  29. int ab_mantissa = (int)((ab + rounding) / (1ll << 31));
  30. return ab_mantissa;
  31. }
  32. int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) {
  33. if (a == SHRT_MIN && b == SHRT_MIN) {
  34. return SHRT_MAX;
  35. }
  36. int32_t ab = ((int32_t)a) * ((int32_t)b);
  37. int16_t rounding = ab >= 0 ? (1ll << 14) : (1ll - (1ll << 14));
  38. return (int16_t)((ab + rounding) / (1ll << 15));
  39. }
  40. // division by a 2^exponent with rounding
  41. // or arithmetic right shift with rouding
  42. int RoundingDivideByPOT(int x, int exponent) {
  43. const int mask = (1ll << exponent) - 1;
  44. const int remainder = x & mask;
  45. const int threshold = (mask >> 1) + (x < 0 ? 1 : 0);
  46. return (x >> exponent) + (remainder > threshold ? 1 : 0);
  47. }
  48. int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) {
  49. return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
  50. }
  51. int FractionsBits(int integer_bits) { return 8 * sizeof(int32_t) - 1 - integer_bits; }
  52. int FixedPoint_One(int integer_bits, int fractions_bits) {
  53. return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits)));
  54. }
  55. int RoundingHalfSum(int32_t a, int32_t b) {
  56. int64_t sum = (int64_t)a + (int64_t)b;
  57. return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2);
  58. }
  59. int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; }
  60. int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; }
  61. int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; }
  62. int32_t BitNot(int32_t a) { return ~(uint32_t)a; }
  63. int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); }
  64. int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); }
  65. int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; }
  66. int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
  67. int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); }
  68. int CountLeadingZeroBits(uint32_t x) {
  69. #if defined(__GUNC__)
  70. return x ? __builtin_clz(x) : 8 * sizeof(uint32_t);
  71. #else
  72. if (x == 0) {
  73. return 8 * sizeof(uint32_t);
  74. }
  75. const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1);
  76. int leading_zeros = 0;
  77. while (x < leading_positive) {
  78. x <<= 1;
  79. leading_zeros++;
  80. }
  81. return leading_zeros;
  82. #endif
  83. }
  84. int CountLeadingSignBits(int32_t x) {
  85. #if defined(__GUNC__) && !defined(__clang__)
  86. return x ? __builtin_clrsb(x) : 8 * sizeof(int32_t);
  87. #else
  88. return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0;
  89. #endif
  90. }
  91. int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) {
  92. if (exponent > 0) {
  93. const int min = INT32_MIN;
  94. const int max = INT32_MAX;
  95. const int scalar_int_bits = 8 * sizeof(int32_t);
  96. const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1);
  97. const int postive_mask = x > thresold ? BitNot(0) : 0;
  98. const int negative_mask = x < -thresold ? BitNot(0) : 0;
  99. int result = x * ((int32_t)(1) << (uint32_t)exponent);
  100. result = BitsSelect(postive_mask, max, result);
  101. result = BitsSelect(negative_mask, min, result);
  102. return result;
  103. } else if (exponent < 0) {
  104. return RoundingDivideByPOT(x, -exponent);
  105. } else {
  106. return x;
  107. }
  108. }
  109. int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) {
  110. int exponent = integer_bits_src - integer_bits_dst;
  111. return SaturatingRoundingMultiplyByPOT(x, exponent);
  112. }
  113. int32_t reciprocal_on_interval_between_0_1(int32_t a) {
  114. int one = FixedPoint_One(0, FractionsBits(0));
  115. int half_sum = RoundingHalfSum(a, one);
  116. const int constant_48_over_17 = 1515870810;
  117. const int constant_neg_32_over_17 = -1010580540;
  118. int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17);
  119. for (int i = 0; i < 3; i++) {
  120. int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x);
  121. int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x;
  122. x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2);
  123. }
  124. return Rescale(x, 2 - 1, 0);
  125. }
  126. int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift) {
  127. int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x);
  128. *recip_shift = x_digits - leading_zreos_plus_one;
  129. const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31));
  130. const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one);
  131. return shifted_scaled;
  132. }
  133. int exp_on_interval_values(int a) {
  134. const int constant_neg_1_over_8 = 1895147668;
  135. const int constant_1_over_3 = 715827883;
  136. int fractional_bits = FractionsBits(0);
  137. int x = a + ConstantPOT(fractional_bits, -3);
  138. int x2 = SaturatingRoundingDoublingHighMul(x, x);
  139. int x3 = SaturatingRoundingDoublingHighMul(x2, x);
  140. int x4 = SaturatingRoundingDoublingHighMul(x2, x2);
  141. int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2);
  142. int x4_over_24_plus_x3_over_6_plus_x2_over_2 =
  143. SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1);
  144. return constant_neg_1_over_8 +
  145. SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
  146. }
  147. void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder,
  148. int *result) {
  149. if (integer_bits > exponent) {
  150. int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0;
  151. *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))),
  152. SaturatingRoundingDoublingHighMul(*result, muliplier), *result);
  153. }
  154. }
  155. int exp_on_negative_values(int a, const int integer_bits) {
  156. int fractional_bits = FractionsBits(integer_bits);
  157. const int one_quarter = ConstantPOT(fractional_bits, -2);
  158. int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter;
  159. int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0));
  160. int remainder = a_mod_quarter_minus_one_quarter - a;
  161. exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result);
  162. exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result);
  163. exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result);
  164. exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result);
  165. exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result);
  166. exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result);
  167. exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result);
  168. int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0;
  169. if (integer_bits > 5) {
  170. const int clamp = -(1 << (uint32_t)clamp_bits);
  171. result = BitsSelect(MaskIfLessThan(a, clamp), 0, result);
  172. }
  173. result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result);
  174. return result;
  175. }
  176. #ifdef ENABLE_NEON
  177. int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) {
  178. const int32x4_t shift_vec = vdupq_n_s32(-exponent);
  179. const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
  180. const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
  181. return vrshlq_s32(fixed_up_x, shift_vec);
  182. }
  183. int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { return vqrdmulhq_s32(a, b); }
  184. #endif