You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward.cpp 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * \file dnn/src/cuda/resize/forward.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/cv/common.h"
  12. #include "src/cuda/handle.h"
  13. #include "src/cuda/resize/common.h"
  14. #include "src/cuda/resize/helper.h"
  15. #include "src/cuda/resize/opr_impl.h"
  16. #include "src/cuda/resize/resize_cv.cuh"
  17. #include "src/cuda/utils.h"
  18. using namespace megdnn;
  19. using namespace cuda;
  20. namespace {
  21. void resize_cv_proxy(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  22. InterpolationMode imode, void* workspace,
  23. cudaStream_t stream) {
  24. using namespace megcv;
  25. for (size_t i = 0; i < src.layout.shape[0]; ++i) {
  26. if (dst.layout.dtype == dtype::Float32()) {
  27. Mat<float> src_mat = TensorND2Mat<float>(src, i);
  28. Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
  29. resize::resize_cv<float>(
  30. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  31. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  32. src_mat.step(), dst_mat.step(), src_mat.channels(), imode,
  33. workspace, stream);
  34. } else if (dst.layout.dtype == dtype::Uint8()) {
  35. Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
  36. Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
  37. resize::resize_cv<uchar>(
  38. src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
  39. src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
  40. src_mat.step(), dst_mat.step(), src_mat.channels(), imode,
  41. workspace, stream);
  42. } else {
  43. megdnn_throw(
  44. megdnn_mangle("Unsupported datatype of WarpAffine optr."));
  45. }
  46. }
  47. }
  48. } // anonymous namespace
  49. size_t ResizeImpl::get_workspace_in_bytes(const TensorLayout& src,
  50. const TensorLayout& dst) {
  51. InterpolationMode imode = param().imode;
  52. if (param().format == Param::Format::NCHW ||
  53. (imode != Param::InterpolationMode::CUBIC &&
  54. imode != Param::InterpolationMode::LANCZOS4)) {
  55. return 0;
  56. }
  57. size_t src_rows = src.shape[1];
  58. size_t dst_rows = dst.shape[1];
  59. size_t src_cols = src.shape[2];
  60. size_t dst_cols = dst.shape[2];
  61. size_t ch = src.shape[3];
  62. size_t dst_area_size = dst_rows * dst_cols;
  63. size_t src_area_size = src_rows * src_cols;
  64. bool enlarge = dst_area_size > src_area_size;
  65. bool shrink = dst_area_size <= src_area_size;
  66. bool U8 = src.dtype == dtype::Uint8();
  67. megdnn_assert(src.dtype == dtype::Uint8() || src.dtype == dtype::Float32());
  68. bool F32_1 = !U8 && ch == 1;
  69. bool F32_3 = !U8 && ch == 3;
  70. bool use_vector = (enlarge && (dst_area_size <= 500 * 500)) ||
  71. (shrink && (F32_3 || (U8 && dst_area_size <= 500 * 500) ||
  72. (F32_1 && dst_area_size <= 1000 * 1000)));
  73. if (!use_vector) {
  74. int coef_size = 0;
  75. if (imode == Param::InterpolationMode::CUBIC) {
  76. coef_size = 4;
  77. } else {
  78. coef_size = 8;
  79. megdnn_assert(imode == Param::InterpolationMode::LANCZOS4);
  80. }
  81. if (U8) {
  82. return dst_rows * coef_size * sizeof(short) + //! dev_coef_row
  83. dst_rows * sizeof(int) + //! dev_sr
  84. dst_cols * coef_size * sizeof(short) + //! dev_coef_col
  85. dst_cols * sizeof(int); //! dev_sc
  86. } else {
  87. return dst_rows * coef_size * sizeof(float) + //! dev_coef_row
  88. dst_rows * sizeof(int) + //! dev_sr
  89. dst_cols * coef_size * sizeof(float) + //! dev_coef_col
  90. dst_cols * sizeof(int); //! dev_sc
  91. }
  92. }
  93. return 0;
  94. }
  95. void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
  96. _megdnn_workspace workspace) {
  97. check_exec(src.layout, dst.layout, workspace.size);
  98. auto stream = cuda_stream(this->handle());
  99. bool is_nhwc = param().format == param::Resize::Format::NHWC;
  100. size_t C, IH, IW, OH, OW;
  101. ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
  102. if (is_nhwc) {
  103. if (param().imode != Param::InterpolationMode::LINEAR &&
  104. is_nhwc_contig_wc(src.layout)) {
  105. resize_cv_proxy(src, dst, resize::get_imode(param().imode),
  106. workspace.raw_ptr, stream);
  107. return;
  108. }
  109. C = src.layout.shape[3];
  110. IH = src.layout.shape[1];
  111. IW = src.layout.shape[2];
  112. OH = dst.layout.shape[1];
  113. OW = dst.layout.shape[2];
  114. } else if (param().format == param::Resize::Format::NCHW) {
  115. C = src.layout.shape[1];
  116. IH = src.layout.shape[2];
  117. IW = src.layout.shape[3];
  118. OH = dst.layout.shape[2];
  119. OW = dst.layout.shape[3];
  120. S_IN = src.layout.stride[0];
  121. S_IC = src.layout.stride[1];
  122. S_IH = src.layout.stride[2];
  123. S_IW = src.layout.stride[3];
  124. } else {
  125. megdnn_assert(param().format == param::Resize::Format::NCHW4,
  126. "invalid resize format");
  127. megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS8);
  128. C = src.layout.shape[1] * 4;
  129. IH = src.layout.shape[2];
  130. IW = src.layout.shape[3];
  131. OH = dst.layout.shape[2];
  132. OW = dst.layout.shape[3];
  133. resize::forward_proxy_nchw4(src.compatible_ptr<int8_t>(),
  134. dst.compatible_ptr<int8_t>(), src.layout[0],
  135. C, IH, IW, OH, OW, stream);
  136. return;
  137. }
  138. megdnn_assert(param().imode == Param::InterpolationMode::LINEAR,
  139. "unsupported interpolation mode for NCHW format");
  140. if (src.layout.dtype == dtype::Float32{}) {
  141. resize::forward_proxy(is_nhwc, src.ptr<dt_float32>(),
  142. dst.ptr<dt_float32>(), src.layout[0], C, IH, IW,
  143. OH, OW, S_IN, S_IC, S_IH, S_IW, stream);
  144. } else if (src.layout.dtype == dtype::Uint8()) {
  145. resize::forward_proxy(is_nhwc, src.ptr<dt_uint8>(), dst.ptr<dt_uint8>(),
  146. src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
  147. S_IH, S_IW, stream);
  148. } else if (src.layout.dtype == dtype::Int8()) {
  149. resize::forward_proxy(is_nhwc, src.ptr<dt_int8>(), dst.ptr<dt_int8>(),
  150. src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
  151. S_IH, S_IW, stream);
  152. } else {
  153. megdnn_throw(
  154. ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
  155. }
  156. }
  157. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台