You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blob_manager_impl.cpp 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #include "./blob_manager_impl.h"
  2. #include <set>
  3. #include "megbrain/utils/arith_helper.h"
  4. namespace mgb {
  5. namespace imperative {
  6. BlobManagerImpl::BlobData::BlobData(OwnedBlob* in_blob) {
  7. blob = in_blob;
  8. DeviceTensorStorage d_storage;
  9. d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage);
  10. h_storage = HostTensorStorage(blob->m_comp_node);
  11. h_storage.ensure_size(blob->m_size);
  12. h_storage.copy_from(const_cast<DeviceTensorStorage&>(d_storage), blob->m_size);
  13. }
  14. void BlobManagerImpl::register_blob(OwnedBlob* blob) {
  15. // add blob into the comp2blobs map
  16. MGB_LOCK_GUARD(m_mtx);
  17. mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob));
  18. }
  19. void BlobManagerImpl::unregister_blob(OwnedBlob* blob) {
  20. // erase blob into the comp2blobs map
  21. MGB_LOCK_GUARD(m_mtx);
  22. mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob));
  23. }
  24. void BlobManagerImpl::alloc_with_defrag(OwnedBlob* blob, size_t size) {
  25. if (m_custom_allocator) {
  26. blob->m_storage = m_custom_allocator(blob->m_comp_node, size);
  27. return;
  28. }
  29. // try alloc
  30. // if fail, try defrag, alloc again
  31. if (!try_alloc_direct(blob, size)) {
  32. mgb_log_warn("memory allocation failed for blob; try defragmenting");
  33. defrag(blob->m_comp_node);
  34. alloc_direct(blob, size);
  35. }
  36. }
  37. void BlobManagerImpl::alloc_direct(OwnedBlob* blob, size_t size) {
  38. mgb_assert(blob->m_comp_node.valid());
  39. DeviceTensorStorage storage(blob->m_comp_node);
  40. storage.ensure_size(size);
  41. blob->m_storage = storage.raw_storage();
  42. }
  43. void BlobManagerImpl::set_allocator(allocator_t allocator) {
  44. m_custom_allocator = allocator;
  45. }
  46. void BlobManagerImpl::defrag(const CompNode& cn) {
  47. auto& blobs_set_ptr = ([&]() -> auto& {
  48. MGB_LOCK_GUARD(m_mtx);
  49. return m_comp2blobs_map[cn];
  50. })();
  51. MGB_LOCK_GUARD(blobs_set_ptr.mtx);
  52. std::vector<BlobData> blob_data_arrary;
  53. std::set<Blob::RawStorage> storage_set;
  54. auto alignment = cn.get_mem_addr_alignment();
  55. size_t tot_sz = 0;
  56. // copy to HostTensorStorage, and release
  57. for (auto i : blobs_set_ptr.blobs_set) {
  58. // skip if blob do not have m_storage
  59. if (!i->m_storage)
  60. continue;
  61. // skip if ues_count() > 1
  62. if (i->m_storage.use_count() > 1)
  63. continue;
  64. // two blobs can't share same storage
  65. mgb_assert(storage_set.insert(i->m_storage).second);
  66. tot_sz += get_aligned_power2(i->m_size, alignment);
  67. BlobData blob_data(i);
  68. blob_data_arrary.push_back(blob_data);
  69. i->m_storage.reset();
  70. }
  71. // clear all, make sure m_storage will be release
  72. storage_set.clear();
  73. // skip if no blob to defrag
  74. if (!blob_data_arrary.size())
  75. return;
  76. // wait all other comp nodes to avoid moved var being read; note that
  77. // ExecEnv has been paused, so no new task would not be dispatched
  78. CompNode::sync_all();
  79. CompNode::try_coalesce_all_free_memory();
  80. // try free all
  81. MGB_TRY { cn.free_device(cn.alloc_device(tot_sz)); }
  82. MGB_CATCH(MemAllocError&, {})
  83. // sort blobs by created time, may be helpful for reduce memory fragment
  84. std::sort(
  85. blob_data_arrary.begin(), blob_data_arrary.end(),
  86. [](auto& lhs, auto& rhs) { return lhs.blob->m_id < rhs.blob->m_id; });
  87. // allocate for each storage
  88. for (auto i : blob_data_arrary) {
  89. DeviceTensorStorage d_storage = DeviceTensorStorage(cn);
  90. d_storage.ensure_size(i.blob->m_size);
  91. d_storage.copy_from(i.h_storage, i.blob->m_size);
  92. i.blob->m_storage = d_storage.raw_storage();
  93. }
  94. // wait copy finish before destructing host values
  95. cn.sync();
  96. }
  97. struct BlobManagerStub : BlobManager {
  98. void alloc_direct(OwnedBlob* blob, size_t size) {
  99. mgb_assert(0, "prohibited after global variable destruction");
  100. };
  101. void alloc_with_defrag(OwnedBlob* blob, size_t size) {
  102. mgb_assert(0, "prohibited after global variable destruction");
  103. };
  104. void register_blob(OwnedBlob* blob) {
  105. mgb_assert(0, "prohibited after global variable destruction");
  106. };
  107. void unregister_blob(OwnedBlob* blob){};
  108. void defrag(const CompNode& cn) {
  109. mgb_assert(0, "prohibited after global variable destruction");
  110. };
  111. void set_allocator(allocator_t allocator) {
  112. mgb_assert(0, "prohibited after global variable destruction");
  113. };
  114. };
  115. BlobManager* BlobManager::inst() {
  116. static std::aligned_union_t<0, BlobManagerImpl, BlobManagerStub> storage;
  117. struct Keeper {
  118. Keeper() { new (&storage) BlobManagerImpl(); }
  119. ~Keeper() {
  120. reinterpret_cast<BlobManager*>(&storage)->~BlobManager();
  121. new (&storage) BlobManagerStub();
  122. }
  123. };
  124. static Keeper _;
  125. return reinterpret_cast<BlobManager*>(&storage);
  126. }
  127. } // namespace imperative
  128. } // namespace mgb