You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_func_fp32.c 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp32/common_func_fp32.h"
  17. void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel,
  18. size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) {
  19. int oc_div = 0, oc_mod = 0;
  20. for (int oc = 0; oc < output_channel; oc++) {
  21. if (size != 0) {
  22. oc_div = oc / size;
  23. oc_mod = oc % size;
  24. } else {
  25. return;
  26. }
  27. for (int hw = 0; hw < plane_size; hw++) {
  28. int src_index = oc_div * size * plane_stride + hw * size + oc_mod;
  29. int dst_index = hw * oc_stride + oc;
  30. float value = src_ptr_[src_index];
  31. if (bias_ptr != NULL) {
  32. value = value + bias_ptr[oc];
  33. }
  34. value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value);
  35. value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value);
  36. out_ptr[dst_index] = value;
  37. }
  38. }
  39. return;
  40. }
  41. void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
  42. size_t plane_size, size_t stride, size_t relu_type) {
  43. #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
  44. PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM);
  45. #else
  46. size_t oc8mod = output_channel % C8NUM;
  47. size_t oc8div = output_channel - oc8mod;
  48. size_t stride_size = stride * sizeof(float);
  49. PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type);
  50. #endif
  51. return;
  52. }
  53. void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
  54. size_t plane_size, size_t plane_stride, size_t relu_type) {
  55. #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
  56. size_t oc4mod = output_channel % C4NUM;
  57. size_t oc4div = output_channel - oc4mod;
  58. size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float);
  59. PostFuncBiasReluC4(out_ptr, c4_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type);
  60. #else
  61. PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type,
  62. C4NUM);
  63. #endif
  64. return;
  65. }
  66. #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
  67. void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
  68. const int unitStep = 4 * length;
  69. for (int y = 0; y < h; ++y) {
  70. float *dstY = M + y * w * unitStep;
  71. for (int x = 0; x < w; ++x) {
  72. float *dstX = dstY + x * unitStep;
  73. const float *srcX = S + x * unitStep;
  74. memset(dstX, 0, unitStep * sizeof(float));
  75. for (int i = 0; i < k; ++i) {
  76. float b = B[i * h + y];
  77. const float *srcY = srcX + i * w * unitStep;
  78. if (0.0f == b) {
  79. continue;
  80. }
  81. for (int j = 0; j < unitStep; ++j) {
  82. dstX[j] += srcY[j] * b;
  83. }
  84. }
  85. }
  86. }
  87. }
  88. // M = S * B , M = w*h * l, S = k*h * l, B = w*k
  89. void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
  90. const int unitStep = 4 * length;
  91. for (int y = 0; y < h; ++y) {
  92. float *dstY = M + y * w * unitStep;
  93. const float *srcY = S + y * k * unitStep;
  94. for (int x = 0; x < w; ++x) {
  95. float *dstX = dstY + x * unitStep;
  96. memset(dstX, 0, unitStep * sizeof(float));
  97. for (int i = 0; i < k; ++i) {
  98. const float *srcX = srcY + i * unitStep;
  99. float b = B[i * h + x];
  100. if (0.0f == b) {
  101. continue;
  102. }
  103. for (int j = 0; j < unitStep; ++j) {
  104. dstX[j] += srcX[j] * b;
  105. }
  106. }
  107. }
  108. }
  109. }
  110. #endif