You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_func.c 3.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/common_func.h"
  17. #include "nnacl/quantization/fixed_point.h"
  18. int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) {
  19. return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3;
  20. }
  21. int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2) {
  22. return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3];
  23. }
  24. int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); }
  25. int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }
  26. int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); }
  27. void ReluFp32(float *data, float *dst, int ele_num) {
  28. int four_block = UP_DIV(ele_num, C4NUM);
  29. for (int i = 0; i < four_block - 1; i++) {
  30. int index = i * C4NUM;
  31. #ifdef ENABLE_NEON
  32. float32x4_t relu_data = vld1q_f32(data + index);
  33. float32x4_t zero_data = vdupq_n_f32(0);
  34. relu_data = vmaxq_f32(relu_data, zero_data);
  35. vst1q_f32(dst + index, relu_data);
  36. #else
  37. data[index] = data[index] < 0 ? 0 : data[index];
  38. data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
  39. data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
  40. data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
  41. #endif
  42. }
  43. for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
  44. data[j] = data[j] < 0 ? 0 : data[j];
  45. }
  46. }
  47. void Relu6Fp32(float *data, float *dst, int ele_num) {
  48. int four_block = UP_DIV(ele_num, C4NUM);
  49. for (int i = 0; i < four_block - 1; i++) {
  50. int index = i * C4NUM;
  51. #ifdef ENABLE_NEON
  52. float32x4_t relu6_data = vld1q_f32(data + index);
  53. float32x4_t zero_data = vdupq_n_f32(0);
  54. float32x4_t six_data = vdupq_n_f32(6);
  55. relu6_data = vmaxq_f32(relu6_data, zero_data);
  56. relu6_data = vminq_f32(relu6_data, six_data);
  57. vst1q_f32(dst + index, relu6_data);
  58. #else
  59. data[index] = data[index] < 0 ? 0 : data[index];
  60. data[index] = data[index] > 6 ? 6 : data[index];
  61. data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
  62. data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
  63. data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
  64. data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
  65. data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
  66. data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
  67. #endif
  68. }
  69. for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
  70. data[j] = data[j] < 0 ? 0 : data[j];
  71. data[j] = data[j] > 6 ? 6 : data[j];
  72. }
  73. }