You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

float16.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CORE_BASE_FLOAT16_H_
  17. #define MINDSPORE_CORE_BASE_FLOAT16_H_
  18. #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64)
  19. // Built for lite and ARM
  20. #include <arm_neon.h>
  21. using float16 = float16_t;
  22. #else
  23. #include <cmath>
  24. #include <climits>
  25. #include <cstdint>
  26. #include <ostream>
  27. #include <limits>
  28. #include <functional>
  29. // Implement Float16 for mindspore, inspired by Eigen::half.
  30. namespace mindspore {
  31. class Float16 {
  32. public:
  33. static constexpr uint16_t value_mask = 0x7fff;
  34. static constexpr uint16_t nan_value = 0x7e00;
  35. static constexpr uint16_t inf_value = 0x7c00;
  36. static constexpr uint16_t true_value = 0x3c00;
  37. union Union32 {
  38. uint32_t u;
  39. float f;
  40. };
  41. Float16() = default;
  42. ~Float16() = default;
  43. Float16(const Float16 &other) noexcept = default;
  44. Float16(Float16 &&other) noexcept = default;
  45. Float16 &operator=(const Float16 &other) noexcept = default;
  46. Float16 &operator=(Float16 &&other) noexcept = default;
  47. static Float16 FromRaw(uint16_t v) {
  48. Float16 f;
  49. f.value_ = v;
  50. return f;
  51. }
  52. explicit Float16(float f) : value_(FromFloat32(f)) {}
  53. explicit Float16(bool b) : value_(b ? true_value : 0) {}
  54. template <typename T>
  55. explicit Float16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {}
  56. uint16_t int_value() const { return value_; }
  57. explicit operator bool() const { return (value_ & value_mask) != 0; }
  58. explicit operator float() const { return ToFloat32(*this); }
  59. explicit operator double() const { return static_cast<double>(ToFloat32(*this)); }
  60. explicit operator int8_t() const { return static_cast<int8_t>(ToFloat32(*this)); }
  61. explicit operator uint8_t() const { return static_cast<uint8_t>(ToFloat32(*this)); }
  62. explicit operator int16_t() const { return static_cast<int16_t>(ToFloat32(*this)); }
  63. explicit operator uint16_t() const { return static_cast<uint16_t>(ToFloat32(*this)); }
  64. explicit operator int32_t() const { return static_cast<int32_t>(ToFloat32(*this)); }
  65. explicit operator uint32_t() const { return static_cast<uint32_t>(ToFloat32(*this)); }
  66. explicit operator int64_t() const { return static_cast<int64_t>(ToFloat32(*this)); }
  67. explicit operator uint64_t() const { return static_cast<uint64_t>(ToFloat32(*this)); }
  68. Float16 &operator+=(const Float16 &b) {
  69. value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b));
  70. return *this;
  71. }
  72. Float16 &operator-=(const Float16 &b) {
  73. value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b));
  74. return *this;
  75. }
  76. Float16 &operator*=(const Float16 &b) {
  77. value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b));
  78. return *this;
  79. }
  80. Float16 &operator/=(const Float16 &b) {
  81. value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b));
  82. return *this;
  83. }
  84. static float ToFloat32(Float16 f16) {
  85. constexpr Union32 magic = {113 << 23};
  86. constexpr uint32_t exponent_adjust = ((127 - 15) << 23);
  87. constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23);
  88. constexpr uint32_t zero_extra_exp_adjust = (1 << 23);
  89. constexpr uint32_t sign_mask = 0x8000;
  90. constexpr unsigned int shifted_exp = (0x7c00 << 13); // Exponent mask after shift.
  91. constexpr unsigned int exponent_bits = 13;
  92. constexpr unsigned int sign_bit_shift = 16;
  93. // Exponent/mantissa bits.
  94. Union32 f32{.u = (static_cast<uint32_t>(f16.value_ & value_mask) << exponent_bits)};
  95. // Just the exponent.
  96. unsigned int exp = (shifted_exp & f32.u);
  97. f32.u += exponent_adjust;
  98. // Handle exponent special cases.
  99. if (exp == shifted_exp) {
  100. // Inf/NaN, extra exp adjust.
  101. f32.u += inf_extra_exp_adjust;
  102. } else if (exp == 0) {
  103. // Zero/Denormal, extra exp adjust and renormalize.
  104. f32.u += zero_extra_exp_adjust;
  105. f32.f -= magic.f;
  106. }
  107. // Set sign bit.
  108. f32.u |= ((f16.value_ & sign_mask) << sign_bit_shift);
  109. return f32.f;
  110. }
  111. private:
  112. static uint16_t FromFloat32(float f32) {
  113. constexpr uint32_t magic = {113 << 23};
  114. constexpr Union32 f32infty = {255 << 23};
  115. constexpr Union32 f16max = {(127 + 16) << 23};
  116. constexpr Union32 denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
  117. constexpr unsigned int exponent_bits = 13;
  118. constexpr unsigned int sign_bit_shift = 16;
  119. constexpr unsigned int sign_mask = 0x80000000u;
  120. constexpr uint32_t rouding_bias_part1 = ((unsigned int)(15 - 127) << 23) + 0xfff;
  121. Union32 f{.f = f32};
  122. unsigned int sign = f.u & sign_mask;
  123. f.u ^= sign;
  124. uint16_t result = 0;
  125. // NOTE all the integer compares in this function can be safely
  126. // compiled into signed compares since all operands are below
  127. // 0x80000000. Important if you want fast straight SSE2 code
  128. // (since there's no unsigned PCMPGTD).
  129. if (f.u >= f16max.u) {
  130. // Result is Inf or NaN (all exponent bits set).
  131. result = (f.u > f32infty.u) ? nan_value : inf_value;
  132. } else if (f.u < magic) {
  133. // (De)normalized number or zero; resulting FP16 is subnormal or zero.
  134. // Use a magic value to align our 10 mantissa bits at the bottom of
  135. // the float. as long as FP addition is round-to-nearest-even this
  136. // just works.
  137. f.f += denorm_magic.f;
  138. // And one integer subtract of the bias later, we have our final float!
  139. result = static_cast<uint16_t>(f.u - denorm_magic.u);
  140. } else {
  141. // Resulting mantissa is odd.
  142. unsigned int mant_odd = (f.u >> exponent_bits) & 1;
  143. // Update exponent, rounding bias part 1;
  144. f.u += rouding_bias_part1;
  145. // Rounding bias part 2;
  146. f.u += mant_odd;
  147. // Take the bits!
  148. result = static_cast<uint16_t>(f.u >> exponent_bits);
  149. }
  150. // Set sign bit.
  151. result |= static_cast<uint16_t>(sign >> sign_bit_shift);
  152. return result;
  153. }
  154. uint16_t value_;
  155. };
  156. inline Float16 operator+(const Float16 &a, const Float16 &b) {
  157. return Float16(static_cast<float>(a) + static_cast<float>(b));
  158. }
  159. inline Float16 operator*(const Float16 &a, const Float16 &b) {
  160. return Float16(static_cast<float>(a) * static_cast<float>(b));
  161. }
  162. inline Float16 operator-(const Float16 &a, const Float16 &b) {
  163. return Float16(static_cast<float>(a) - static_cast<float>(b));
  164. }
  165. inline Float16 operator/(const Float16 &a, const Float16 &b) {
  166. return Float16(static_cast<float>(a) / static_cast<float>(b));
  167. }
  168. // Division by an size_t. Do it in full float precision to avoid
  169. // accuracy issues in converting the denominator to float16.
  170. inline Float16 operator/(const Float16 &a, size_t b) { return Float16(static_cast<float>(a) / static_cast<float>(b)); }
  171. inline Float16 operator-(const Float16 &a) {
  172. constexpr uint16_t sign_mask = 0x8000;
  173. return Float16::FromRaw(a.int_value() ^ sign_mask);
  174. }
  175. inline bool operator==(const Float16 &a, const Float16 &b) {
  176. return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
  177. }
  178. inline bool operator!=(const Float16 &a, const Float16 &b) {
  179. return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
  180. }
  181. inline bool operator<(const Float16 &a, const Float16 &b) { return static_cast<float>(a) < static_cast<float>(b); }
  182. inline bool operator<=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) <= static_cast<float>(b); }
  183. inline bool operator>(const Float16 &a, const Float16 &b) { return static_cast<float>(a) > static_cast<float>(b); }
  184. inline bool operator>=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) >= static_cast<float>(b); }
  185. inline std::ostream &operator<<(std::ostream &os, const Float16 &v) { return (os << static_cast<float>(v)); }
  186. } // namespace mindspore
  187. using float16 = mindspore::Float16;
  188. namespace std {
  189. template <>
  190. struct hash<float16> {
  191. std::size_t operator()(const float16 &f16) const noexcept { return static_cast<std::size_t>(f16.int_value()); }
  192. };
  193. template <>
  194. struct numeric_limits<float16> {
  195. static constexpr bool is_specialized = true;
  196. static constexpr bool is_signed = true;
  197. static constexpr bool is_integer = false;
  198. static constexpr bool is_exact = false;
  199. static constexpr bool has_infinity = true;
  200. static constexpr bool has_quiet_NaN = true;
  201. static constexpr bool has_signaling_NaN = true;
  202. static constexpr std::float_denorm_style has_denorm = std::denorm_present;
  203. static constexpr bool has_denorm_loss = false;
  204. static constexpr std::float_round_style round_style = std::round_to_nearest;
  205. static constexpr bool is_iec559 = false;
  206. static constexpr bool is_bounded = false;
  207. static constexpr bool is_modulo = false;
  208. static constexpr int digits = 11;
  209. static constexpr int digits10 = 3;
  210. static constexpr int max_digits10 = 5;
  211. static constexpr int radix = 2;
  212. static constexpr int min_exponent = -13;
  213. static constexpr int min_exponent10 = -4;
  214. static constexpr int max_exponent = 16;
  215. static constexpr int max_exponent10 = 4;
  216. static constexpr bool traps = true;
  217. static constexpr bool tinyness_before = false;
  218. static constexpr uint16_t raw_min = 0x400;
  219. static constexpr uint16_t raw_max = 0x7bff;
  220. static constexpr uint16_t raw_lowest = 0xfbff;
  221. static constexpr uint16_t raw_epsilon = 0x0800;
  222. static constexpr float round_error_value = 0.5;
  223. static float16(min)() noexcept { return float16::FromRaw(raw_min); }
  224. static float16(max)() noexcept { return float16::FromRaw(raw_max); }
  225. static float16 lowest() noexcept { return float16::FromRaw(raw_lowest); }
  226. static float16 epsilon() noexcept { return float16::FromRaw(raw_epsilon); }
  227. static float16 round_error() noexcept { return float16(round_error_value); }
  228. static float16 infinity() noexcept { return float16::FromRaw(float16::inf_value); }
  229. static float16 quiet_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
  230. static float16 signaling_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
  231. static float16 denorm_min() noexcept { return float16::FromRaw(1); }
  232. };
  233. // If std::numeric_limits<T> is specialized, should also specialize
  234. // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
  235. // std::numeric_limits<const volatile T>
  236. // https://stackoverflow.com/a/16519653/
  237. template <>
  238. struct numeric_limits<const mindspore::Float16> : numeric_limits<mindspore::Float16> {};
  239. template <>
  240. struct numeric_limits<volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
  241. template <>
  242. struct numeric_limits<const volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
  243. } // namespace std
  244. // Implements standard math functions for float16.
  245. inline bool(isinf)(const float16 &a) { return (a.int_value() & float16::value_mask) == float16::inf_value; }
  246. inline bool(isnan)(const float16 &a) { return (a.int_value() & float16::value_mask) > float16::inf_value; }
  247. inline bool(isfinite)(const float16 &a) { return !(isinf(a)) && !(isnan(a)); }
  248. inline float16 abs(const float16 &a) { return float16::FromRaw(a.int_value() & float16::value_mask); }
  249. inline float16 exp(const float16 &a) { return float16(::expf(static_cast<float>(a))); }
  250. inline float16 log(const float16 &a) { return float16(::logf(static_cast<float>(a))); }
  251. inline float16 log1p(const float16 &a) { return float16(::log1pf(static_cast<float>(a))); }
  252. inline float16 log10(const float16 &a) { return float16(::log10f(static_cast<float>(a))); }
  253. inline float16 sqrt(const float16 &a) { return float16(::sqrtf(static_cast<float>(a))); }
  254. inline float16 sin(const float16 &a) { return float16(::sinf(static_cast<float>(a))); }
  255. inline float16 cos(const float16 &a) { return float16(::cosf(static_cast<float>(a))); }
  256. inline float16 tan(const float16 &a) { return float16(::tanf(static_cast<float>(a))); }
  257. inline float16 tanh(const float16 &a) { return float16(::tanhf(static_cast<float>(a))); }
  258. inline float16 floor(const float16 &a) { return float16(::floorf(static_cast<float>(a))); }
  259. inline float16 ceil(const float16 &a) { return float16(::ceilf(static_cast<float>(a))); }
  260. inline float16(min)(const float16 &a, const float16 &b) { return b < a ? b : a; }
  261. inline float16(max)(const float16 &a, const float16 &b) { return a < b ? b : a; }
  262. inline float16 pow(const float16 &a, const float16 &b) {
  263. return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
  264. }
  265. #endif // ENABLE_ARM32 || ENABLE_ARM64
  266. inline float half_to_float(float16 h) { return static_cast<float>(h); }
  267. #endif // MINDSPORE_CORE_BASE_FLOAT16_H_