/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "common/common_test.h" #include "runtime/device/memory_scheduler.h" namespace mindspore::device { constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024; constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024; class MemHandlerImpl : public MemHandler { public: MemHandlerImpl() { device_mem_.resize(kMaxVirtualCount, 0); host_mem_.resize(kMaxVirtualCount, 1); } size_t GetAvailableMemSize() override { return kDeviceMemSize; } void *MallocDevice(size_t mem_size) override { auto ret = device_mem_.data() + device_virtual_count_; ++device_virtual_count_; device_mem_size_.emplace(ret, mem_size); return ret; } void FreeDevice(void *ptr) override { auto iter = device_mem_size_.find(ptr); if (iter != device_mem_size_.end()) { device_mem_size_.erase(iter); } } void *MallocHost(size_t mem_size) override { auto ret = host_mem_.data() + host_virtual_count_; ++host_virtual_count_; host_mem_size_.emplace(ret, mem_size); return ret; } void FreeHost(void *ptr) override { auto iter = host_mem_size_.find(ptr); if (iter != host_mem_size_.end()) { host_mem_size_.erase(iter); } } void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {} void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {} private: std::vector device_mem_; std::vector host_mem_; size_t device_virtual_count_; size_t host_virtual_count_; std::map device_mem_size_; std::map host_mem_size_; }; class TestMemScheduler : public UT::Common { public: TestMemScheduler() {} }; /// Feature: MemSchedulerManager /// Description: Test MemSchedulerManager GetOrCreateMemScheduler interface /// Expectation: Create MemScheduler TEST_F(TestMemScheduler, test_mem_scheduler_manager) { MemSchedulerManager mem_scheduler_manager; auto ret = mem_scheduler_manager.GetMemScheduler(0); ASSERT_EQ(ret, nullptr); ret = mem_scheduler_manager.GetOrCreateMemScheduler(0); ASSERT_NE(ret, nullptr); ret = mem_scheduler_manager.GetMemScheduler(0); ASSERT_NE(ret, nullptr); } /// Feature: MemScheduler /// Description: Test MemScheduler interface /// Expectation: MemScheduler GetOrMalloc return valid ptr TEST_F(TestMemScheduler, test_mem_scheduler) { MemSchedulerManager mem_scheduler_manager; auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0); ASSERT_NE(scheduler, nullptr); auto need_record = scheduler->need_record_event(); ASSERT_EQ(need_record, true); auto optimized = scheduler->optimized(); ASSERT_EQ(optimized, false); std::shared_ptr mem_handler = std::make_shared(); ASSERT_NE(mem_handler, nullptr); scheduler->SetMemHandler(mem_handler); constexpr size_t kUsedTensors = 10; constexpr size_t kTimeSlice = 7; std::vector tensor_keys(kUsedTensors, 0); std::vector tensor_datas(kUsedTensors, 0); std::vector init_tensors = {0, 2, 4}; std::vector> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}}; void *stream = nullptr; scheduler->SetTotalStep(kTimeSlice); // record for (auto index : init_tensors) { scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh); } for (size_t i = 0; i < kTimeSlice; ++i) { auto &tensors = step_tensors[i]; for (auto j : tensors) { scheduler->GetOrMalloc(tensor_keys.data() + j, 1); } scheduler->PostCompute(stream); } scheduler->set_need_record_event(false); // optimize scheduler->Optimize(); // run scheduler->ResetCurrentStep(); for (auto index : init_tensors) { scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh); } for (size_t i = 0; i < kTimeSlice; ++i) { scheduler->PreCompute(stream); auto &tensors = step_tensors[i]; for (auto j : tensors) { auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1); ASSERT_NE(addr, nullptr); } scheduler->PostCompute(stream); } } } // namespace mindspore::device