You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.cpp 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. /**
  2. * \file dnn/src/common/padding.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "megdnn/oprs/general.h"
  14. #include "megdnn/thin/small_vector.h"
  15. #include "src/common/opr_param_defs_enumv.cuh"
  16. #include "src/common/utils.h"
  17. namespace megdnn {
  18. using padding_param = megdnn::param_enumv::Padding;
  19. void PaddingForward::forward_check_exec(const TensorLayout& src,
  20. const TensorLayout& dst) {
  21. check_exec(src, dst);
  22. megdnn_assert(src.dtype.enumv() != DTypeEnum::Bool &&
  23. src.dtype.enumv() != DTypeEnum::IntB1 &&
  24. src.dtype.enumv() != DTypeEnum::IntB2 &&
  25. src.dtype.enumv() != DTypeEnum::IntB4,
  26. "unsupported %s dtype for forward padding opr",
  27. src.dtype.name());
  28. }
  29. void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  30. SmallVector<size_t> offsets(get_offsets());
  31. TensorShape dst_shape;
  32. switch (src.ndim) {
  33. case 1:
  34. dst_shape = {src.shape[0] + offsets[0] + offsets[1]};
  35. break;
  36. case 2:
  37. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  38. src.shape[1] + offsets[2] + offsets[3]};
  39. break;
  40. case 3:
  41. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  42. src.shape[1] + offsets[2] + offsets[3],
  43. src.shape[2] + offsets[4] + offsets[5]};
  44. break;
  45. case 4:
  46. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  47. src.shape[1] + offsets[2] + offsets[3],
  48. src.shape[2] + offsets[4] + offsets[5],
  49. src.shape[3] + offsets[6] + offsets[7]};
  50. break;
  51. case 5:
  52. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  53. src.shape[1] + offsets[2] + offsets[3],
  54. src.shape[2] + offsets[4] + offsets[5],
  55. src.shape[3] + offsets[6] + offsets[7],
  56. src.shape[4] + offsets[8] + offsets[9]};
  57. break;
  58. case 6:
  59. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  60. src.shape[1] + offsets[2] + offsets[3],
  61. src.shape[2] + offsets[4] + offsets[5],
  62. src.shape[3] + offsets[6] + offsets[7],
  63. src.shape[4] + offsets[8] + offsets[9],
  64. src.shape[5] + offsets[10] + offsets[11]};
  65. break;
  66. case 7:
  67. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  68. src.shape[1] + offsets[2] + offsets[3],
  69. src.shape[2] + offsets[4] + offsets[5],
  70. src.shape[3] + offsets[6] + offsets[7],
  71. src.shape[4] + offsets[8] + offsets[9],
  72. src.shape[5] + offsets[10] + offsets[11],
  73. src.shape[6] + offsets[12] + offsets[13]};
  74. break;
  75. default:
  76. megdnn_assert(false, "invalid tensor ndim %zu", src.ndim);
  77. break;
  78. }
  79. dst = TensorLayout(dst_shape, src.dtype);
  80. }
  81. void PaddingBackward::backward_check_exec(const TensorLayout& src,
  82. const TensorLayout& dst) {
  83. check_exec(dst, src);
  84. megdnn_assert(src.dtype.enumv() ==
  85. DTypeEnum::Float32 DNN_INC_FLOAT16(
  86. || src.dtype.enumv() == DTypeEnum::Float16 ||
  87. src.dtype.enumv() == DTypeEnum::BFloat16),
  88. "unsupported %s dtype for forward padding opr",
  89. src.dtype.name());
  90. }
  91. SmallVector<size_t> PaddingBase::get_offsets() {
  92. SmallVector<size_t> offsets = {
  93. param().front_offset_dim0, param().back_offset_dim0,
  94. param().front_offset_dim1, param().back_offset_dim1,
  95. param().front_offset_dim2, param().back_offset_dim2,
  96. param().front_offset_dim3, param().back_offset_dim3,
  97. param().front_offset_dim4, param().back_offset_dim4,
  98. param().front_offset_dim5, param().back_offset_dim5,
  99. param().front_offset_dim6, param().back_offset_dim6};
  100. return offsets;
  101. }
  102. void PaddingBase::check_exec(const TensorLayout& src, const TensorLayout& dst) {
  103. SmallVector<size_t> offsets(get_offsets());
  104. // make sure the src and dst tensor not empty
  105. megdnn_assert(src.ndim != 0 && dst.ndim != 0);
  106. // make sure src and dst is same dtype
  107. megdnn_assert_eq_dtype(src, dst);
  108. // make sure src and dst is same ndim
  109. megdnn_assert(src.ndim == dst.ndim, "the src.ndim = %zu the dst.ndim = %zu",
  110. src.ndim, dst.ndim);
  111. // make sure in every dimension dst is equal or greater than src
  112. for (size_t i = 0; i < src.ndim; ++i) {
  113. megdnn_assert(dst.shape[i] ==
  114. src.shape[i] + offsets[i * 2] + offsets[i * 2 + 1]);
  115. }
  116. // check the padding mode is valid
  117. megdnn_assert(static_cast<uint32_t>(param().padding_mode) ==
  118. padding_param::PaddingMode::REFLECT ||
  119. static_cast<uint32_t>(param().padding_mode) ==
  120. padding_param::PaddingMode::REPLICATE ||
  121. static_cast<uint32_t>(param().padding_mode) ==
  122. padding_param::PaddingMode::CONSTANT,
  123. "unsupported padding mode");
  124. // addition check for reflect padding, make sure the reflected index is
  125. // valid
  126. if (static_cast<uint32_t>(param().padding_mode) ==
  127. padding_param::PaddingMode::REFLECT) {
  128. for (size_t i = 0; i < src.ndim; ++i) {
  129. megdnn_assert(offsets[i * 2] < src.shape[i] &&
  130. dst.shape[i] - offsets[i * 2] - src.shape[i] <
  131. src.shape[i]);
  132. }
  133. }
  134. }
  135. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台