You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blob_manager_impl.cpp 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #include "./blob_manager_impl.h"
  2. #include <set>
  3. #include "megbrain/utils/arith_helper.h"
  4. namespace mgb {
  5. namespace imperative {
  6. BlobManagerImpl::BlobData::BlobData(OwnedBlob* in_blob) {
  7. blob = in_blob;
  8. DeviceTensorStorage d_storage;
  9. d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage);
  10. h_storage = HostTensorStorage(blob->m_comp_node);
  11. h_storage.ensure_size(blob->m_size);
  12. h_storage.copy_from(const_cast<DeviceTensorStorage&>(d_storage), blob->m_size);
  13. }
  14. void BlobManagerImpl::register_blob(OwnedBlob* blob) {
  15. // add blob into the comp2blobs map
  16. MGB_LOCK_GUARD(m_mtx);
  17. mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob));
  18. }
  19. void BlobManagerImpl::unregister_blob(OwnedBlob* blob) {
  20. // erase blob into the comp2blobs map
  21. MGB_LOCK_GUARD(m_mtx);
  22. mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob));
  23. }
  24. void BlobManagerImpl::alloc_with_defrag(OwnedBlob* blob, size_t size) {
  25. if (custom_allocator) {
  26. blob->m_storage = custom_allocator(blob->m_comp_node, size);
  27. return;
  28. }
  29. // try alloc
  30. MGB_TRY { alloc_direct(blob, size); }
  31. // if fail, try defrag, alloc again
  32. MGB_CATCH(MemAllocError&, {
  33. mgb_log_warn("memory allocation failed for blob; try defragmenting");
  34. defrag(blob->m_comp_node);
  35. alloc_direct(blob, size);
  36. });
  37. }
  38. void BlobManagerImpl::alloc_direct(OwnedBlob* blob, size_t size) {
  39. DeviceTensorStorage storage(blob->m_comp_node);
  40. mgb_assert(blob->m_comp_node.valid());
  41. storage.ensure_size(size);
  42. blob->m_storage = storage.raw_storage();
  43. }
  44. DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(
  45. CompNode cn, TensorLayout& layout) {
  46. DeviceTensorND dev_tensor;
  47. if (custom_allocator) {
  48. DeviceTensorStorage storage(cn);
  49. size_t sz = layout.dtype.size(layout.total_nr_elems());
  50. storage.reset(cn, sz, custom_allocator(cn, sz));
  51. dev_tensor.reset(storage, layout);
  52. return dev_tensor;
  53. }
  54. MGB_TRY { dev_tensor = alloc_workspace(cn, layout); }
  55. MGB_CATCH(MemAllocError&, {
  56. mgb_log_warn("memory allocation failed for workspace; try defragmenting");
  57. defrag(cn);
  58. dev_tensor = alloc_workspace(cn, layout);
  59. });
  60. return dev_tensor;
  61. };
  62. DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout) {
  63. DeviceTensorStorage storage(cn);
  64. storage.ensure_size(layout.dtype.size(layout.total_nr_elems()));
  65. DeviceTensorND dev_tensor;
  66. dev_tensor.reset(storage, layout);
  67. return dev_tensor;
  68. }
  69. void BlobManagerImpl::set_allocator(allocator_t allocator) {
  70. custom_allocator = allocator;
  71. }
  72. void BlobManagerImpl::defrag(const CompNode& cn) {
  73. BlobSetWithMux* blobs_set_ptr;
  74. {
  75. MGB_LOCK_GUARD(m_mtx);
  76. blobs_set_ptr = &m_comp2blobs_map[cn];
  77. }
  78. MGB_LOCK_GUARD(blobs_set_ptr->mtx);
  79. std::vector<BlobData> blob_data_arrary;
  80. std::set<Blob::RawStorage> storage_set;
  81. auto alignment = cn.get_mem_addr_alignment();
  82. size_t tot_sz = 0;
  83. // copy to HostTensorStorage, and release
  84. for (auto i : blobs_set_ptr->blobs_set) {
  85. // skip if blob do not have m_storage
  86. if (!i->m_storage)
  87. continue;
  88. // skip if ues_count() > 1
  89. if (i->m_storage.use_count() > 1)
  90. continue;
  91. // two blobs can't share same storage
  92. mgb_assert(storage_set.insert(i->m_storage).second);
  93. tot_sz += get_aligned_power2(i->m_size, alignment);
  94. BlobData blob_data(i);
  95. blob_data_arrary.push_back(blob_data);
  96. i->m_storage.reset();
  97. }
  98. // clear all, make sure m_storage will be release
  99. storage_set.clear();
  100. // skip if no blob to defrag
  101. if (!blob_data_arrary.size())
  102. return;
  103. // wait all other comp nodes to avoid moved var being read; note that
  104. // ExecEnv has been paused, so no new task would not be dispatched
  105. CompNode::sync_all();
  106. CompNode::try_coalesce_all_free_memory();
  107. // try free all
  108. MGB_TRY { cn.free_device(cn.alloc_device(tot_sz)); }
  109. MGB_CATCH(MemAllocError&, {})
  110. // sort blobs by created time, may be helpful for reduce memory fragment
  111. std::sort(
  112. blob_data_arrary.begin(), blob_data_arrary.end(),
  113. [](auto& lhs, auto& rhs) { return lhs.blob->m_id < rhs.blob->m_id; });
  114. // allocate for each storage
  115. for (auto i : blob_data_arrary) {
  116. DeviceTensorStorage d_storage = DeviceTensorStorage(cn);
  117. d_storage.ensure_size(i.blob->m_size);
  118. d_storage.copy_from(i.h_storage, i.blob->m_size);
  119. i.blob->m_storage = d_storage.raw_storage();
  120. }
  121. // wait copy finish before destructing host values
  122. cn.sync();
  123. }
  124. struct BlobManagerStub : BlobManager {
  125. void alloc_direct(OwnedBlob* blob, size_t size) {
  126. mgb_assert(0, "prohibited after global variable destruction");
  127. };
  128. void alloc_with_defrag(OwnedBlob* blob, size_t size) {
  129. mgb_assert(0, "prohibited after global variable destruction");
  130. };
  131. DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout& layout) {
  132. mgb_assert(0, "prohibited after global variable destruction");
  133. };
  134. void register_blob(OwnedBlob* blob) {
  135. mgb_assert(0, "prohibited after global variable destruction");
  136. };
  137. void unregister_blob(OwnedBlob* blob){};
  138. void defrag(const CompNode& cn) {
  139. mgb_assert(0, "prohibited after global variable destruction");
  140. };
  141. virtual void set_allocator(allocator_t allocator) {
  142. mgb_assert(0, "prohibited after global variable destruction");
  143. };
  144. };
  145. BlobManager* BlobManager::inst() {
  146. static std::aligned_union_t<0, BlobManagerImpl, BlobManagerStub> storage;
  147. struct Keeper {
  148. Keeper() { new (&storage) BlobManagerImpl(); }
  149. ~Keeper() {
  150. reinterpret_cast<BlobManager*>(&storage)->~BlobManager();
  151. new (&storage) BlobManagerStub();
  152. }
  153. };
  154. static Keeper _;
  155. return reinterpret_cast<BlobManager*>(&storage);
  156. }
  157. } // namespace imperative
  158. } // namespace mgb