You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mem_scheduler_test.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include <map>
  18. #include "common/common_test.h"
  19. #include "runtime/device/memory_scheduler.h"
  20. namespace mindspore::device {
  21. constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024;
  22. constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024;
  23. class MemHandlerImpl : public MemHandler {
  24. public:
  25. MemHandlerImpl() {
  26. device_mem_.resize(kMaxVirtualCount, 0);
  27. host_mem_.resize(kMaxVirtualCount, 1);
  28. }
  29. size_t GetAvailableMemSize() override { return kDeviceMemSize; }
  30. void *MallocDevice(size_t mem_size) override {
  31. auto ret = device_mem_.data() + device_virtual_count_;
  32. ++device_virtual_count_;
  33. device_mem_size_.emplace(ret, mem_size);
  34. return ret;
  35. }
  36. void FreeDevice(void *ptr) override {
  37. auto iter = device_mem_size_.find(ptr);
  38. if (iter != device_mem_size_.end()) {
  39. device_mem_size_.erase(iter);
  40. }
  41. }
  42. void *MallocHost(size_t mem_size) override {
  43. auto ret = host_mem_.data() + host_virtual_count_;
  44. ++host_virtual_count_;
  45. host_mem_size_.emplace(ret, mem_size);
  46. return ret;
  47. }
  48. void FreeHost(void *ptr) override {
  49. auto iter = host_mem_size_.find(ptr);
  50. if (iter != host_mem_size_.end()) {
  51. host_mem_size_.erase(iter);
  52. }
  53. }
  54. void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
  55. void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
  56. private:
  57. std::vector<uint8_t> device_mem_;
  58. std::vector<uint8_t> host_mem_;
  59. size_t device_virtual_count_;
  60. size_t host_virtual_count_;
  61. std::map<void *, size_t> device_mem_size_;
  62. std::map<void *, size_t> host_mem_size_;
  63. };
  64. class TestMemScheduler : public UT::Common {
  65. public:
  66. TestMemScheduler() {}
  67. };
  68. /// Feature: MemSchedulerManager
  69. /// Description: Test MemSchedulerManager GetOrCreateMemScheduler interface
  70. /// Expectation: Create MemScheduler
  71. TEST_F(TestMemScheduler, test_mem_scheduler_manager) {
  72. MemSchedulerManager mem_scheduler_manager;
  73. auto ret = mem_scheduler_manager.GetMemScheduler(0);
  74. ASSERT_EQ(ret, nullptr);
  75. ret = mem_scheduler_manager.GetOrCreateMemScheduler(0);
  76. ASSERT_NE(ret, nullptr);
  77. ret = mem_scheduler_manager.GetMemScheduler(0);
  78. ASSERT_NE(ret, nullptr);
  79. }
  80. /// Feature: MemScheduler
  81. /// Description: Test MemScheduler interface
  82. /// Expectation: MemScheduler GetOrMalloc return valid ptr
  83. TEST_F(TestMemScheduler, test_mem_scheduler) {
  84. MemSchedulerManager mem_scheduler_manager;
  85. auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
  86. ASSERT_NE(scheduler, nullptr);
  87. auto need_record = scheduler->need_record_event();
  88. ASSERT_EQ(need_record, true);
  89. auto optimized = scheduler->optimized();
  90. ASSERT_EQ(optimized, false);
  91. std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
  92. ASSERT_NE(mem_handler, nullptr);
  93. scheduler->SetMemHandler(mem_handler);
  94. constexpr size_t kUsedTensors = 10;
  95. constexpr size_t kTimeSlice = 7;
  96. std::vector<uint8_t> tensor_keys(kUsedTensors, 0);
  97. std::vector<uint8_t> tensor_datas(kUsedTensors, 0);
  98. std::vector<size_t> init_tensors = {0, 2, 4};
  99. std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
  100. void *stream = nullptr;
  101. scheduler->SetTotalStep(kTimeSlice);
  102. // record
  103. for (auto index : init_tensors) {
  104. scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
  105. }
  106. for (size_t i = 0; i < kTimeSlice; ++i) {
  107. auto &tensors = step_tensors[i];
  108. for (auto j : tensors) {
  109. scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
  110. }
  111. scheduler->PostCompute(stream);
  112. }
  113. scheduler->set_need_record_event(false);
  114. // optimize
  115. scheduler->Optimize();
  116. // run
  117. scheduler->ResetCurrentStep();
  118. for (auto index : init_tensors) {
  119. scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
  120. }
  121. for (size_t i = 0; i < kTimeSlice; ++i) {
  122. scheduler->PreCompute(stream);
  123. auto &tensors = step_tensors[i];
  124. for (auto j : tensors) {
  125. auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
  126. ASSERT_NE(addr, nullptr);
  127. }
  128. scheduler->PostCompute(stream);
  129. }
  130. }
  131. } // namespace mindspore::device