You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm_arm.cpp 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "batchnorm_arm.h"
  15. #if __ARM_NEON
  16. #include <arm_neon.h>
  17. #endif // __ARM_NEON
  18. namespace ncnn {
  19. DEFINE_LAYER_CREATOR(BatchNorm_arm)
  20. BatchNorm_arm::BatchNorm_arm()
  21. {
  22. #if __ARM_NEON
  23. support_packing = true;
  24. #endif // __ARM_NEON
  25. }
  26. int BatchNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
  27. {
  28. int dims = bottom_top_blob.dims;
  29. int elempack = bottom_top_blob.elempack;
  30. #if __ARM_NEON
  31. if (opt.use_packing_layout)
  32. {
  33. if (elempack == 4)
  34. {
  35. if (dims == 1)
  36. {
  37. int w = bottom_top_blob.w;
  38. #pragma omp parallel for num_threads(opt.num_threads)
  39. for (int i=0; i<w; i++)
  40. {
  41. float* ptr = (float*)bottom_top_blob + i * 4;
  42. float32x4_t _a = vld1q_f32((const float*)a_data + i * 4);
  43. float32x4_t _b = vld1q_f32((const float*)b_data + i * 4);
  44. float32x4_t _p = vld1q_f32(ptr);
  45. _p = vmlaq_f32(_a, _p, _b);
  46. vst1q_f32(ptr, _p);
  47. }
  48. }
  49. if (dims == 2)
  50. {
  51. int w = bottom_top_blob.w;
  52. int h = bottom_top_blob.h;
  53. #pragma omp parallel for num_threads(opt.num_threads)
  54. for (int i=0; i<h; i++)
  55. {
  56. float32x4_t _a = vld1q_f32((const float*)a_data + i * 4);
  57. float32x4_t _b = vld1q_f32((const float*)b_data + i * 4);
  58. float* ptr = bottom_top_blob.row(i);
  59. for (int j=0; j<w; j++)
  60. {
  61. float32x4_t _p = vld1q_f32(ptr);
  62. _p = vmlaq_f32(_a, _p, _b);
  63. vst1q_f32(ptr, _p);
  64. ptr += 4;
  65. }
  66. }
  67. }
  68. if (dims == 3)
  69. {
  70. int w = bottom_top_blob.w;
  71. int h = bottom_top_blob.h;
  72. int c = bottom_top_blob.c;
  73. int size = w * h;
  74. #pragma omp parallel for num_threads(opt.num_threads)
  75. for (int q=0; q<c; q++)
  76. {
  77. float32x4_t _a = vld1q_f32((const float*)a_data + q * 4);
  78. float32x4_t _b = vld1q_f32((const float*)b_data + q * 4);
  79. float* ptr = bottom_top_blob.channel(q);
  80. for (int i=0; i<size; i++)
  81. {
  82. float32x4_t _p = vld1q_f32(ptr);
  83. _p = vmlaq_f32(_a, _p, _b);
  84. vst1q_f32(ptr, _p);
  85. ptr += 4;
  86. }
  87. }
  88. }
  89. return 0;
  90. }
  91. } // opt.use_packing_layout
  92. #endif // __ARM_NEON
  93. if (dims != 3)
  94. return BatchNorm::forward_inplace(bottom_top_blob, opt);
  95. int w = bottom_top_blob.w;
  96. int h = bottom_top_blob.h;
  97. // int c = bottom_top_blob.c;
  98. int size = w * h;
  99. #pragma omp parallel for num_threads(opt.num_threads)
  100. for (int q=0; q<channels; q++)
  101. {
  102. float* ptr = bottom_top_blob.channel(q);
  103. float a = a_data[q];
  104. float b = b_data[q];
  105. #if __ARM_NEON
  106. int nn = size >> 2;
  107. int remain = size - (nn << 2);
  108. #else
  109. int remain = size;
  110. #endif // __ARM_NEON
  111. #if __ARM_NEON
  112. #if __aarch64__
  113. if (nn > 0)
  114. {
  115. asm volatile(
  116. "dup v1.4s, %w4 \n"
  117. "dup v2.4s, %w5 \n"
  118. "0: \n"
  119. "prfm pldl1keep, [%1, #128] \n"
  120. "ld1 {v0.4s}, [%1] \n"
  121. "orr v3.16b, v1.16b, v1.16b \n"
  122. "fmla v3.4s, v0.4s, v2.4s \n"
  123. "subs %w0, %w0, #1 \n"
  124. "st1 {v3.4s}, [%1], #16 \n"
  125. "bne 0b \n"
  126. : "=r"(nn), // %0
  127. "=r"(ptr) // %1
  128. : "0"(nn),
  129. "1"(ptr),
  130. "r"(a), // %4
  131. "r"(b) // %5
  132. : "cc", "memory", "v0", "v1", "v2", "v3"
  133. );
  134. }
  135. #else
  136. if (nn > 0)
  137. {
  138. asm volatile(
  139. "vdup.f32 q1, %4 \n"
  140. "vdup.f32 q2, %5 \n"
  141. "0: \n"
  142. "pld [%1, #128] \n"
  143. "vld1.f32 {d0-d1}, [%1 :128] \n"
  144. "vorr.32 q3, q1, q1 \n"
  145. "vmla.f32 q3, q0, q2 \n"
  146. "subs %0, #1 \n"
  147. "vst1.f32 {d6-d7}, [%1 :128]! \n"
  148. "bne 0b \n"
  149. : "=r"(nn), // %0
  150. "=r"(ptr) // %1
  151. : "0"(nn),
  152. "1"(ptr),
  153. "r"(a), // %4
  154. "r"(b) // %5
  155. : "cc", "memory", "q0", "q1", "q2", "q3"
  156. );
  157. }
  158. #endif // __aarch64__
  159. #endif // __ARM_NEON
  160. for (; remain>0; remain--)
  161. {
  162. *ptr = b * *ptr + a;
  163. ptr++;
  164. }
  165. }
  166. return 0;
  167. }
  168. } // namespace ncnn