You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * \file dnn/src/common/pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
  16. TensorLayout& dst) {
  17. auto errmsg =
  18. megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
  19. megdnn_mangle("pad_h=") + std::to_string(param().pad_h) + ", " +
  20. megdnn_mangle("pad_w=") + std::to_string(param().pad_w) + ", " +
  21. megdnn_mangle("stride_h=") + std::to_string(param().stride_h) +
  22. ", " + megdnn_mangle("stride_w=") +
  23. std::to_string(param().stride_w) + ", " +
  24. megdnn_mangle("window_h=") + std::to_string(param().window_h) +
  25. ", " + megdnn_mangle("window_w=") +
  26. std::to_string(param().window_w) + ", " + megdnn_mangle("is_max=") +
  27. std::to_string(param().mode == Mode::MAX) + ", " +
  28. megdnn_mangle("is_nhwc=") +
  29. std::to_string(param().format == Param::Format::NHWC) + ", " +
  30. megdnn_mangle("is_nhwcd4=") +
  31. std::to_string(param().format == Param::Format::NHWCD4);
  32. auto errmsg_c = errmsg.c_str();
  33. MEGDNN_MARK_USED_VAR(errmsg_c);
  34. megdnn_assert_contiguous(src);
  35. size_t spatial_pos, c_pos, batch_pos = 0;
  36. if (param().format == Param::Format::NCHW) {
  37. megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
  38. spatial_pos = 2;
  39. c_pos = 1;
  40. } else if (param().format == Param::Format::NHWC) {
  41. megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
  42. spatial_pos = 1;
  43. c_pos = 3;
  44. } else if (param().format == Param::Format::NCHW4 ||
  45. param().format == Param::Format::NCHW44 ||
  46. param().format == Param::Format::NCHW88 ||
  47. param().format == Param::Format::NCHW32) {
  48. megdnn_assert(src.ndim == 5_z, "%s", errmsg_c);
  49. spatial_pos = 2;
  50. c_pos = 1;
  51. } else if (param().format == Param::Format::CHWN4) {
  52. spatial_pos = 1;
  53. c_pos = 0;
  54. batch_pos = 3;
  55. } else {
  56. megdnn_assert(
  57. param().format == Param::Format::NHWCD4 && src.ndim == 5_z,
  58. "%s", errmsg_c);
  59. spatial_pos = 1;
  60. c_pos = 2;
  61. }
  62. size_t n = src[batch_pos];
  63. size_t c = src[c_pos];
  64. size_t ih = src[spatial_pos];
  65. size_t iw = src[spatial_pos + 1];
  66. if (param().format == Param::Format::NHWCD4) {
  67. c *= 4;
  68. iw = src[spatial_pos + 2];
  69. }
  70. if (param().format == Param::Format::NCHW4 ||
  71. param().format == Param::Format::NCHW44 ||
  72. param().format == Param::Format::CHWN4) {
  73. c *= 4;
  74. }
  75. if (param().format == Param::Format::NCHW88) {
  76. c *= 8;
  77. }
  78. if (param().format == Param::Format::NCHW32) {
  79. c *= 32;
  80. }
  81. size_t oh, ow;
  82. size_t fh = this->param().window_h;
  83. size_t fw = this->param().window_w;
  84. size_t sh = this->param().stride_h;
  85. size_t sw = this->param().stride_w;
  86. size_t ph = this->param().pad_h;
  87. size_t pw = this->param().pad_w;
  88. if (ph < fh && pw < fw) {
  89. megdnn_log_error(
  90. "pooling padding size (%zu %zu) should not be bigger than "
  91. "window size (%zu %zu), it only can be used in CaffePooling",
  92. pw, ph, fw, fh);
  93. }
  94. infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow);
  95. if (param().format == Param::Format::NCHW) {
  96. dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype);
  97. } else if (param().format == Param::Format::NHWC) {
  98. megdnn_assert(param().format == Param::Format::NHWC,
  99. "invalid pooling format");
  100. dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format);
  101. } else if (param().format == Param::Format::NCHW4 ||
  102. param().format == Param::Format::NCHW44) {
  103. dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format};
  104. } else if (param().format == Param::Format::NCHW88) {
  105. dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format};
  106. } else if (param().format == Param::Format::NCHW32) {
  107. dst = TensorLayout{{n, c / 32, oh, ow, 32}, src.dtype, src.format};
  108. } else if (param().format == Param::Format::CHWN4) {
  109. dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format};
  110. } else {
  111. megdnn_assert(param().format == Param::Format::NHWCD4,
  112. "invalid pooling format");
  113. dst = TensorLayout{{n, oh, c / 4, ow, 4}, src.dtype, src.format};
  114. }
  115. }
  116. void PoolingBase::check_layout_fwd(const TensorLayout& src,
  117. const TensorLayout& dst) {
  118. TensorLayout dst_expected;
  119. megdnn_assert_eq_dtype(src, dst);
  120. deduce_layout_fwd(src, dst_expected);
  121. megdnn_assert_eq_layout(dst_expected, dst);
  122. megdnn_assert(src.dtype == dst.dtype);
  123. megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT ||
  124. src.dtype == dtype::Int8() ||
  125. src.dtype.category() == DTypeCategory::QUANTIZED);
  126. }
  127. void PoolingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  128. deduce_layout_fwd(src, dst);
  129. }
  130. void PoolingForward::check_exec(const TensorLayout& src,
  131. const TensorLayout& dst,
  132. size_t workspace_in_bytes) {
  133. check_layout_fwd(src, dst);
  134. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  135. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  136. }
  137. void PoolingBackward::check_exec(const TensorLayout& src,
  138. const TensorLayout& dst,
  139. const TensorLayout& diff,
  140. const TensorLayout& grad,
  141. size_t workspace_in_bytes) {
  142. check_layout_fwd(src, dst);
  143. megdnn_assert_eq_layout(src, grad);
  144. megdnn_assert_eq_layout(dst, diff);
  145. auto required_workspace_in_bytes =
  146. get_workspace_in_bytes(src, dst, diff, grad);
  147. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  148. }
  149. } // namespace megdnn
  150. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台