You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_func.c 5.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/common_func.h"
  17. int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) {
  18. return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3;
  19. }
  20. int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2) {
  21. return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3];
  22. }
  23. int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); }
  24. int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }
  25. int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); }
  26. void ReluFp32(float *data, float *dst, int ele_num) {
  27. int four_block = UP_DIV(ele_num, C4NUM);
  28. for (int i = 0; i < four_block - 1; i++) {
  29. int index = i * C4NUM;
  30. #ifdef ENABLE_NEON
  31. float32x4_t relu_data = vld1q_f32(data + index);
  32. float32x4_t zero_data = vdupq_n_f32(0);
  33. relu_data = vmaxq_f32(relu_data, zero_data);
  34. vst1q_f32(dst + index, relu_data);
  35. #else
  36. data[index] = data[index] < 0 ? 0 : data[index];
  37. data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
  38. data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
  39. data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
  40. #endif
  41. }
  42. for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
  43. data[j] = data[j] < 0 ? 0 : data[j];
  44. }
  45. }
  46. void Relu6Fp32(float *data, float *dst, int ele_num) {
  47. int four_block = UP_DIV(ele_num, C4NUM);
  48. for (int i = 0; i < four_block - 1; i++) {
  49. int index = i * C4NUM;
  50. #ifdef ENABLE_NEON
  51. float32x4_t relu6_data = vld1q_f32(data + index);
  52. float32x4_t zero_data = vdupq_n_f32(0);
  53. float32x4_t six_data = vdupq_n_f32(6);
  54. relu6_data = vmaxq_f32(relu6_data, zero_data);
  55. relu6_data = vminq_f32(relu6_data, six_data);
  56. vst1q_f32(dst + index, relu6_data);
  57. #else
  58. data[index] = data[index] < 0 ? 0 : data[index];
  59. data[index] = data[index] > 6 ? 6 : data[index];
  60. data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
  61. data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
  62. data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
  63. data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
  64. data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
  65. data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
  66. #endif
  67. }
  68. for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
  69. data[j] = data[j] < 0 ? 0 : data[j];
  70. data[j] = data[j] > 6 ? 6 : data[j];
  71. }
  72. }
  73. #ifdef ENABLE_AVX
  74. #ifdef WIN32
  75. void ReluFp32C8(float *data, float *dst, int ele_num) {
  76. int four_block = UP_DIV(ele_num, C8NUM);
  77. for (int i = 0; i < four_block - 1; i++) {
  78. int index = i * C8NUM;
  79. data[index] = data[index] < 0 ? 0 : data[index];
  80. data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
  81. data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
  82. data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
  83. data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4];
  84. data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5];
  85. data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6];
  86. data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7];
  87. }
  88. for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) {
  89. data[j] = data[j] < 0 ? 0 : data[j];
  90. }
  91. }
  92. void Relu6Fp32C8(float *data, float *dst, int ele_num) {
  93. int four_block = UP_DIV(ele_num, C8NUM);
  94. for (int i = 0; i < four_block - 1; i++) {
  95. int index = i * C8NUM;
  96. data[index] = data[index] < 0 ? 0 : data[index];
  97. data[index] = data[index] > 6 ? 6 : data[index];
  98. data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
  99. data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
  100. data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
  101. data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
  102. data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
  103. data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
  104. data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4];
  105. data[index + 4] = data[index + 4] > 6 ? 6 : data[index + 4];
  106. data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5];
  107. data[index + 5] = data[index + 5] > 6 ? 6 : data[index + 5];
  108. data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6];
  109. data[index + 6] = data[index + 6] > 6 ? 6 : data[index + 6];
  110. data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7];
  111. data[index + 7] = data[index + 7] > 6 ? 6 : data[index + 7];
  112. }
  113. for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) {
  114. data[j] = data[j] < 0 ? 0 : data[j];
  115. data[j] = data[j] > 6 ? 6 : data[j];
  116. }
  117. }
  118. #endif
  119. #endif