You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_func_fp16.c 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp16/common_func_fp16.h"
  17. void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr,
  18. size_t output_channel, size_t plane_size, size_t oc_stride, size_t hw_stride,
  19. ActType act_type, int size) {
  20. if (size == 0) {
  21. return;
  22. }
  23. for (int oc = 0; oc < output_channel; oc++) {
  24. int oc_div = oc / size, oc_mod = oc % size;
  25. for (int hw = 0; hw < plane_size; hw++) {
  26. int src_index = oc_div * size * hw_stride + hw * size + oc_mod;
  27. int dst_index = hw * oc_stride + oc;
  28. float16_t value = src_ptr_[src_index];
  29. if (bias_ptr != NULL) {
  30. value = value + bias_ptr[oc];
  31. }
  32. value = (act_type == ActType_Relu || act_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value);
  33. value = (act_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value);
  34. out_ptr[dst_index] = value;
  35. }
  36. }
  37. return;
  38. }
  39. void PostConvFuncFp16C8(const float16_t *c8_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane,
  40. size_t oc_stride, ActType act_type) {
  41. size_t oc8mod = oc % C8NUM;
  42. size_t oc8div = oc - oc8mod;
  43. size_t stride_size = oc_stride * sizeof(float16_t);
  44. PostFuncBiasReluC8Fp16(nhwc_out, c8_out, bias, oc8div, oc8mod, plane, stride_size, act_type);
  45. return;
  46. }
  47. void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane,
  48. size_t plane_stride, ActType act_type) {
  49. size_t oc4mod = oc % C4NUM;
  50. size_t oc4div = oc - oc4mod;
  51. size_t stride_size = (plane_stride - plane) * C4NUM * sizeof(float16_t);
  52. PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type);
  53. return;
  54. }