You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

megbrain_pubapi.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. /**
  2. * \file python_module/src/cpp/megbrain_pubapi.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "./megbrain_pubapi.h"
  10. #include "./megbrain_pubapi_internal.h"
  11. #include "megbrain/tensor.h"
  12. #include "megbrain/graph/var_node.h"
  13. #include "megbrain/comp_node_env.h"
  14. namespace {
  15. class DeleteDispatcher final : public mgb::CompNodeDepedentObject {
  16. mgb::thin_function<void()> m_deleter;
  17. mgb::CompNode m_comp_node;
  18. std::atomic<bool> done;
  19. std::shared_ptr<void> on_comp_node_finalize() override {
  20. bool _ = false;
  21. if (done.compare_exchange_strong(_, true)) {
  22. m_deleter();
  23. }
  24. return {};
  25. }
  26. public:
  27. explicit DeleteDispatcher(mgb::thin_function<void()>&& deleter,
  28. mgb::CompNode cn)
  29. : m_deleter(std::move(deleter)), m_comp_node(cn) {
  30. done.store(false);
  31. }
  32. void trigger() {
  33. bool _ = false;
  34. if (done.compare_exchange_strong(_, true)) {
  35. if (!is_finalized()) {
  36. m_comp_node.add_callback(std::move(m_deleter));
  37. } else {
  38. m_deleter();
  39. }
  40. }
  41. }
  42. };
  43. } // namespace
  44. using namespace mgb;
  45. pubapi::DeviceTensor::DataType mgb::dtype_mgb2pubapi(DType dtype) {
  46. using DevDType = pubapi::DeviceTensor::DataType;
  47. switch (dtype.enumv()) {
  48. #define o(s, t) \
  49. case DTypeEnum::s: \
  50. return DevDType::t
  51. o(Float32, FLOAT32);
  52. o(Float16, FLOAT16);
  53. o(Int32, INT32);
  54. o(Int16, INT16);
  55. o(Int8, INT8);
  56. o(Uint8, UINT8);
  57. #undef o
  58. default:
  59. mgb_throw(MegBrainError, "dtype %s not implemented for pubapi",
  60. dtype.name());
  61. }
  62. }
  63. struct pubapi::DeviceTensor::_Impl {
  64. static TensorShape desc_shape_to_tensor_shape(const DeviceTensor::Desc &desc) {
  65. TensorShape shape;
  66. mgb_assert(desc.ndim && desc.ndim <= TensorShape::MAX_NDIM,
  67. "invalid ndim: %zu", desc.ndim);
  68. shape.ndim = desc.ndim;
  69. for (size_t i = 0; i < desc.ndim; ++ i) {
  70. shape[i] = desc.shape[i];
  71. }
  72. return shape;
  73. }
  74. #if MGB_CUDA
  75. class CudaCurrentDeviceRestore {
  76. int m_orig_dev = -1;
  77. public:
  78. CudaCurrentDeviceRestore(CompNode cn) {
  79. if (cn.device_type() == CompNode::DeviceType::CUDA) {
  80. MGB_CUDA_CHECK(cudaGetDevice(&m_orig_dev));
  81. }
  82. }
  83. ~CudaCurrentDeviceRestore() {
  84. if (m_orig_dev != -1) {
  85. cudaSetDevice(m_orig_dev);
  86. }
  87. }
  88. };
  89. #else
  90. class CudaCurrentDeviceRestore {
  91. public:
  92. CudaCurrentDeviceRestore(CompNode) {
  93. }
  94. };
  95. #endif
  96. static void sync(const DeviceTensor *self, bool strong) {
  97. CompNode cn;
  98. if (self->m_dev_nd) {
  99. cn = static_cast<DeviceTensorND*>(self->m_dev_nd)->comp_node();
  100. } else {
  101. mgb_assert(self->m_varptr);
  102. cn = static_cast<cg::VarNode*>(self->m_varptr)->comp_node();
  103. }
  104. CudaCurrentDeviceRestore cuda_dev_restore{cn};
  105. cn.sync();
  106. #if MGB_CUDA
  107. if (strong && cn.device_type() == CompNode::DeviceType::CUDA) {
  108. cn.activate();
  109. MGB_CUDA_CHECK(cudaDeviceSynchronize());
  110. }
  111. #endif
  112. }
  113. static const char* dtype_name(DataType dtype) {
  114. switch (dtype) {
  115. #define on(c) \
  116. case DataType::c: \
  117. return #c
  118. on(FLOAT32);
  119. on(FLOAT16);
  120. on(INT32);
  121. on(INT16);
  122. on(INT8);
  123. on(UINT8);
  124. #undef on
  125. default:
  126. mgb_throw(MegBrainError, "invalid pubapi dtype enum: %d",
  127. static_cast<int>(dtype));
  128. }
  129. }
  130. static void copy(
  131. DeviceTensor *self, const Desc &other, CopyDirection direction) {
  132. mgb_assert(self->desc.dtype == other.dtype, "dtype mismatch: %s vs %s",
  133. self->dtype_name(), dtype_name(other.dtype));
  134. mgb_assert(self->m_varptr || self->m_dev_nd);
  135. const DeviceTensorND *dv;
  136. if (direction == CopyDirection::OTHER_TO_SELF) {
  137. mgb_assert(!self->m_readonly, "can not copy into readonly tensor");
  138. auto shape = desc_shape_to_tensor_shape(other);
  139. if (self->m_varptr) {
  140. auto var = static_cast<cg::VarNode*>(self->m_varptr);
  141. dv = &var->shape_alloc(shape).dev_tensor();
  142. } else {
  143. dv = static_cast<DeviceTensorND*>(self->m_dev_nd);
  144. mgb_assert(dv->shape().eq_shape(shape),
  145. "copy dest tensor shape is %s, but source shape is %s",
  146. dv->shape().to_string().c_str(), shape.to_string().c_str());
  147. }
  148. mgb_assert(self->desc.dtype == dtype_mgb2pubapi(dv->dtype()));
  149. self->desc.dev_ptr = dv->raw_ptr();
  150. self->desc.ndim = dv->shape().ndim;
  151. self->desc.shape = dv->shape().shape;
  152. if (!other.dev_ptr) {
  153. // used in resize()
  154. return;
  155. }
  156. } else {
  157. mgb_assert(direction == CopyDirection::SELF_TO_OTHER);
  158. if (self->m_varptr) {
  159. dv = &static_cast<cg::VarNode*>(self->m_varptr)->dev_tensor();
  160. } else {
  161. dv = static_cast<DeviceTensorND*>(self->m_dev_nd);
  162. }
  163. }
  164. mgb_assert(dv->layout().is_contiguous());
  165. auto size = dv->layout().span().dist_byte();
  166. auto cn = dv->comp_node();
  167. CudaCurrentDeviceRestore cuda_dev_restore{cn};
  168. void *dst = dv->raw_ptr(), *src = other.dev_ptr;
  169. if (direction == CopyDirection::SELF_TO_OTHER) {
  170. std::swap(dst, src);
  171. }
  172. #if !MGB_CUDA
  173. mgb_assert(other.type != Type::CUDA, "cuda disabled at compile time");
  174. #endif
  175. auto &&desc = self->desc;
  176. if (other.type == desc.type) {
  177. #if MGB_CUDA
  178. if (desc.type == Type::CUDA) {
  179. int dev = desc.cuda_ctx.device;
  180. if (dev == -1) {
  181. MGB_CUDA_CHECK(cudaGetDevice(&dev));
  182. }
  183. mgb_assert(dev == other.cuda_ctx.device,
  184. "DeviceTensor copy must be on the same device; "
  185. "got %d vs %d", dev, other.cuda_ctx.device);
  186. }
  187. #endif
  188. cn.peer_copy_to(cn, dst, src, size);
  189. } else {
  190. if ((desc.type == Type::CPU && other.type == Type::CUDA &&
  191. direction == CopyDirection::SELF_TO_OTHER) ||
  192. (other.type == Type::CPU && desc.type == Type::CUDA &&
  193. direction == CopyDirection::OTHER_TO_SELF)) {
  194. cn.copy_to_device(dst, src, size);
  195. } else {
  196. mgb_assert((desc.type == Type::CUDA && other.type == Type::CPU &&
  197. direction == CopyDirection::SELF_TO_OTHER) ||
  198. (other.type == Type::CUDA && desc.type == Type::CPU &&
  199. direction == CopyDirection::OTHER_TO_SELF));
  200. cn.copy_to_host(dst, src, size);
  201. }
  202. }
  203. }
  204. static void forward_other_memory(
  205. const DeviceTensor *self,
  206. const Desc &other, CallbackOnce deleter) {
  207. mgb_assert(self->desc.dtype == other.dtype, "dtype mismatch: %s vs %s",
  208. self->dtype_name(), dtype_name(other.dtype));
  209. auto deleter_wrap = [deleter]() mutable { deleter.consume(); };
  210. thin_function<void(void*)> deleter_dispatch;
  211. if (self->desc.type == Type::CPU) {
  212. CompNode cn{};
  213. if (self->m_varptr) {
  214. cn = static_cast<cg::VarNode*>(self->m_varptr)->comp_node();
  215. } else {
  216. cn = static_cast<DeviceTensorND*>(self->m_dev_nd)->comp_node();
  217. }
  218. deleter_dispatch = [d = new DeleteDispatcher(deleter_wrap, cn)](void*) {
  219. d->trigger();
  220. delete d;
  221. };
  222. } else {
  223. deleter_dispatch = [deleter_wrap](void*) mutable { deleter_wrap(); };
  224. }
  225. auto shape = desc_shape_to_tensor_shape(other);
  226. if (self->m_varptr) {
  227. auto var = static_cast<cg::VarNode*>(self->m_varptr);
  228. DeviceTensorStorage storage;
  229. storage.reset(var->comp_node(),
  230. shape.total_nr_elems() * var->dtype().size(),
  231. {static_cast<dt_byte*>(other.dev_ptr), deleter_dispatch});
  232. DeviceTensorND tensor;
  233. tensor.reset(storage, {shape, var->dtype()});
  234. var->reset_dev_tensor_from_tensor(tensor);
  235. } else {
  236. DeviceTensorND& tensor = *static_cast<DeviceTensorND*>(self->m_dev_nd);
  237. DeviceTensorStorage storage;
  238. size_t dtype_size = tensor.layout().dtype.size();
  239. storage.reset(tensor.comp_node(),
  240. shape.total_nr_elems() * dtype_size,
  241. {static_cast<dt_byte*>(other.dev_ptr), deleter_dispatch});
  242. tensor.reset(storage, {shape, tensor.layout().dtype});
  243. }
  244. }
  245. static void forward_to(
  246. const DeviceTensor *self,
  247. void **dest, CallbackOnce* deleter) {
  248. auto orig_dv_ptr = static_cast<DeviceTensorStorage*>(self->m_dev_nd);
  249. *dest = orig_dv_ptr->ptr();
  250. mgb_assert(*dest == self->desc.dev_ptr);
  251. deleter->user_data = new DeviceTensorStorage(*orig_dv_ptr);
  252. deleter->fptr = [](void* ptr) {
  253. delete reinterpret_cast<DeviceTensorStorage*>(ptr);
  254. };
  255. }
  256. static void init_tensor(pubapi::DeviceTensor& dest, DeviceTensorND* tensor,
  257. VarNode* var, bool readonly) {
  258. memset(&dest, 0, sizeof(pubapi::DeviceTensor));
  259. {
  260. static FuncTable functable{&sync, &copy, &forward_other_memory,
  261. &dtype_name, &forward_to};
  262. dest.m_functable = &functable;
  263. }
  264. dest._version0 = dest._version1 = CURRENT_VERSION;
  265. mgb_assert((!!tensor) ^ (!!var));
  266. auto cn = tensor ? tensor->comp_node() : var->comp_node();
  267. using Type = pubapi::DeviceTensor::Type;
  268. switch (cn.device_type()) {
  269. case CompNode::DeviceType::CPU:
  270. dest.desc.type = Type::CPU;
  271. break;
  272. #if MGB_CUDA
  273. case CompNode::DeviceType::CUDA:
  274. dest.desc.type = Type::CUDA;
  275. break;
  276. #endif
  277. default:
  278. mgb_throw(MegBrainError, "bad comp node type: %d",
  279. static_cast<int>(cn.device_type()));
  280. }
  281. dest.desc.dtype = dtype_mgb2pubapi(tensor ? tensor->dtype() : var->dtype());
  282. if (tensor) {
  283. dest.desc.dev_ptr = tensor->raw_ptr();
  284. dest.desc.shape = tensor->shape().shape;
  285. dest.desc.ndim = tensor->shape().ndim;
  286. dest.size_bytes = tensor->layout().span().dist_byte();
  287. }
  288. #if MGB_CUDA
  289. if (dest.desc.type == Type::CUDA) {
  290. auto&& env = CompNodeEnv::from_comp_node(cn).cuda_env();
  291. dest.desc.cuda_ctx.device = env.device;
  292. dest.desc.cuda_ctx.stream = env.stream;
  293. }
  294. #endif
  295. dest.m_readonly = readonly;
  296. dest.m_dev_nd = tensor;
  297. dest.m_varptr = var;
  298. }
  299. }; // pubapi::DeviceTensor::Impl
  300. void mgb::init_pubapi_dev_tensor(pubapi::DeviceTensor& dest,
  301. DeviceTensorND* tensor, VarNode* var,
  302. bool readonly) {
  303. pubapi::DeviceTensor::_Impl::init_tensor(dest, tensor, var, readonly);
  304. }
  305. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台