You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api_cache.h 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. /**
  2. * \file dnn/src/cuda/api_cache.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/common/api_cache.h"
  14. #include "src/cuda/cudnn_wrapper.h"
  15. namespace megdnn {
  16. class CudnnConvDescParam {
  17. public:
  18. cudnnConvolutionDescriptor_t value;
  19. Empty serialize(StringSerializer& ser, Empty) {
  20. constexpr int nbDims = MEGDNN_MAX_NDIM;
  21. int padA[MEGDNN_MAX_NDIM];
  22. int strideA[MEGDNN_MAX_NDIM];
  23. int dilationA[MEGDNN_MAX_NDIM];
  24. cudnnConvolutionMode_t mode;
  25. cudnnDataType_t computeType;
  26. cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA,
  27. dilationA, &mode, &computeType);
  28. ser.write_plain(nbDims);
  29. for (int i = 0; i < nbDims; ++i) {
  30. ser.write_plain(padA[i]);
  31. ser.write_plain(strideA[i]);
  32. ser.write_plain(dilationA[i]);
  33. }
  34. ser.write_plain(mode);
  35. ser.write_plain(computeType);
  36. return Empty{};
  37. }
  38. Empty deserialize(StringSerializer& ser, Empty) {
  39. int ndim = ser.read_plain<int>();
  40. int padA[MEGDNN_MAX_NDIM];
  41. int strideA[MEGDNN_MAX_NDIM];
  42. int dilationA[MEGDNN_MAX_NDIM];
  43. for (int i = 0; i < ndim; ++i) {
  44. padA[i] = ser.read_plain<int>();
  45. strideA[i] = ser.read_plain<int>();
  46. dilationA[i] = ser.read_plain<int>();
  47. }
  48. cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>();
  49. cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>();
  50. cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA,
  51. mode, computeType);
  52. return Empty{};
  53. }
  54. };
  55. class CudnnTensorDescParam {
  56. public:
  57. cudnnTensorDescriptor_t value;
  58. Empty serialize(StringSerializer& ser, Empty) {
  59. constexpr int nbDims = MEGDNN_MAX_NDIM;
  60. cudnnDataType_t dataType;
  61. int dimA[MEGDNN_MAX_NDIM];
  62. int strideA[MEGDNN_MAX_NDIM];
  63. cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA,
  64. strideA);
  65. ser.write_plain(nbDims);
  66. for (int i = 0; i < nbDims; ++i) {
  67. ser.write_plain(dimA[i]);
  68. ser.write_plain(strideA[i]);
  69. }
  70. ser.write_plain(dataType);
  71. return Empty{};
  72. }
  73. Empty deserialize(StringSerializer& ser, Empty) {
  74. constexpr int nbDims = MEGDNN_MAX_NDIM;
  75. cudnnDataType_t dataType;
  76. int dimA[MEGDNN_MAX_NDIM];
  77. int strideA[MEGDNN_MAX_NDIM];
  78. nbDims = ser.read_plain<int>();
  79. for (int i = 0; i < nbDims; ++i) {
  80. dimA[i] = ser.read_plain<int>();
  81. strideA[i] = ser.read_plain<int>();
  82. }
  83. dataType = ser.read_plain<cudnnDataType_t>();
  84. cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA);
  85. return Empty{};
  86. }
  87. };
  88. class CudnnFilterDescParam {
  89. public:
  90. cudnnFilterDescriptor_t value;
  91. Empty serialize(StringSerializer& ser, Empty) {
  92. constexpr int nbDims = MEGDNN_MAX_NDIM;
  93. cudnnDataType_t dataType;
  94. cudnnTensorFormat_t format;
  95. int filterDimA[MEGDNN_MAX_NDIM];
  96. cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims,
  97. filterDimA);
  98. ser.write_plain(nbDims);
  99. for (int i = 0; i < nbDims; ++i) {
  100. ser.write_plain(filterDimA[i]);
  101. }
  102. ser.write_plain(dataType);
  103. ser.write_plain(format);
  104. return Empty{};
  105. }
  106. Empty deserialize(StringSerializer& ser, Empty) {
  107. constexpr int nbDims = MEGDNN_MAX_NDIM;
  108. cudnnDataType_t dataType;
  109. cudnnTensorFormat_t format;
  110. int filterDimA[MEGDNN_MAX_NDIM];
  111. nbDims = ser.read_plain<int>();
  112. for (int i = 0; i < nbDims; ++i) {
  113. filterDimA[i] = ser.read_plain<int>();
  114. }
  115. dataType = ser.read_plain<cudnnDataType_t>();
  116. format = ser.read_plain<cudnnTensorFormat_t>();
  117. cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA);
  118. return Empty{};
  119. }
  120. };
  121. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台