You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

megbrain_pubapi.h 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. /**
  2. * \file python_module/src/cpp/megbrain_pubapi.h
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief public API for exposing megbrain internal data structures
  7. *
  8. * This is a pure header without compile-time dependencies.
  9. *
  10. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  11. */
  12. #pragma once
  13. #include <cstdint>
  14. #include <cstddef>
  15. namespace mgb {
  16. namespace pubapi {
  17. /*!
  18. * \brief a general callback that would be invoked exactly once
  19. *
  20. * During the invoke, the functor shoule release related memory
  21. */
  22. struct CallbackOnce {
  23. void (*fptr)(void *);
  24. void *user_data;
  25. //! invoke the callback and clean up the scene
  26. void consume() {
  27. fptr(user_data);
  28. fptr = nullptr;
  29. user_data = nullptr;
  30. }
  31. };
  32. //! tensor on a computing device
  33. class DeviceTensor {
  34. public:
  35. static constexpr uint32_t CURRENT_VERSION = 20190725;
  36. //! device type
  37. enum class Type: uint32_t {
  38. CPU, CUDA
  39. };
  40. enum class DataType: uint32_t {
  41. FLOAT32, FLOAT16, INT32, INT16, INT8, UINT8
  42. };
  43. enum class CopyDirection {
  44. SELF_TO_OTHER, OTHER_TO_SELF
  45. };
  46. struct CudaContext {
  47. int device; //! set to -1 in copy() to use current device
  48. void *stream; //!< set to nullptr for default stream
  49. };
  50. //! tensor descriptor
  51. struct Desc {
  52. Type type;
  53. DataType dtype;
  54. void *dev_ptr; //!< pointer to actual device buffer
  55. const size_t *shape; //!< pointer to shape array
  56. size_t ndim;
  57. //! only valid if type == Type::CUDA
  58. CudaContext cuda_ctx;
  59. };
  60. uint32_t _version0; //!< for consistency check
  61. // note: fields starting with underscore are for internal use only
  62. Desc desc;
  63. size_t size_bytes;
  64. /*!
  65. * \brief synchonize with the calling thread
  66. *
  67. * This must be called before forwarding memory for direct use
  68. *
  69. * \param strong whether to synchronoze the whole device (true), or
  70. * just the computing node (false). Currently it only affects
  71. * how cuda sync is performed.
  72. */
  73. void sync(bool strong = false) const {
  74. m_functable->sync(this, strong);
  75. }
  76. /*!
  77. * \brief copy to/from another buffer
  78. *
  79. * Note: the copy is performed on the comp node on which this tensor
  80. * resides and is always async.
  81. *
  82. * If \p direction is OTHER_TO_SELF and shape of this changes, then
  83. * the corresponding dev_ptr would also be updated.
  84. *
  85. * \param other the other buffer involved in the copy; if
  86. * \p direction is SELF_TO_OTHER, then only its type and
  87. * dev_ptr would be used
  88. * \param direction specify the direction to perform the copy
  89. */
  90. void copy(const Desc &other, CopyDirection direction) {
  91. m_functable->copy(this, other, direction);
  92. }
  93. /*!
  94. * \brief resize this tensor to given shape
  95. */
  96. void resize(size_t ndim, const size_t *shape) {
  97. Desc tmp;
  98. tmp.dev_ptr = nullptr;
  99. tmp.ndim = ndim;
  100. tmp.shape = shape;
  101. copy(tmp, CopyDirection::OTHER_TO_SELF);
  102. }
  103. //! name of dtype of this tensor
  104. const char* dtype_name() const { return dtype_name(desc.dtype); }
  105. //! name of given dtype
  106. const char* dtype_name(DataType dtype) const {
  107. return m_functable->dtype_name(dtype);
  108. }
  109. /*!
  110. * \brief forward memory from \p other directly to the underlying
  111. * storage
  112. *
  113. * This can only be used when there is a corresponding VarNode for
  114. * this DeviceTensor. (e.g. for the outputs of Craniotome oprs)
  115. */
  116. void forward_other_memory(
  117. const Desc &other, CallbackOnce deleter) const {
  118. m_functable->forward_other_memory(this, other, deleter);
  119. }
  120. /*!
  121. * \brief forward device buffer to \p dest directly and create a
  122. * tensor storage shared memory with m_dv_nd, it would be deleted
  123. * when calling deleter, so refcnt to data ptr could be managed
  124. * correctly.
  125. */
  126. void forward_to(
  127. void **dest, CallbackOnce* deleter) const {
  128. m_functable->forward_to(this, dest, deleter);
  129. }
  130. struct _Impl;
  131. private:
  132. // note: we use a func table to avoid symbol visibility problems and
  133. // linking hazards when built with other code base
  134. struct FuncTable {
  135. void (*sync)(const DeviceTensor*, bool);
  136. void (*copy)(DeviceTensor*, const Desc&, CopyDirection);
  137. void (*forward_other_memory)(const DeviceTensor*, const Desc&,
  138. CallbackOnce);
  139. const char* (*dtype_name)(DataType);
  140. void (*forward_to)(const DeviceTensor*, void**, CallbackOnce*);
  141. };
  142. bool m_readonly;
  143. void* m_dev_nd;
  144. void* m_varptr;
  145. FuncTable* m_functable;
  146. public:
  147. uint32_t _version1;
  148. };
  149. /*!
  150. * \brief reinterpret_cast raw pointer or pointer integer to mgb object and
  151. * check version
  152. * \return object pointer if the version is correct; nullptr if failed
  153. */
  154. template<typename T, typename S>
  155. T* as_versioned_obj(S &&val) {
  156. T *obj = reinterpret_cast<T*>(val);
  157. if (obj->_version0 != T::CURRENT_VERSION ||
  158. obj->_version1 != T::CURRENT_VERSION) {
  159. return nullptr;
  160. }
  161. return obj;
  162. }
  163. } // namespace pubapi
  164. } // namespace mgb
  165. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台