You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

groupnorm.cpp 7.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "groupnorm.h"
  15. namespace ncnn {
  16. GroupNorm::GroupNorm()
  17. {
  18. one_blob_only = true;
  19. support_inplace = true;
  20. }
  21. int GroupNorm::load_param(const ParamDict& pd)
  22. {
  23. group = pd.get(0, 1);
  24. channels = pd.get(1, 0);
  25. eps = pd.get(2, 0.001f);
  26. affine = pd.get(3, 1);
  27. return 0;
  28. }
  29. int GroupNorm::load_model(const ModelBin& mb)
  30. {
  31. if (affine == 0)
  32. return 0;
  33. gamma_data = mb.load(channels, 1);
  34. if (gamma_data.empty())
  35. return -100;
  36. beta_data = mb.load(channels, 1);
  37. if (beta_data.empty())
  38. return -100;
  39. return 0;
  40. }
  41. int GroupNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
  42. {
  43. const int dims = bottom_top_blob.dims;
  44. const int channels_per_group = channels / group;
  45. if (dims == 1)
  46. {
  47. #pragma omp parallel for num_threads(opt.num_threads)
  48. for (int g = 0; g < group; g++)
  49. {
  50. Mat bottom_top_blob_g = bottom_top_blob.range(g * channels_per_group, channels_per_group);
  51. const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
  52. const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);
  53. // mean and var
  54. float sum = 0.f;
  55. for (int q = 0; q < channels_per_group; q++)
  56. {
  57. sum += bottom_top_blob_g[q];
  58. }
  59. float mean = sum / channels_per_group;
  60. float sqsum = 0.f;
  61. for (int q = 0; q < channels_per_group; q++)
  62. {
  63. float tmp = bottom_top_blob_g[q] - mean;
  64. sqsum += tmp * tmp;
  65. }
  66. float var = sqsum / channels_per_group;
  67. for (int q = 0; q < channels_per_group; q++)
  68. {
  69. float a;
  70. float b;
  71. if (affine)
  72. {
  73. float gamma = gamma_data_g[q];
  74. float beta = beta_data_g[q];
  75. a = gamma / sqrtf(var + eps);
  76. b = -mean * a + beta;
  77. }
  78. else
  79. {
  80. a = 1.f / (sqrtf(var + eps));
  81. b = -mean * a;
  82. }
  83. bottom_top_blob_g[q] = bottom_top_blob_g[q] * a + b;
  84. }
  85. }
  86. }
  87. if (dims == 2)
  88. {
  89. int w = bottom_top_blob.w;
  90. #pragma omp parallel for num_threads(opt.num_threads)
  91. for (int g = 0; g < group; g++)
  92. {
  93. Mat bottom_top_blob_g = bottom_top_blob.row_range(g * channels_per_group, channels_per_group);
  94. const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
  95. const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);
  96. // mean and var
  97. float sum = 0.f;
  98. for (int q = 0; q < channels_per_group; q++)
  99. {
  100. const float* ptr = bottom_top_blob_g.row(q);
  101. for (int i = 0; i < w; i++)
  102. {
  103. sum += ptr[i];
  104. }
  105. }
  106. float mean = sum / (channels_per_group * w);
  107. float sqsum = 0.f;
  108. for (int q = 0; q < channels_per_group; q++)
  109. {
  110. const float* ptr = bottom_top_blob_g.row(q);
  111. for (int i = 0; i < w; i++)
  112. {
  113. float tmp = ptr[i] - mean;
  114. sqsum += tmp * tmp;
  115. }
  116. }
  117. float var = sqsum / (channels_per_group * w);
  118. for (int q = 0; q < channels_per_group; q++)
  119. {
  120. float a;
  121. float b;
  122. if (affine)
  123. {
  124. float gamma = gamma_data_g[q];
  125. float beta = beta_data_g[q];
  126. a = gamma / sqrtf(var + eps);
  127. b = -mean * a + beta;
  128. }
  129. else
  130. {
  131. a = 1.f / (sqrtf(var + eps));
  132. b = -mean * a;
  133. }
  134. float* ptr = bottom_top_blob_g.row(q);
  135. for (int i = 0; i < w; i++)
  136. {
  137. ptr[i] = ptr[i] * a + b;
  138. }
  139. }
  140. }
  141. }
  142. if (dims == 3 || dims == 4)
  143. {
  144. int w = bottom_top_blob.w;
  145. int h = bottom_top_blob.h;
  146. int d = bottom_top_blob.d;
  147. int size = w * h * d;
  148. #pragma omp parallel for num_threads(opt.num_threads)
  149. for (int g = 0; g < group; g++)
  150. {
  151. Mat bottom_top_blob_g = bottom_top_blob.channel_range(g * channels_per_group, channels_per_group);
  152. const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
  153. const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);
  154. // mean and var
  155. float sum = 0.f;
  156. for (int q = 0; q < channels_per_group; q++)
  157. {
  158. const float* ptr = bottom_top_blob_g.channel(q);
  159. for (int i = 0; i < size; i++)
  160. {
  161. sum += ptr[i];
  162. }
  163. }
  164. float mean = sum / (channels_per_group * size);
  165. float sqsum = 0.f;
  166. for (int q = 0; q < channels_per_group; q++)
  167. {
  168. const float* ptr = bottom_top_blob_g.channel(q);
  169. for (int i = 0; i < size; i++)
  170. {
  171. float tmp = ptr[i] - mean;
  172. sqsum += tmp * tmp;
  173. }
  174. }
  175. float var = sqsum / (channels_per_group * size);
  176. for (int q = 0; q < channels_per_group; q++)
  177. {
  178. float a;
  179. float b;
  180. if (affine)
  181. {
  182. float gamma = gamma_data_g[q];
  183. float beta = beta_data_g[q];
  184. a = gamma / sqrtf(var + eps);
  185. b = -mean * a + beta;
  186. }
  187. else
  188. {
  189. a = 1.f / (sqrtf(var + eps));
  190. b = -mean * a;
  191. }
  192. float* ptr = bottom_top_blob_g.channel(q);
  193. for (int i = 0; i < size; i++)
  194. {
  195. ptr[i] = ptr[i] * a + b;
  196. }
  197. }
  198. }
  199. }
  200. return 0;
  201. }
  202. } // namespace ncnn