You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

upsample2_nchwxx.cpp 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * \file dnn/src/arm_common/resize/upsample2_nchwxx.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/resize/upsample2_nchwxx.h"
  13. #include "src/arm_common/resize/helper.h"
  14. #include "src/arm_common/simd_macro/marm_neon.h"
  15. using namespace megdnn;
  16. using namespace arm_common;
  17. using namespace resize;
  18. namespace {
  19. template <typename simd_helper, size_t fh, size_t fw>
  20. static inline typename simd_helper::simd_type compute_linear_element(
  21. const typename simd_helper::simd_type src[4],
  22. const typename simd_helper::simd_type alpha[2][2]) {
  23. typename simd_helper::simd_type c = simd_helper::dup(0);
  24. c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]);
  25. c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]);
  26. c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]);
  27. c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]);
  28. return c;
  29. }
  30. template <typename simd_helper, bool has_right, bool has_bottom>
  31. static inline void compute_linear_2x2_element(
  32. const typename simd_helper::ctype* src, typename simd_helper::ctype* dst,
  33. size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) {
  34. constexpr size_t PC = simd_helper::simd_width;
  35. const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src};
  36. if (has_right) {
  37. src_ptr[1] += PC;
  38. src_ptr[3] += PC;
  39. }
  40. if (has_bottom) {
  41. src_ptr[2] += IW * PC;
  42. src_ptr[3] += IW * PC;
  43. }
  44. typename simd_helper::simd_type rsrc[4];
  45. rsrc[0] = simd_helper::load(src_ptr[0]);
  46. rsrc[1] = simd_helper::load(src_ptr[1]);
  47. rsrc[2] = simd_helper::load(src_ptr[2]);
  48. rsrc[3] = simd_helper::load(src_ptr[3]);
  49. typename simd_helper::simd_type rdst[4];
  50. rdst[0] = compute_linear_element<simd_helper, 0, 0>(rsrc, alpha);
  51. rdst[1] = compute_linear_element<simd_helper, 0, 1>(rsrc, alpha);
  52. rdst[2] = compute_linear_element<simd_helper, 1, 0>(rsrc, alpha);
  53. rdst[3] = compute_linear_element<simd_helper, 1, 1>(rsrc, alpha);
  54. simd_helper::store(dst, rdst[0]);
  55. if (has_right) {
  56. simd_helper::store(dst + PC, rdst[1]);
  57. }
  58. if (has_bottom) {
  59. simd_helper::store(dst + OW * PC, rdst[2]);
  60. }
  61. if (has_right && has_bottom) {
  62. simd_helper::store(dst + (OW + 1) * PC, rdst[3]);
  63. }
  64. }
  65. template <typename ctype>
  66. void linear_upsample2_nchwxx(
  67. const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) {
  68. using simd_helper = SIMDHelper<ctype>;
  69. size_t OW = IW * 2;
  70. constexpr size_t PC = simd_helper::simd_width;
  71. typename simd_helper::simd_type alpha[2][2];
  72. alpha[0][0] = simd_helper::dup(0.75 * 0.75);
  73. alpha[0][1] = simd_helper::dup(0.75 * 0.25);
  74. alpha[1][0] = simd_helper::dup(0.25 * 0.75);
  75. alpha[1][1] = simd_helper::dup(0.25 * 0.25);
  76. for (size_t i = 0; i < N; ++i) {
  77. compute_linear_2x2_element<simd_helper, false, false>(
  78. src_ptr, dst_ptr, IW, OW, alpha);
  79. {
  80. for (size_t iw = 0; iw + 1 < IW; ++iw) {
  81. compute_linear_2x2_element<simd_helper, true, false>(
  82. src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha);
  83. }
  84. }
  85. compute_linear_2x2_element<simd_helper, false, false>(
  86. src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha);
  87. dst_ptr += OW * PC;
  88. for (size_t ih = 0; ih + 1 < IH; ++ih) {
  89. compute_linear_2x2_element<simd_helper, false, true>(
  90. src_ptr, dst_ptr, IW, OW, alpha);
  91. for (size_t iw = 0; iw + 1 < IW; ++iw) {
  92. compute_linear_2x2_element<simd_helper, true, true>(
  93. src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha);
  94. }
  95. compute_linear_2x2_element<simd_helper, false, true>(
  96. src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha);
  97. src_ptr += IW * PC;
  98. dst_ptr += 2 * OW * PC;
  99. }
  100. compute_linear_2x2_element<simd_helper, false, false>(
  101. src_ptr, dst_ptr, IW, OW, alpha);
  102. {
  103. for (size_t iw = 0; iw + 1 < IW; ++iw) {
  104. compute_linear_2x2_element<simd_helper, true, false>(
  105. src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha);
  106. }
  107. }
  108. compute_linear_2x2_element<simd_helper, false, false>(
  109. src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha);
  110. src_ptr += IW * PC;
  111. dst_ptr += OW * PC;
  112. }
  113. }
  114. template <typename ctype>
  115. void nearest_upsample2_nchwxx(
  116. const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) {
  117. using simd_helper = SIMDHelper<ctype>;
  118. size_t OW = IW * 2;
  119. constexpr size_t PC = simd_helper::simd_width;
  120. for (size_t i = 0; i < N; ++i) {
  121. for (size_t ih = 0; ih < IH; ++ih) {
  122. for (size_t iw = 0; iw < IW; ++iw) {
  123. typename simd_helper::simd_type r0 =
  124. simd_helper::load(src_ptr + iw * PC);
  125. simd_helper::store(dst_ptr + (iw * 2) * PC, r0);
  126. simd_helper::store(dst_ptr + (iw * 2 + 1) * PC, r0);
  127. simd_helper::store(dst_ptr + (OW + iw * 2) * PC, r0);
  128. simd_helper::store(dst_ptr + (OW + iw * 2 + 1) * PC, r0);
  129. }
  130. src_ptr += IW * PC;
  131. dst_ptr += 2 * OW * PC;
  132. }
  133. }
  134. }
  135. } // namespace
  136. void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32(
  137. const ResizeImpl::KernParam<float>& kern_param) {
  138. linear_upsample2_nchwxx(
  139. kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4,
  140. kern_param.ih, kern_param.iw);
  141. }
  142. void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32(
  143. const ResizeImpl::KernParam<float>& kern_param) {
  144. nearest_upsample2_nchwxx(
  145. kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4,
  146. kern_param.ih, kern_param.iw);
  147. }
  148. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  149. void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16(
  150. const ResizeImpl::KernParam<dt_float16>& kern_param) {
  151. auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr());
  152. auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr());
  153. linear_upsample2_nchwxx(
  154. sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw);
  155. }
  156. void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16(
  157. const ResizeImpl::KernParam<dt_float16>& kern_param) {
  158. auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr());
  159. auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr());
  160. nearest_upsample2_nchwxx(
  161. sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw);
  162. }
  163. #endif

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台