You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_func.c 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp32/common_func.h"
  17. void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel,
  18. size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) {
  19. int oc_div = 0, oc_mod = 0;
  20. for (int oc = 0; oc < output_channel; oc++) {
  21. if (size != 0) {
  22. oc_div = oc / size;
  23. oc_mod = oc % size;
  24. } else {
  25. return;
  26. }
  27. for (int hw = 0; hw < plane_size; hw++) {
  28. int src_index = oc_div * size * plane_stride + hw * size + oc_mod;
  29. int dst_index = hw * oc_stride + oc;
  30. float value = src_ptr_[src_index];
  31. if (bias_ptr != NULL) {
  32. value = value + bias_ptr[oc];
  33. }
  34. value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value);
  35. value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value);
  36. out_ptr[dst_index] = value;
  37. }
  38. }
  39. return;
  40. }
  41. void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
  42. size_t plane_size, size_t stride, size_t relu_type) {
  43. #ifndef ENABLE_ARM
  44. PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM);
  45. #else
  46. size_t oc8mod = output_channel % C8NUM;
  47. size_t oc8div = output_channel - oc8mod;
  48. size_t stride_size = stride * sizeof(float);
  49. PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type);
  50. #endif
  51. return;
  52. }
  53. void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
  54. size_t plane_size, size_t plane_stride, size_t relu_type) {
  55. #ifdef ENABLE_ARM
  56. size_t oc4mod = output_channel % C4NUM;
  57. size_t oc4div = output_channel - oc4mod;
  58. size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float);
  59. PostFuncBiasReluC4(out_ptr, c4_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type);
  60. #else
  61. PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type,
  62. C4NUM);
  63. #endif
  64. return;
  65. }
  66. #ifndef ENABLE_ARM
  67. void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
  68. int unitStep = 4 * length;
  69. for (int y = 0; y < h; ++y) {
  70. float *dstY = M + y * w * unitStep;
  71. for (int x = 0; x < w; ++x) {
  72. float *dstX = dstY + x * unitStep;
  73. const float *srcX = S + x * unitStep;
  74. memset(dstX, 0, unitStep * sizeof(float));
  75. for (int i = 0; i < k; ++i) {
  76. float b = B[i * h + y];
  77. const float *srcY = srcX + i * w * unitStep;
  78. if (0.0f == b) {
  79. continue;
  80. }
  81. for (int j = 0; j < unitStep; ++j) {
  82. dstX[j] += srcY[j] * b;
  83. }
  84. }
  85. }
  86. }
  87. }
  88. // M = S * B , M = w*h * l, S = k*h * l, B = w*k
  89. void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
  90. int unitStep = 4 * length;
  91. for (int y = 0; y < h; ++y) {
  92. float *dstY = M + y * w * unitStep;
  93. const float *srcY = S + y * k * unitStep;
  94. for (int x = 0; x < w; ++x) {
  95. float *dstX = dstY + x * unitStep;
  96. memset(dstX, 0, unitStep * sizeof(float));
  97. for (int i = 0; i < k; ++i) {
  98. const float *srcX = srcY + i * unitStep;
  99. float b = B[i * h + x];
  100. if (0.0f == b) {
  101. continue;
  102. }
  103. for (int j = 0; j < unitStep; ++j) {
  104. dstX[j] += srcX[j] * b;
  105. }
  106. }
  107. }
  108. }
  109. }
  110. #endif
  111. union float32_bits {
  112. unsigned int u;
  113. float f;
  114. };
  115. typedef union float32_bits float32_bits;
  116. float ShortToFloat32(uint16_t src_value) {
  117. const float32_bits magic = {113 << 23};
  118. const unsigned int shifted_exp = 0x7c00 << 13;
  119. float32_bits o;
  120. o.u = (src_value & 0x7fff) << 13;
  121. unsigned int exp = shifted_exp & o.u;
  122. o.u += (127 - 15) << 23;
  123. if (exp == shifted_exp) {
  124. o.u += (128 - 16) << 23;
  125. } else if (exp == 0) {
  126. o.u += 1 << 23;
  127. o.f -= magic.f;
  128. }
  129. o.u |= (src_value & 0x8000) << 16;
  130. return o.f;
  131. }
  132. static const unsigned int FP32_BIT_SIZE = 32;
  133. static const unsigned int FP32_EXPONENT_BIAS = 127;
  134. static const unsigned int FP32_SIGNIFICAND = 23;
  135. static const unsigned int FP32_EXPONENT_MAX = 255;
  136. static const unsigned int FP16_BIT_SIZE = 16;
  137. static const unsigned int FP16_EXPONENT_BIAS = 15;
  138. static const unsigned int FP16_SIGNIFICAND = 10;
  139. static const int FP16_EXPONENT_MAX = 30;
  140. static const int FP16_EXPONENT_MIN = -10;
  141. uint16_t Float32ToShort(float src_value) {
  142. float *psrcValue = NULL;
  143. psrcValue = &src_value;
  144. unsigned int srcValueBit = (unsigned int)(*psrcValue);
  145. unsigned int sign = srcValueBit >> (FP32_BIT_SIZE - 1);
  146. unsigned int mantissa = srcValueBit & 0x007FFFFF;
  147. // exponent
  148. int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS;
  149. uint16_t res;
  150. if (exp > 0 && exp < FP16_EXPONENT_MAX) {
  151. // use rte rounding mode, round the significand, combine sign, exponent and significand into a short.
  152. res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) |
  153. ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
  154. } else if (srcValueBit == 0) {
  155. res = 0;
  156. } else {
  157. if (exp <= 0) {
  158. if (exp < FP16_EXPONENT_MIN) {
  159. // value is less than min half float point
  160. res = 0;
  161. } else {
  162. // normalized single, magnitude is less than min normal half float point.
  163. mantissa = (mantissa | 0x00800000) >> (1 - exp);
  164. // round to nearest
  165. if ((mantissa & 0x00001000) > 0) {
  166. mantissa = mantissa + 0x00002000;
  167. }
  168. // combine sign & mantissa (exp is zero to get denormalized number)
  169. res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
  170. }
  171. } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) {
  172. if (mantissa == 0) {
  173. // input float is infinity, return infinity half
  174. res = (sign << FP16_EXPONENT_BIAS) | 0x7C00;
  175. } else {
  176. // input float is NaN, return half NaN
  177. res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
  178. }
  179. } else {
  180. // exp > 0, normalized single, round to nearest
  181. if ((mantissa & 0x00001000) > 0) {
  182. mantissa = mantissa + 0x00002000;
  183. if ((mantissa & 0x00800000) > 0) {
  184. mantissa = 0;
  185. exp = exp + 1;
  186. }
  187. }
  188. if (exp > FP16_EXPONENT_MAX) {
  189. // exponent overflow - return infinity half
  190. res = (sign << FP16_EXPONENT_BIAS) | 0x7C00;
  191. } else {
  192. // combine sign, exp and mantissa into normalized half
  193. res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) |
  194. (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND));
  195. }
  196. }
  197. }
  198. return res;
  199. }