You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

float16.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CORE_BASE_FLOAT16_H_
  17. #define MINDSPORE_CORE_BASE_FLOAT16_H_
  18. #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64)
  19. // Built for lite and ARM
  20. #include <arm_neon.h>
  21. using float16 = float16_t;
  22. #else
  23. #include <cmath>
  24. #include <climits>
  25. #include <cstdint>
  26. #include <ostream>
  27. #include <limits>
  28. #include <functional>
  29. // Implement Float16 for mindspore, inspired by Eigen::half.
  30. namespace mindspore {
  31. class Float16 {
  32. public:
  33. static constexpr uint16_t value_mask = 0x7fff;
  34. static constexpr uint16_t nan_value = 0x7e00;
  35. static constexpr uint16_t inf_value = 0x7c00;
  36. static constexpr uint16_t true_value = 0x3c00;
  37. union Union32 {
  38. uint32_t u;
  39. float f;
  40. };
  41. Float16() = default;
  42. ~Float16() = default;
  43. Float16(const Float16 &other) noexcept = default;
  44. Float16(Float16 &&other) noexcept = default;
  45. Float16 &operator=(const Float16 &other) noexcept = default;
  46. Float16 &operator=(Float16 &&other) noexcept = default;
  47. static Float16 FromRaw(uint16_t v) {
  48. Float16 f;
  49. f.value_ = v;
  50. return f;
  51. }
  52. explicit Float16(float f) : value_(FromFloat32(f)) {}
  53. explicit Float16(bool b) : value_(b ? true_value : 0) {}
  54. template <typename T>
  55. explicit Float16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {}
  56. uint16_t int_value() const { return value_; }
  57. explicit operator bool() const { return (value_ & value_mask) != 0; }
  58. explicit operator float() const { return ToFloat32(*this); }
  59. explicit operator double() const { return static_cast<double>(ToFloat32(*this)); }
  60. explicit operator int8_t() const { return static_cast<int8_t>(ToFloat32(*this)); }
  61. explicit operator uint8_t() const { return static_cast<uint8_t>(ToFloat32(*this)); }
  62. explicit operator int16_t() const { return static_cast<int16_t>(ToFloat32(*this)); }
  63. explicit operator uint16_t() const { return static_cast<uint16_t>(ToFloat32(*this)); }
  64. explicit operator int32_t() const { return static_cast<int32_t>(ToFloat32(*this)); }
  65. explicit operator uint32_t() const { return static_cast<uint32_t>(ToFloat32(*this)); }
  66. explicit operator int64_t() const { return static_cast<int64_t>(ToFloat32(*this)); }
  67. explicit operator uint64_t() const { return static_cast<uint64_t>(ToFloat32(*this)); }
  68. Float16 &operator+=(const Float16 &b) {
  69. value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b));
  70. return *this;
  71. }
  72. Float16 &operator-=(const Float16 &b) {
  73. value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b));
  74. return *this;
  75. }
  76. Float16 &operator*=(const Float16 &b) {
  77. value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b));
  78. return *this;
  79. }
  80. Float16 &operator/=(const Float16 &b) {
  81. value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b));
  82. return *this;
  83. }
  84. static float ToFloat32(Float16 f16) {
  85. constexpr Union32 magic = {113 << 23};
  86. constexpr uint32_t exponent_adjust = ((127 - 15) << 23);
  87. constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23);
  88. constexpr uint32_t zero_extra_exp_adjust = (1 << 23);
  89. constexpr uint32_t sign_mask = 0x8000;
  90. constexpr unsigned int shifted_exp = (0x7c00 << 13); // Exponent mask after shift.
  91. constexpr unsigned int exponent_bits = 13;
  92. constexpr unsigned int sign_bit_shift = 16;
  93. // Exponent/mantissa bits.
  94. Union32 f32;
  95. f32.u = (static_cast<uint32_t>(f16.value_ & value_mask) << exponent_bits);
  96. // Just the exponent.
  97. unsigned int exp = (shifted_exp & f32.u);
  98. f32.u += exponent_adjust;
  99. // Handle exponent special cases.
  100. if (exp == shifted_exp) {
  101. // Inf/NaN, extra exp adjust.
  102. f32.u += inf_extra_exp_adjust;
  103. } else if (exp == 0) {
  104. // Zero/Denormal, extra exp adjust and renormalize.
  105. f32.u += zero_extra_exp_adjust;
  106. f32.f -= magic.f;
  107. }
  108. // Set sign bit.
  109. f32.u |= ((f16.value_ & sign_mask) << sign_bit_shift);
  110. return f32.f;
  111. }
  112. private:
  113. static uint16_t FromFloat32(float f32) {
  114. constexpr uint32_t magic = {113 << 23};
  115. constexpr Union32 f32infty = {255 << 23};
  116. constexpr Union32 f16max = {(127 + 16) << 23};
  117. constexpr Union32 denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
  118. constexpr unsigned int exponent_bits = 13;
  119. constexpr unsigned int sign_bit_shift = 16;
  120. constexpr unsigned int sign_mask = 0x80000000u;
  121. constexpr uint32_t rouding_bias_part1 = ((unsigned int)(15 - 127) << 23) + 0xfff;
  122. Union32 f;
  123. f.f = f32;
  124. unsigned int sign = f.u & sign_mask;
  125. f.u ^= sign;
  126. uint16_t result = 0;
  127. // NOTE all the integer compares in this function can be safely
  128. // compiled into signed compares since all operands are below
  129. // 0x80000000. Important if you want fast straight SSE2 code
  130. // (since there's no unsigned PCMPGTD).
  131. if (f.u >= f16max.u) {
  132. // Result is Inf or NaN (all exponent bits set).
  133. result = (f.u > f32infty.u) ? nan_value : inf_value;
  134. } else if (f.u < magic) {
  135. // (De)normalized number or zero; resulting FP16 is subnormal or zero.
  136. // Use a magic value to align our 10 mantissa bits at the bottom of
  137. // the float. as long as FP addition is round-to-nearest-even this
  138. // just works.
  139. f.f += denorm_magic.f;
  140. // And one integer subtract of the bias later, we have our final float!
  141. result = static_cast<uint16_t>(f.u - denorm_magic.u);
  142. } else {
  143. // Resulting mantissa is odd.
  144. unsigned int mant_odd = (f.u >> exponent_bits) & 1;
  145. // Update exponent, rounding bias part 1;
  146. f.u += rouding_bias_part1;
  147. // Rounding bias part 2;
  148. f.u += mant_odd;
  149. // Take the bits!
  150. result = static_cast<uint16_t>(f.u >> exponent_bits);
  151. }
  152. // Set sign bit.
  153. result |= static_cast<uint16_t>(sign >> sign_bit_shift);
  154. return result;
  155. }
  156. uint16_t value_;
  157. };
  158. inline Float16 operator+(const Float16 &a, const Float16 &b) {
  159. return Float16(static_cast<float>(a) + static_cast<float>(b));
  160. }
  161. inline Float16 operator*(const Float16 &a, const Float16 &b) {
  162. return Float16(static_cast<float>(a) * static_cast<float>(b));
  163. }
  164. inline Float16 operator-(const Float16 &a, const Float16 &b) {
  165. return Float16(static_cast<float>(a) - static_cast<float>(b));
  166. }
  167. inline Float16 operator/(const Float16 &a, const Float16 &b) {
  168. return Float16(static_cast<float>(a) / static_cast<float>(b));
  169. }
  170. // Division by an size_t. Do it in full float precision to avoid
  171. // accuracy issues in converting the denominator to float16.
  172. inline Float16 operator/(const Float16 &a, size_t b) { return Float16(static_cast<float>(a) / static_cast<float>(b)); }
  173. inline Float16 operator-(const Float16 &a) {
  174. constexpr uint16_t sign_mask = 0x8000;
  175. return Float16::FromRaw(a.int_value() ^ sign_mask);
  176. }
  177. inline bool operator==(const Float16 &a, const Float16 &b) {
  178. return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
  179. }
  180. inline bool operator!=(const Float16 &a, const Float16 &b) {
  181. return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
  182. }
  183. inline bool operator<(const Float16 &a, const Float16 &b) { return static_cast<float>(a) < static_cast<float>(b); }
  184. inline bool operator<=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) <= static_cast<float>(b); }
  185. inline bool operator>(const Float16 &a, const Float16 &b) { return static_cast<float>(a) > static_cast<float>(b); }
  186. inline bool operator>=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) >= static_cast<float>(b); }
  187. inline std::ostream &operator<<(std::ostream &os, const Float16 &v) { return (os << static_cast<float>(v)); }
  188. } // namespace mindspore
  189. using float16 = mindspore::Float16;
  190. namespace std {
  191. template <>
  192. struct hash<float16> {
  193. std::size_t operator()(const float16 &f16) const noexcept { return static_cast<std::size_t>(f16.int_value()); }
  194. };
  195. template <>
  196. struct numeric_limits<float16> {
  197. static constexpr bool is_specialized = true;
  198. static constexpr bool is_signed = true;
  199. static constexpr bool is_integer = false;
  200. static constexpr bool is_exact = false;
  201. static constexpr bool has_infinity = true;
  202. static constexpr bool has_quiet_NaN = true;
  203. static constexpr bool has_signaling_NaN = true;
  204. static constexpr std::float_denorm_style has_denorm = std::denorm_present;
  205. static constexpr bool has_denorm_loss = false;
  206. static constexpr std::float_round_style round_style = std::round_to_nearest;
  207. static constexpr bool is_iec559 = false;
  208. static constexpr bool is_bounded = false;
  209. static constexpr bool is_modulo = false;
  210. static constexpr int digits = 11;
  211. static constexpr int digits10 = 3;
  212. static constexpr int max_digits10 = 5;
  213. static constexpr int radix = 2;
  214. static constexpr int min_exponent = -13;
  215. static constexpr int min_exponent10 = -4;
  216. static constexpr int max_exponent = 16;
  217. static constexpr int max_exponent10 = 4;
  218. static constexpr bool traps = true;
  219. static constexpr bool tinyness_before = false;
  220. static constexpr uint16_t raw_min = 0x400;
  221. static constexpr uint16_t raw_max = 0x7bff;
  222. static constexpr uint16_t raw_lowest = 0xfbff;
  223. static constexpr uint16_t raw_epsilon = 0x0800;
  224. static constexpr float round_error_value = 0.5;
  225. static float16(min)() noexcept { return float16::FromRaw(raw_min); }
  226. static float16(max)() noexcept { return float16::FromRaw(raw_max); }
  227. static float16 lowest() noexcept { return float16::FromRaw(raw_lowest); }
  228. static float16 epsilon() noexcept { return float16::FromRaw(raw_epsilon); }
  229. static float16 round_error() noexcept { return float16(round_error_value); }
  230. static float16 infinity() noexcept { return float16::FromRaw(float16::inf_value); }
  231. static float16 quiet_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
  232. static float16 signaling_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
  233. static float16 denorm_min() noexcept { return float16::FromRaw(1); }
  234. };
  235. // If std::numeric_limits<T> is specialized, should also specialize
  236. // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
  237. // std::numeric_limits<const volatile T>
  238. // https://stackoverflow.com/a/16519653/
  239. template <>
  240. struct numeric_limits<const mindspore::Float16> : numeric_limits<mindspore::Float16> {};
  241. template <>
  242. struct numeric_limits<volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
  243. template <>
  244. struct numeric_limits<const volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
  245. } // namespace std
  246. // Implements standard math functions for float16.
  247. inline bool(isinf)(const float16 &a) { return (a.int_value() & float16::value_mask) == float16::inf_value; }
  248. inline bool(isnan)(const float16 &a) { return (a.int_value() & float16::value_mask) > float16::inf_value; }
  249. inline bool(isfinite)(const float16 &a) { return !(isinf(a)) && !(isnan(a)); }
  250. inline float16 abs(const float16 &a) { return float16::FromRaw(a.int_value() & float16::value_mask); }
  251. inline float16 exp(const float16 &a) { return float16(::expf(static_cast<float>(a))); }
  252. inline float16 log(const float16 &a) { return float16(::logf(static_cast<float>(a))); }
  253. inline float16 log1p(const float16 &a) { return float16(::log1pf(static_cast<float>(a))); }
  254. inline float16 log10(const float16 &a) { return float16(::log10f(static_cast<float>(a))); }
  255. inline float16 sqrt(const float16 &a) { return float16(::sqrtf(static_cast<float>(a))); }
  256. inline float16 sin(const float16 &a) { return float16(::sinf(static_cast<float>(a))); }
  257. inline float16 cos(const float16 &a) { return float16(::cosf(static_cast<float>(a))); }
  258. inline float16 tan(const float16 &a) { return float16(::tanf(static_cast<float>(a))); }
  259. inline float16 tanh(const float16 &a) { return float16(::tanhf(static_cast<float>(a))); }
  260. inline float16 floor(const float16 &a) { return float16(::floorf(static_cast<float>(a))); }
  261. inline float16 ceil(const float16 &a) { return float16(::ceilf(static_cast<float>(a))); }
  262. inline float16(min)(const float16 &a, const float16 &b) { return b < a ? b : a; }
  263. inline float16(max)(const float16 &a, const float16 &b) { return a < b ? b : a; }
  264. inline float16 pow(const float16 &a, const float16 &b) {
  265. return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
  266. }
  267. #endif // ENABLE_ARM32 || ENABLE_ARM64
  268. inline float half_to_float(float16 h) { return static_cast<float>(h); }
  269. #endif // MINDSPORE_CORE_BASE_FLOAT16_H_