You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 3.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. /**
  2. * \file dnn/src/cuda/convolution/helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./helper.h"
  12. using namespace megdnn;
  13. using namespace cuda;
  14. using namespace convolution;
  15. bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) {
  16. if (args.src_layout->dtype == args.filter_layout->dtype &&
  17. args.src_layout->dtype == dtype::BFloat16()) {
  18. return false;
  19. }
  20. // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
  21. // on Tegra K1.
  22. if (args.handle->is_tegra_k1())
  23. return false;
  24. // TODO: We only support NCHW format now. It seems cuDNN provides support
  25. // for NHWC as well.
  26. if (args.filter_meta.format == param::Convolution::Format::NCHW4) {
  27. if (args.dst_layout->dtype.enumv() != DTypeEnum::Int8 &&
  28. args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) {
  29. return false;
  30. }
  31. } else if (args.filter_meta.format != param::Convolution::Format::NCHW) {
  32. return false;
  33. }
  34. auto& fm = args.filter_meta;
  35. bool supported = true;
  36. supported &= (fm.spatial_ndim == 2);
  37. #if CUDNN_VERSION < 7000
  38. supported &= (fm.group == 1);
  39. #endif
  40. #if CUDNN_VERSION < 7500
  41. supported &= (fm.dilation[0] == 1 && fm.dilation[1] == 1);
  42. #endif
  43. return supported;
  44. }
  45. SmallVector<size_t> convolution::matmul_get_workspace_bundle(
  46. const ForwardSizeArgs &args) {
  47. auto dtype = args.src_layout->dtype;
  48. auto &&fm = args.filter_meta;
  49. megdnn_assert(fm.group == 1);
  50. auto N = args.src_layout->shape[0];
  51. auto OC = fm.ocpg,
  52. IC = fm.icpg,
  53. FH = fm.spatial[0],
  54. FW = fm.spatial[1];
  55. auto OH = args.dst_layout->shape[2],
  56. OW = args.dst_layout->shape[3];
  57. SmallVector<size_t> sizes{
  58. dtype.size() * args.dst_layout->total_nr_elems(),
  59. dtype.size() * IC*FH*FW*OH*OW*N
  60. };
  61. if (args.filter_meta.should_flip) {
  62. sizes.push_back(dtype.size() * OC * IC * FH * FW);
  63. }
  64. return sizes;
  65. }
  66. void convolution::flip_filter(const ForwardSizeArgs &args,
  67. const Workspace &workspace, void *&raw_ptr) {
  68. auto &&fm = args.filter_meta;
  69. megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2);
  70. auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1];
  71. auto dtype = fm.dtype;
  72. megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW);
  73. TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}},
  74. dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout};
  75. dst.layout.stride[2] = -dst.layout.stride[2];
  76. dst.layout.stride[3] = -dst.layout.stride[3];
  77. args.handle->relayout_opr()->exec(src, dst);
  78. raw_ptr = workspace.raw_ptr;
  79. }
  80. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台