You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_func_int8.c 2.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/int8/common_func_int8.h"
  17. #include "nnacl/quantization/fixed_point.h"
  18. void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane,
  19. size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi,
  20. int32_t left_shift, int32_t right_shift, int32_t zp, int size) {
  21. if (size == 0) {
  22. return;
  23. }
  24. for (int r = 0; r < plane; r++) {
  25. for (int c = 0; c < oc; c++) {
  26. int c8div = c / size, c8mod = c % size;
  27. int src_index = c8div * in_plane_stride + r * size + c8mod;
  28. int dst_index = r * out_oc_stride + c;
  29. int32_t value = in[src_index];
  30. if (bias != NULL) {
  31. value = in[src_index] + bias[c];
  32. }
  33. value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
  34. value = MSMIN(maxi, value);
  35. value = MSMAX(mini, value);
  36. out[dst_index] = (int8_t)value;
  37. }
  38. }
  39. return;
  40. }
  41. void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier,
  42. int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi) {
  43. /* ((int32_t)row8x8-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
  44. PostConvFuncCommInt8(in, out, bias, oc, plane, oc, UP_ROUND(plane, C8NUM) * C8NUM, multiplier, mini, maxi, left_shift,
  45. right_shift, zp, C8NUM);
  46. return;
  47. }
  48. void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride,
  49. int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini,
  50. int32_t maxi) {
  51. /* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
  52. #ifndef ENABLE_ARM64
  53. PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi,
  54. left_shift, right_shift, zp, C4NUM);
  55. #else
  56. size_t oc4div = oc / C4NUM * C4NUM;
  57. size_t oc4res = oc % C4NUM;
  58. PostFuncInt8C4Neon64(in, bias, out, oc4div, oc4res, plane, stride * sizeof(int8_t), multiplier, left_shift,
  59. right_shift, zp, mini, maxi);
  60. #endif
  61. return;
  62. }