You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm_fp16.c 2.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp16/batchnorm_fp16.h"
  17. #include <math.h>
  18. void BatchNormFp16(const float16_t *input, const void *mean, const void *variance,
  19. BatchNormParameter *param, int task_id, float16_t *output) {
  20. int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
  21. int completed_units = task_id * units_per_thread;
  22. int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
  23. int cur_offset = completed_units * param->channel_;
  24. for (int i = 0; i < cur_unit; i++) {
  25. for (int c = 0; c < param->channel_; c++) {
  26. float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
  27. if (variance_sqrt != 0) {
  28. output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
  29. }
  30. }
  31. cur_offset += param->channel_;
  32. }
  33. }
  34. void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
  35. const void *variance, BatchNormParameter *param, int task_id, void *output) {
  36. int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
  37. int completed_units = task_id * units_per_thread;
  38. int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
  39. int cur_offset = completed_units * param->channel_;
  40. for (int i = 0; i < cur_unit; i++) {
  41. for (int c = 0; c < param->channel_; c++) {
  42. float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
  43. if (variance_sqrt != 0) {
  44. float16_t norm_val =
  45. (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
  46. ((float16_t *)output)[cur_offset + c] =
  47. norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
  48. }
  49. }
  50. cur_offset += param->channel_;
  51. }
  52. }