You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

separableConv.cpp 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * \file dnn/src/common/separableConv.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void SeparableConvBase::deduce_layout_fwd(const TensorLayout &src,
  15. const TensorLayout &filter_x,
  16. const TensorLayout &filter_y,
  17. TensorLayout &dst)
  18. {
  19. auto errmsg = [&]() {
  20. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter_x) +
  21. ", " + megdnn_layout_msg(dst) + ", " +
  22. "is_xcorr=" + "borderMode=" +
  23. std::to_string((param().mode == Mode::CROSS_CORRELATION)) +
  24. ", " + std::to_string((int)(param().borderMode)) + ", " +
  25. "pad_h=" + std::to_string(param().pad_h) + ", " +
  26. "pad_w=" + std::to_string(param().pad_w) + ", " +
  27. "stride_h=" + std::to_string(param().stride_h) + ", " +
  28. "stride_w=" + std::to_string(param().stride_w);
  29. };
  30. MEGDNN_MARK_USED_VAR(errmsg);
  31. megdnn_assert_contiguous(src);
  32. megdnn_assert_contiguous(filter_x);
  33. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  34. megdnn_assert(filter_x.ndim == 4_z, "%s", errmsg().c_str());
  35. size_t n = src[0];
  36. size_t ic = src[1];
  37. size_t ih = src[2];
  38. size_t iw = src[3];
  39. size_t oc = filter_x[0];
  40. megdnn_assert_eq_layout(filter_x, filter_y);
  41. megdnn_assert(filter_x[1] == ic, "%s", errmsg().c_str());
  42. size_t fw = filter_x[3];
  43. size_t fh = fw;
  44. size_t sh = this->param().stride_h;
  45. size_t sw = this->param().stride_w;
  46. size_t ph = this->param().pad_h;
  47. size_t pw = this->param().pad_w;
  48. size_t oh, ow;
  49. infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow);
  50. dst = TensorLayout(TensorShape({n, oc, oh, ow}), src.dtype);
  51. }
  52. void SeparableConvBase::check_layout_fwd(const TensorLayout &src,
  53. const TensorLayout &filter_x,
  54. const TensorLayout &filter_y,
  55. const TensorLayout &dst)
  56. {
  57. TensorLayout dst_expected;
  58. megdnn_assert_eq_dtype(src, filter_x);
  59. megdnn_assert_eq_dtype(src, filter_y);
  60. megdnn_assert_eq_layout(filter_x, filter_y);
  61. megdnn_assert_eq_dtype(src, dst);
  62. deduce_layout_fwd(src, filter_x, filter_y, dst_expected);
  63. megdnn_assert_eq_layout(dst_expected, dst);
  64. }
  65. void SeparableConvForward::deduce_layout(const TensorLayout &src,
  66. const TensorLayout &filter_x,
  67. const TensorLayout &filter_y,
  68. TensorLayout &dst)
  69. {
  70. deduce_layout_fwd(src, filter_x, filter_y, dst);
  71. }
  72. void SeparableConvForward::check_exec(const TensorLayout &src,
  73. const TensorLayout &filter_x,
  74. const TensorLayout &filter_y,
  75. const TensorLayout &dst,
  76. size_t workspace_in_bytes)
  77. {
  78. check_layout_fwd(src, filter_x, filter_y, dst);
  79. auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter_x, filter_y, dst);
  80. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  81. }
  82. } // namespace megdnn
  83. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台