You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.c 3.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp32/batchnorm.h"
  17. #include <math.h>
  18. #include "nnacl/batchnorm_parameter.h"
  19. #include "nnacl/op_base.h"
  20. #include "nnacl/errorcode.h"
  21. void BatchNormFp32(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id,
  22. void *output) {
  23. int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
  24. int completed_units = task_id * units_per_thread;
  25. int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
  26. int cur_offset = completed_units * param->channel_;
  27. for (int i = 0; i < cur_unit; i++) {
  28. for (int c = 0; c < param->channel_; c++) {
  29. float variance_sqrt = sqrt(((const float *)variance)[c] + param->epsilon_);
  30. ((float *)output)[cur_offset + c] =
  31. (((const float *)input)[cur_offset + c] - ((const float *)mean)[c]) / variance_sqrt;
  32. }
  33. cur_offset += param->channel_;
  34. }
  35. }
  36. void FusedBatchNormFp32(const void *input, const void *scale, const void *offset, const void *mean,
  37. const void *variance, BatchNormParameter *param, int task_id, void *output) {
  38. int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
  39. int completed_units = task_id * units_per_thread;
  40. int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
  41. int cur_offset = completed_units * param->channel_;
  42. for (int i = 0; i < cur_unit; i++) {
  43. for (int c = 0; c < param->channel_; c++) {
  44. float variance_sqrt = sqrt(((const float *)variance)[c] + param->epsilon_);
  45. float norm_val = (((const float *)input)[cur_offset + c] - ((const float *)mean)[c]) / variance_sqrt;
  46. ((float *)output)[cur_offset + c] = norm_val * ((const float *)scale)[c] + ((const float *)offset)[c];
  47. }
  48. cur_offset += param->channel_;
  49. }
  50. }
  51. void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, BatchNormParameter *param,
  52. float *save_mean, float *save_var) {
  53. float N = (float)param->unit_;
  54. for (int i = 0; i < param->unit_; i++) {
  55. for (int c = 0; c < param->channel_; c++) {
  56. int idx = i * param->channel_ + c;
  57. run_mean[c] += input[idx];
  58. run_var[c] += input[idx] * input[idx];
  59. }
  60. }
  61. const float VN = (N > 1.0f) ? (N - 1.0f) : 1.0f;
  62. for (int c = 0; c < param->channel_; c++) {
  63. run_mean[c] = run_mean[c] / N;
  64. run_var[c] = run_var[c] / VN - run_mean[c] * run_mean[c];
  65. save_mean[c] = param->momentum_ * save_mean[c] + (1 - param->momentum_) * run_mean[c];
  66. const float var = run_var[c];
  67. save_var[c] = param->momentum_ * save_var[c] + (1 - param->momentum_) * var;
  68. }
  69. }