You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * \file dnn/src/common/group_local.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void GroupLocalBase::deduce_layout_fwd(const TensorLayout &src,
  15. const TensorLayout &filter,
  16. TensorLayout &dst)
  17. {
  18. auto errmsg = [&]() {
  19. return megdnn_layout_msg(src) + ", "
  20. + megdnn_layout_msg(filter) + ", "
  21. + megdnn_layout_msg(dst) + ", "
  22. + megdnn_mangle("pad_h=") + std::to_string(param().pad_h) + ", "
  23. + megdnn_mangle("pad_w=") + std::to_string(param().pad_w) + ", "
  24. + megdnn_mangle("stride_h=") + std::to_string(param().stride_h) + ", "
  25. + megdnn_mangle("stride_w=") + std::to_string(param().stride_w);
  26. };
  27. MEGDNN_MARK_USED_VAR(errmsg);
  28. megdnn_assert_contiguous(src);
  29. megdnn_assert_contiguous(filter);
  30. megdnn_assert(param().mode == Mode::CROSS_CORRELATION,
  31. "only CROSS_CORRELATION mode is supported for glocal.");
  32. megdnn_assert(param().sparse == Param::Sparse::DENSE &&
  33. param().dilate_h == 1 && param().dilate_w == 1 &&
  34. src.dtype.category() == DTypeCategory::FLOAT &&
  35. src.dtype == dst.dtype,
  36. "unsupported conv param for Local opr");
  37. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  38. megdnn_assert(filter.ndim == 7_z, "%s", errmsg().c_str());
  39. size_t group = filter[0];
  40. size_t n = src[0];
  41. size_t ic = src[1];
  42. size_t ih = src[2];
  43. size_t iw = src[3];
  44. size_t oc = filter[6]*group;
  45. size_t oh = filter[1], ow = filter[2];
  46. megdnn_assert_eq_size_t(filter[0], group);
  47. megdnn_assert_eq_size_t(filter[3]*group, ic);
  48. size_t fh = filter[4], fw = filter[5];
  49. // (group, oh, ow, ic/group, fh, fw, oc/group)
  50. infer_conv_shape2d(ih, iw, fh, fw,
  51. param().stride_h, param().stride_w,
  52. param().pad_h, param().pad_w, oh, ow);
  53. dst = TensorLayout(TensorShape({n, oc, oh, ow}), src.dtype);
  54. }
  55. void GroupLocalBase::check_layout_fwd(const TensorLayout &src,
  56. const TensorLayout &filter,
  57. const TensorLayout &dst)
  58. {
  59. TensorLayout dst_expected{dst.dtype};
  60. megdnn_assert_eq_dtype(src, filter);
  61. megdnn_assert_eq_dtype(src, dst);
  62. deduce_layout_fwd(src, filter, dst_expected);
  63. megdnn_assert_eq_layout(dst_expected, dst);
  64. megdnn_assert(src.dtype == dtype::Float32() || MEGDNN_FLOAT16_SELECT(src.dtype == dtype::Float16(), true));
  65. }
  66. void GroupLocalForward::check_exec(const TensorLayout &src,
  67. const TensorLayout &filter,
  68. const TensorLayout &dst,
  69. size_t workspace_in_bytes)
  70. {
  71. check_layout_fwd(src, filter, dst);
  72. auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst);
  73. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  74. }
  75. void GroupLocalBackwardData::check_exec(const TensorLayout &filter,
  76. const TensorLayout &diff,
  77. const TensorLayout &grad,
  78. size_t workspace_in_bytes)
  79. {
  80. check_layout_fwd(grad, filter, diff);
  81. auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
  82. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  83. }
  84. void GroupLocalBackwardFilter::check_exec(const TensorLayout &src,
  85. const TensorLayout &diff,
  86. const TensorLayout &grad,
  87. size_t workspace_in_bytes)
  88. {
  89. check_layout_fwd(src, grad, diff);
  90. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  91. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  92. }
  93. } // namespace megdnn
  94. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台