You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. /**
  2. * \file dnn/src/common/resize.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void ResizeBase::check_layout_fwd(const TensorLayout& src,
  15. const TensorLayout& dst) {
  16. auto errmsg = [&]() {
  17. return megdnn_layout_msg(src) + ", " + ", " + megdnn_layout_msg(dst);
  18. };
  19. MEGDNN_MARK_USED_VAR(errmsg);
  20. megdnn_assert(dst.dtype == src.dtype && dst.shape[0] == src.shape[0], "%s",
  21. errmsg().c_str());
  22. if (param().format == Param::Format::NCHW) {
  23. megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
  24. megdnn_assert(param().imode ==
  25. param::Resize::InterpolationMode::INTER_LINEAR);
  26. } else if (param().format == Param::Format::NHWC) {
  27. megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
  28. } else if (param().format == Param::Format::NCHW4) {
  29. megdnn_assert(src.ndim == 5);
  30. megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8);
  31. megdnn_assert(src.shape[4] == 4);
  32. megdnn_assert(dst.shape[4] == 4);
  33. } else {
  34. megdnn_assert(param().format == Param::Format::NHWCD4,
  35. "invalid resize tensor format");
  36. megdnn_assert(param().imode ==
  37. param::Resize::InterpolationMode::INTER_LINEAR);
  38. megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str());
  39. }
  40. }
  41. void Resize::check_exec(const TensorLayout& src, const TensorLayout& dst,
  42. size_t workspace_in_bytes) {
  43. check_layout_fwd(src, dst);
  44. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  45. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  46. }
  47. void ResizeBackward::check_exec(const TensorLayout& diff,
  48. const TensorLayout& grad,
  49. size_t workspace_in_bytes) {
  50. check_layout_fwd(grad, diff);
  51. auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
  52. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  53. megdnn_assert(param().format == Param::Format::NCHW &&
  54. grad.dtype == dtype::Float32(),
  55. "Backward resize only supports Float32 and NCHW.");
  56. }
  57. std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
  58. int idx) {
  59. //! copy from resize_cv.cpp
  60. float alpha = (idx + 0.5f) / scale - 0.5f;
  61. int origin_idx = static_cast<int>(floor(alpha));
  62. alpha -= origin_idx;
  63. if (origin_idx < 0) {
  64. origin_idx = 0;
  65. alpha = 0;
  66. } else if (origin_idx + 1 >= size) {
  67. origin_idx = size - 2;
  68. alpha = 1;
  69. }
  70. return {alpha, origin_idx};
  71. }
  72. } // namespace megdnn
  73. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台