You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward.cpp 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. /**
  2. * \file dnn/src/cuda/resize/forward.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/cv/common.h"
  12. #include "src/common/cv/enums.h"
  13. #include "src/cuda/handle.h"
  14. #include "src/cuda/resize/common.h"
  15. #include "src/cuda/resize/helper.h"
  16. #include "src/cuda/resize/opr_impl.h"
  17. #include "src/cuda/resize/resize_cv.cuh"
  18. #include "src/cuda/utils.h"
  19. using namespace megdnn;
  20. using namespace cuda;
  21. namespace {
  22. void resize_cv_proxy(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  23. InterpolationMode imode, void* workspace,
  24. cudaStream_t stream) {
  25. using namespace megcv;
  26. for (size_t i = 0; i < src.layout.shape[0]; ++i) {
  27. if (dst.layout.dtype == dtype::Float32()) {
  28. Mat<float> src_mat = TensorND2Mat<float>(src, i);
  29. Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
  30. resize::resize_cv<float>(
  31. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  32. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  33. src_mat.step(), dst_mat.step(), src_mat.channels(), imode,
  34. workspace, stream);
  35. } else if (dst.layout.dtype == dtype::Uint8()) {
  36. Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
  37. Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
  38. resize::resize_cv<uchar>(
  39. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  40. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  41. src_mat.step(), dst_mat.step(), src_mat.channels(), imode,
  42. workspace, stream);
  43. } else {
  44. megdnn_throw("Unsupported datatype of WarpAffine optr.");
  45. }
  46. }
  47. }
  48. } // anonymous namespace
  49. size_t ResizeImpl::get_workspace_in_bytes(const TensorLayout& src,
  50. const TensorLayout& dst) {
  51. InterpolationMode imode = param().imode;
  52. if (param().format == Param::Format::NCHW ||
  53. (imode != Param::InterpolationMode::CUBIC &&
  54. imode != Param::InterpolationMode::LANCZOS4)) {
  55. return 0;
  56. }
  57. size_t src_rows = src.shape[1];
  58. size_t dst_rows = dst.shape[1];
  59. size_t src_cols = src.shape[2];
  60. size_t dst_cols = dst.shape[2];
  61. size_t ch = src.shape[3];
  62. size_t dst_area_size = dst_rows * dst_cols;
  63. size_t src_area_size = src_rows * src_cols;
  64. bool enlarge = dst_area_size > src_area_size;
  65. bool shrink = dst_area_size <= src_area_size;
  66. bool U8 = src.dtype == dtype::Uint8();
  67. megdnn_assert(src.dtype == dtype::Uint8() || src.dtype == dtype::Float32());
  68. bool F32_1 = !U8 && ch == 1;
  69. bool F32_3 = !U8 && ch == 3;
  70. bool use_vector = (enlarge && (dst_area_size <= 500 * 500)) ||
  71. (shrink && (F32_3 || (U8 && dst_area_size <= 500 * 500) ||
  72. (F32_1 && dst_area_size <= 1000 * 1000)));
  73. if (!use_vector) {
  74. int coef_size = 0;
  75. if (imode == Param::InterpolationMode::CUBIC) {
  76. coef_size = 4;
  77. } else {
  78. coef_size = 8;
  79. megdnn_assert(imode == Param::InterpolationMode::LANCZOS4);
  80. }
  81. if (U8) {
  82. return dst_rows * coef_size * sizeof(short) + //! dev_coef_row
  83. dst_rows * sizeof(int) + //! dev_sr
  84. dst_cols * coef_size * sizeof(short) + //! dev_coef_col
  85. dst_cols * sizeof(int); //! dev_sc
  86. } else {
  87. return dst_rows * coef_size * sizeof(float) + //! dev_coef_row
  88. dst_rows * sizeof(int) + //! dev_sr
  89. dst_cols * coef_size * sizeof(float) + //! dev_coef_col
  90. dst_cols * sizeof(int); //! dev_sc
  91. }
  92. }
  93. return 0;
  94. }
  95. void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  96. _megdnn_workspace workspace) {
  97. check_exec(src.layout, dst.layout, workspace.size);
  98. auto stream = cuda_stream(this->handle());
  99. bool is_nhwc = param().format == param::Resize::Format::NHWC;
  100. size_t C, IH, IW, OH, OW;
  101. ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
  102. if (is_nhwc) {
  103. if (param().imode != Param::InterpolationMode::LINEAR &&
  104. is_nhwc_contig_wc(src.layout)) {
  105. resize_cv_proxy(src, dst, resize::get_imode(param().imode),
  106. workspace.raw_ptr, stream);
  107. return;
  108. }
  109. C = src.layout.shape[3];
  110. IH = src.layout.shape[1];
  111. IW = src.layout.shape[2];
  112. OH = dst.layout.shape[1];
  113. OW = dst.layout.shape[2];
  114. } else if (param().format == param::Resize::Format::NCHW) {
  115. C = src.layout.shape[1];
  116. IH = src.layout.shape[2];
  117. IW = src.layout.shape[3];
  118. OH = dst.layout.shape[2];
  119. OW = dst.layout.shape[3];
  120. S_IN = src.layout.stride[0];
  121. S_IC = src.layout.stride[1];
  122. S_IH = src.layout.stride[2];
  123. S_IW = src.layout.stride[3];
  124. } else {
  125. megdnn_assert(param().format == param::Resize::Format::NCHW4,
  126. "invalid resize format");
  127. megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS8);
  128. C = src.layout.shape[1] * 4;
  129. IH = src.layout.shape[2];
  130. IW = src.layout.shape[3];
  131. OH = dst.layout.shape[2];
  132. OW = dst.layout.shape[3];
  133. resize::forward_proxy_nchw4(src.compatible_ptr<int8_t>(),
  134. dst.compatible_ptr<int8_t>(), src.layout[0],
  135. C, IH, IW, OH, OW, stream);
  136. return;
  137. }
  138. megdnn_assert(param().imode == Param::InterpolationMode::LINEAR ||
  139. param().imode == Param::InterpolationMode::NEAREST,
  140. "unsupported interpolation mode for NCHW format");
  141. if (src.layout.dtype == dtype::Float32{}) {
  142. resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
  143. src.ptr<dt_float32>(), dst.ptr<dt_float32>(),
  144. src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
  145. S_IH, S_IW, stream);
  146. } else if (src.layout.dtype == dtype::Uint8()) {
  147. resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
  148. src.ptr<dt_uint8>(), dst.ptr<dt_uint8>(),
  149. src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
  150. S_IH, S_IW, stream);
  151. } else if (src.layout.dtype == dtype::Int8()) {
  152. resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
  153. src.ptr<dt_int8>(), dst.ptr<dt_int8>(),
  154. src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
  155. S_IH, S_IW, stream);
  156. } else {
  157. megdnn_throw(
  158. ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
  159. }
  160. }
  161. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台