You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /**
  2. * \file dnn/test/fallback/resize.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/resize.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/task_record_check.h"
  14. #include "test/fallback/fixture.h"
  15. namespace megdnn {
  16. namespace test {
  17. TEST_F(FALLBACK, RESIZE_CV) {
  18. using namespace resize;
  19. std::vector<TestArg> args = get_cv_args();
  20. Checker<Resize> checker(handle());
  21. for (auto&& arg : args) {
  22. checker.set_param(arg.param)
  23. .set_dtype(0, dtype::Uint8())
  24. .set_dtype(1, dtype::Uint8())
  25. .set_epsilon(1 + 1e-3)
  26. .execs({arg.src, arg.dst});
  27. }
  28. for (auto&& arg : args) {
  29. checker.set_param(arg.param)
  30. .set_dtype(0, dtype::Float32())
  31. .set_dtype(1, dtype::Float32())
  32. .execs({arg.src, arg.dst});
  33. }
  34. }
  35. TEST_F(FALLBACK, RESIZE_CV_RECORD) {
  36. using namespace resize;
  37. std::vector<TestArg> args = get_cv_args();
  38. TaskRecordChecker<Resize> checker(1);
  39. for (auto&& arg : args) {
  40. checker.set_param(arg.param)
  41. .set_dtype(0, dtype::Uint8())
  42. .set_dtype(1, dtype::Uint8())
  43. .set_epsilon(1 + 1e-3)
  44. .execs({arg.src, arg.dst});
  45. }
  46. for (auto&& arg : args) {
  47. checker.set_param(arg.param)
  48. .set_dtype(0, dtype::Float32())
  49. .set_dtype(1, dtype::Float32())
  50. .execs({arg.src, arg.dst});
  51. }
  52. }
  53. TEST_F(FALLBACK, RESIZE) {
  54. using namespace resize;
  55. std::vector<TestArg> args = get_args();
  56. Checker<Resize> checker(handle());
  57. for (auto&& arg : args) {
  58. checker.set_param(arg.param)
  59. .set_dtype(0, dtype::Uint8())
  60. .set_dtype(1, dtype::Uint8())
  61. .set_epsilon(1 + 1e-3)
  62. .execs({arg.src, arg.dst});
  63. }
  64. for (auto&& arg : args) {
  65. checker.set_param(arg.param)
  66. .set_dtype(0, dtype::Float32())
  67. .set_dtype(1, dtype::Float32())
  68. .execs({arg.src, arg.dst});
  69. }
  70. }
  71. TEST_F(FALLBACK, RESIZE_RECORD) {
  72. using namespace resize;
  73. std::vector<TestArg> args = get_args();
  74. TaskRecordChecker<Resize> checker(1);
  75. for (auto&& arg : args) {
  76. checker.set_param(arg.param)
  77. .set_dtype(0, dtype::Uint8())
  78. .set_dtype(1, dtype::Uint8())
  79. .set_epsilon(1 + 1e-3)
  80. .execs({arg.src, arg.dst});
  81. }
  82. for (auto&& arg : args) {
  83. checker.set_param(arg.param)
  84. .set_dtype(0, dtype::Float32())
  85. .set_dtype(1, dtype::Float32())
  86. .execs({arg.src, arg.dst});
  87. }
  88. }
  89. TEST_F(FALLBACK, RESIZE_NCHW_WITH_STRIDE) {
  90. param::Resize param;
  91. param.format = param::Resize::Format::NCHW;
  92. param.imode = param::Resize::InterpolationMode::LINEAR;
  93. Checker<Resize> checker(handle());
  94. checker.set_epsilon(1 + 1e-3).set_param(param);
  95. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  96. TensorShape dst_shape, DType dtype) {
  97. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  98. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  99. };
  100. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  101. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  102. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  103. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  104. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  105. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  106. }
  107. }
  108. TEST_F(FALLBACK, RESIZE_NCHW_WITH_STRIDE_RECORD) {
  109. param::Resize param;
  110. param.format = param::Resize::Format::NCHW;
  111. param.imode = param::Resize::InterpolationMode::LINEAR;
  112. TaskRecordChecker<Resize> checker(1);
  113. checker.set_epsilon(1 + 1e-3).set_param(param);
  114. auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
  115. TensorShape dst_shape, DType dtype) {
  116. checker.set_dtype(0, dtype).set_dtype(1, dtype).execl(
  117. {{src_shape, src_layout, dtype}, {dst_shape, dtype}});
  118. };
  119. for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8()}) {
  120. run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
  121. run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
  122. run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
  123. run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
  124. run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
  125. }
  126. }
  127. TEST_F(FALLBACK, RESIZE_NCHW4) {
  128. using namespace resize;
  129. auto args = get_nchw4_args();
  130. Checker<Resize> checker(handle());
  131. for (auto&& arg : args) {
  132. checker.set_param(arg.param)
  133. .set_dtype(0, dtype::QuantizedS8(1.0f))
  134. .set_dtype(1, dtype::QuantizedS8(1.0f))
  135. .set_epsilon(1 + 1e-3)
  136. .execs({arg.src, arg.dst});
  137. }
  138. }
  139. TEST_F(FALLBACK, RESIZE_NCHW4_RECORD) {
  140. using namespace resize;
  141. auto args = get_nchw4_args();
  142. TaskRecordChecker<Resize> checker(1);
  143. for (auto&& arg : args) {
  144. checker.set_param(arg.param)
  145. .set_dtype(0, dtype::QuantizedS8(1.0f))
  146. .set_dtype(1, dtype::QuantizedS8(1.0f))
  147. .set_epsilon(1 + 1e-3)
  148. .execs({arg.src, arg.dst});
  149. }
  150. }
  151. namespace {
  152. static void set_nchw_args(resize::IMode imode, std::vector<resize::TestArg>& args) {
  153. param::Resize param;
  154. param.format = param::Resize::Format::NCHW;
  155. param.imode = imode;
  156. rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
  157. args.emplace_back(
  158. param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul},
  159. TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul});
  160. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 20, 20});
  161. args.emplace_back(param, TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 7, 9});
  162. args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
  163. args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
  164. }
  165. } // namespace
  166. TEST_F(FALLBACK, RESIZE_NCHW_FP32) {
  167. std::vector<resize::TestArg> args;
  168. set_nchw_args(resize::IMode::INTER_LINEAR, args);
  169. set_nchw_args(resize::IMode::INTER_NEAREST, args);
  170. Checker<Resize> checker(handle());
  171. for (auto&& arg : args) {
  172. checker.set_param(arg.param)
  173. .set_dtype(0, dtype::Float32())
  174. .set_dtype(1, dtype::Float32())
  175. .execs({arg.src, arg.dst});
  176. }
  177. }
  178. TEST_F(FALLBACK, RESIZE_NCHW44_FP32) {
  179. std::vector<resize::TestArg> args = resize::get_nchw44_args();
  180. Checker<Resize> checker(handle());
  181. for (auto&& arg : args) {
  182. checker.set_param(arg.param)
  183. .set_dtype(0, dtype::Float32())
  184. .set_dtype(1, dtype::Float32())
  185. .execs({arg.src, arg.dst});
  186. }
  187. }
  188. } // namespace test
  189. } // namespace megdnn
  190. // vim: syntax=cpp.doxygen