You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.cpp 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * \file dnn/src/common/padding.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "megdnn/oprs/general.h"
  14. #include "megdnn/thin/small_vector.h"
  15. #include "src/common/opr_param_defs_enumv.cuh"
  16. #include "src/common/utils.h"
  17. namespace megdnn {
  18. using padding_param = megdnn::param_enumv::Padding;
  19. void PaddingForward::forward_check_exec(
  20. const TensorLayout& src, const TensorLayout& dst) {
  21. check_exec(src, dst);
  22. megdnn_assert(
  23. src.dtype.enumv() != DTypeEnum::Bool &&
  24. src.dtype.enumv() != DTypeEnum::IntB1 &&
  25. src.dtype.enumv() != DTypeEnum::IntB2 &&
  26. src.dtype.enumv() != DTypeEnum::IntB4,
  27. "unsupported %s dtype for forward padding opr", src.dtype.name());
  28. }
  29. void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  30. SmallVector<size_t> offsets(get_offsets());
  31. TensorShape dst_shape;
  32. switch (src.ndim) {
  33. case 1:
  34. dst_shape = {src.shape[0] + offsets[0] + offsets[1]};
  35. break;
  36. case 2:
  37. dst_shape = {
  38. src.shape[0] + offsets[0] + offsets[1],
  39. src.shape[1] + offsets[2] + offsets[3]};
  40. break;
  41. case 3:
  42. dst_shape = {
  43. src.shape[0] + offsets[0] + offsets[1],
  44. src.shape[1] + offsets[2] + offsets[3],
  45. src.shape[2] + offsets[4] + offsets[5]};
  46. break;
  47. case 4:
  48. dst_shape = {
  49. src.shape[0] + offsets[0] + offsets[1],
  50. src.shape[1] + offsets[2] + offsets[3],
  51. src.shape[2] + offsets[4] + offsets[5],
  52. src.shape[3] + offsets[6] + offsets[7]};
  53. break;
  54. case 5:
  55. dst_shape = {
  56. src.shape[0] + offsets[0] + offsets[1],
  57. src.shape[1] + offsets[2] + offsets[3],
  58. src.shape[2] + offsets[4] + offsets[5],
  59. src.shape[3] + offsets[6] + offsets[7],
  60. src.shape[4] + offsets[8] + offsets[9]};
  61. break;
  62. case 6:
  63. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  64. src.shape[1] + offsets[2] + offsets[3],
  65. src.shape[2] + offsets[4] + offsets[5],
  66. src.shape[3] + offsets[6] + offsets[7],
  67. src.shape[4] + offsets[8] + offsets[9],
  68. src.shape[5] + offsets[10] + offsets[11]};
  69. break;
  70. case 7:
  71. dst_shape = {src.shape[0] + offsets[0] + offsets[1],
  72. src.shape[1] + offsets[2] + offsets[3],
  73. src.shape[2] + offsets[4] + offsets[5],
  74. src.shape[3] + offsets[6] + offsets[7],
  75. src.shape[4] + offsets[8] + offsets[9],
  76. src.shape[5] + offsets[10] + offsets[11],
  77. src.shape[6] + offsets[12] + offsets[13]};
  78. break;
  79. default:
  80. megdnn_assert(false, "invalid tensor ndim %zu", src.ndim);
  81. break;
  82. }
  83. dst = TensorLayout(dst_shape, src.dtype);
  84. }
  85. void PaddingBackward::backward_check_exec(
  86. const TensorLayout& src, const TensorLayout& dst) {
  87. check_exec(dst, src);
  88. megdnn_assert(
  89. src.dtype.enumv() == DTypeEnum::Float32 DNN_INC_FLOAT16(
  90. || src.dtype.enumv() == DTypeEnum::Float16 ||
  91. src.dtype.enumv() == DTypeEnum::BFloat16),
  92. "unsupported %s dtype for forward padding opr", src.dtype.name());
  93. }
  94. SmallVector<size_t> PaddingBase::get_offsets() {
  95. SmallVector<size_t> offsets = {param().front_offset_dim0, param().back_offset_dim0,
  96. param().front_offset_dim1, param().back_offset_dim1,
  97. param().front_offset_dim2, param().back_offset_dim2,
  98. param().front_offset_dim3, param().back_offset_dim3,
  99. param().front_offset_dim4, param().back_offset_dim4,
  100. param().front_offset_dim5, param().back_offset_dim5,
  101. param().front_offset_dim6, param().back_offset_dim6};
  102. return offsets;
  103. }
  104. void PaddingBase::check_exec(const TensorLayout& src, const TensorLayout& dst) {
  105. SmallVector<size_t> offsets(get_offsets());
  106. // make sure the src and dst tensor not empty
  107. megdnn_assert(src.ndim != 0 && dst.ndim != 0);
  108. // make sure src and dst is same dtype
  109. megdnn_assert_eq_dtype(src, dst);
  110. // make sure src and dst is same ndim
  111. megdnn_assert(
  112. src.ndim == dst.ndim, "the src.ndim = %zu the dst.ndim = %zu", src.ndim,
  113. dst.ndim);
  114. // make sure in every dimension dst is equal or greater than src
  115. for (size_t i = 0; i < src.ndim; ++i) {
  116. megdnn_assert(
  117. dst.shape[i] == src.shape[i] + offsets[i * 2] + offsets[i * 2 + 1]);
  118. }
  119. // check the padding mode is valid
  120. megdnn_assert(
  121. static_cast<uint32_t>(param().padding_mode) ==
  122. padding_param::PaddingMode::REFLECT ||
  123. static_cast<uint32_t>(param().padding_mode) ==
  124. padding_param::PaddingMode::REPLICATE ||
  125. static_cast<uint32_t>(param().padding_mode) ==
  126. padding_param::PaddingMode::CONSTANT,
  127. "unsupported padding mode");
  128. // addition check for reflect padding, make sure the reflected index is
  129. // valid
  130. if (static_cast<uint32_t>(param().padding_mode) ==
  131. padding_param::PaddingMode::REFLECT) {
  132. for (size_t i = 0; i < src.ndim; ++i) {
  133. megdnn_assert(
  134. offsets[i * 2] < src.shape[i] &&
  135. dst.shape[i] - offsets[i * 2] - src.shape[i] < src.shape[i]);
  136. }
  137. }
  138. }
  139. } // namespace megdnn