You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /**
  2. * \file dnn/src/common/pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
  16. TensorLayout& dst) {
  17. auto errmsg =
  18. megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
  19. "pad_h=" + std::to_string(param().pad_h) + ", " +
  20. "pad_w=" + std::to_string(param().pad_w) + ", " +
  21. "stride_h=" + std::to_string(param().stride_h) + ", " +
  22. "stride_w=" + std::to_string(param().stride_w) + ", " +
  23. "window_h=" + std::to_string(param().window_h) + ", " +
  24. "window_w=" + std::to_string(param().window_w) + ", " +
  25. "is_max=" + std::to_string(param().mode == Mode::MAX) + ", " +
  26. "is_nhwc=" + std::to_string(param().format == Param::Format::NHWC) +
  27. ", " + "is_nhwcd4=" +
  28. std::to_string(param().format == Param::Format::NHWCD4);
  29. auto errmsg_c = errmsg.c_str();
  30. MEGDNN_MARK_USED_VAR(errmsg_c);
  31. megdnn_assert_contiguous(src);
  32. size_t spatial_pos, c_pos, batch_pos = 0;
  33. if (param().format == Param::Format::NCHW) {
  34. megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
  35. spatial_pos = 2;
  36. c_pos = 1;
  37. } else if (param().format == Param::Format::NHWC) {
  38. megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
  39. spatial_pos = 1;
  40. c_pos = 3;
  41. } else if (param().format == Param::Format::NCHW4 ||
  42. param().format == Param::Format::NCHW44 ||
  43. param().format == Param::Format::NCHW88 ||
  44. param().format == Param::Format::NCHW32) {
  45. megdnn_assert(src.ndim == 5_z, "%s", errmsg_c);
  46. spatial_pos = 2;
  47. c_pos = 1;
  48. } else if (param().format == Param::Format::CHWN4) {
  49. spatial_pos = 1;
  50. c_pos = 0;
  51. batch_pos = 3;
  52. } else {
  53. megdnn_assert(
  54. param().format == Param::Format::NHWCD4 && src.ndim == 5_z,
  55. "%s", errmsg_c);
  56. spatial_pos = 1;
  57. c_pos = 2;
  58. }
  59. size_t n = src[batch_pos];
  60. size_t c = src[c_pos];
  61. size_t ih = src[spatial_pos];
  62. size_t iw = src[spatial_pos + 1];
  63. if (param().format == Param::Format::NHWCD4) {
  64. c *= 4;
  65. iw = src[spatial_pos + 2];
  66. }
  67. if (param().format == Param::Format::NCHW4 ||
  68. param().format == Param::Format::NCHW44 ||
  69. param().format == Param::Format::CHWN4) {
  70. c *= 4;
  71. }
  72. if (param().format == Param::Format::NCHW88) {
  73. c *= 8;
  74. }
  75. if (param().format == Param::Format::NCHW32) {
  76. c *= 32;
  77. }
  78. size_t oh, ow;
  79. size_t fh = this->param().window_h;
  80. size_t fw = this->param().window_w;
  81. size_t sh = this->param().stride_h;
  82. size_t sw = this->param().stride_w;
  83. size_t ph = this->param().pad_h;
  84. size_t pw = this->param().pad_w;
  85. if (ph >= fh || pw >= fw) {
  86. megdnn_log_error(
  87. "pooling padding size (%zu %zu) should not be bigger than "
  88. "window size (%zu %zu), it only can be used in CaffePooling",
  89. pw, ph, fw, fh);
  90. }
  91. infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow);
  92. if (param().format == Param::Format::NCHW) {
  93. dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype);
  94. } else if (param().format == Param::Format::NHWC) {
  95. megdnn_assert(param().format == Param::Format::NHWC,
  96. "invalid pooling format");
  97. dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format);
  98. } else if (param().format == Param::Format::NCHW4 ||
  99. param().format == Param::Format::NCHW44) {
  100. dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format};
  101. } else if (param().format == Param::Format::NCHW88) {
  102. dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format};
  103. } else if (param().format == Param::Format::NCHW32) {
  104. dst = TensorLayout{{n, c / 32, oh, ow, 32}, src.dtype, src.format};
  105. } else if (param().format == Param::Format::CHWN4) {
  106. dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format};
  107. } else {
  108. megdnn_assert(param().format == Param::Format::NHWCD4,
  109. "invalid pooling format");
  110. dst = TensorLayout{{n, oh, c / 4, ow, 4}, src.dtype, src.format};
  111. }
  112. }
  113. void PoolingBase::check_layout_fwd(const TensorLayout& src,
  114. const TensorLayout& dst) {
  115. TensorLayout dst_expected;
  116. megdnn_assert_eq_dtype(src, dst);
  117. deduce_layout_fwd(src, dst_expected);
  118. megdnn_assert_eq_layout(dst_expected, dst);
  119. megdnn_assert(src.dtype == dst.dtype);
  120. megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT ||
  121. src.dtype == dtype::Int8() ||
  122. src.dtype.category() == DTypeCategory::QUANTIZED);
  123. }
  124. void PoolingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  125. deduce_layout_fwd(src, dst);
  126. }
  127. void PoolingForward::check_exec(const TensorLayout& src,
  128. const TensorLayout& dst,
  129. size_t workspace_in_bytes) {
  130. check_layout_fwd(src, dst);
  131. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  132. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  133. }
  134. void PoolingBackward::check_exec(const TensorLayout& src,
  135. const TensorLayout& dst,
  136. const TensorLayout& diff,
  137. const TensorLayout& grad,
  138. size_t workspace_in_bytes) {
  139. check_layout_fwd(src, dst);
  140. megdnn_assert_eq_layout(src, grad);
  141. megdnn_assert_eq_layout(dst, diff);
  142. auto required_workspace_in_bytes =
  143. get_workspace_in_bytes(src, dst, diff, grad);
  144. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  145. }
  146. } // namespace megdnn
  147. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台