Merge pull request !3264 from JoyLvliang/enable-mem-pool-manage-pynative-and-graph-static-memtags/v0.6.0-beta
| @@ -618,7 +618,12 @@ AscendDeviceAddress::~AscendDeviceAddress() { | |||
| return; | |||
| } | |||
| if (from_mem_pool_) { | |||
| AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); | |||
| if (communication_ptr_ != nullptr) { | |||
| AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_); | |||
| communication_ptr_ = nullptr; | |||
| } else { | |||
| AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); | |||
| } | |||
| ptr_ = nullptr; | |||
| } | |||
| } | |||
| @@ -21,32 +21,23 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| constexpr uint64_t kAscendDeviceMemGB = 26; | |||
| constexpr uint64_t kAscendMemPoolGB = 4; | |||
| constexpr uint64_t kAscendDeviceMemGB = 30; | |||
| constexpr uint64_t kMemSizeGB = 30; | |||
| constexpr uint64_t kMaxMemSizeGB = 30; | |||
| constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); | |||
| constexpr uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << kMemSizeGB); | |||
| constexpr uint64_t kReservedMemorySize = 10 * 1024 * 1024; | |||
| void AscendMemoryManager::MallocDeviceMemory() { | |||
| auto context_mem = GetDeviceMemSizeFromContext(); | |||
| device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; | |||
| static_mem_offset_ = device_mem_size_; | |||
| auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM); | |||
| auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]"; | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; | |||
| } | |||
| if (context_mem == 0) { | |||
| device_mem_pool_size_ = kAscendMemPoolSize; | |||
| ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; | |||
| } | |||
| AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_); | |||
| AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_); | |||
| } | |||
| dynamic_mem_offset_ = device_mem_size_ - kReservedMemorySize; | |||
| AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_); | |||
| AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); | |||
| } | |||
| uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { | |||
| @@ -64,7 +55,7 @@ uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { | |||
| auto gb_str = variable_memory_max_size.substr(0, pos); | |||
| auto gb_var = std::stoull(gb_str); | |||
| MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; | |||
| if (gb_var > kMaxMemSizeGB || gb_var == 0) { | |||
| if (gb_var > kAscendDeviceMemGB || gb_var == 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; | |||
| } | |||
| return gb_var << kMemSizeGB; | |||
| @@ -87,8 +78,60 @@ void AscendMemoryManager::FreeDeviceMemory() { | |||
| } | |||
| } | |||
| void AscendMemoryManager::ResetDynamicMemory() { | |||
| total_dynamic_size_ = 0; | |||
| dynamic_mem_offset_ = device_mem_size_ - kReservedMemorySize; | |||
| AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); | |||
| } | |||
| void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { | |||
| return AscendMemoryPool::GetInstance().AllocTensorMem(size); | |||
| auto align_size = GetCommonAlignSize(size); | |||
| return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); | |||
| } | |||
| uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem) { | |||
| size_t align_size = 0; | |||
| if (communication_mem) { | |||
| align_size = GetCommunicationAlignSize(size); | |||
| } else { | |||
| align_size = GetCommonAlignSize(size); | |||
| } | |||
| if (communication_mem) { | |||
| // create protect area [kMemAlignSize -- data -- kMemAlignSize] | |||
| uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); | |||
| return alloc_address + kMemAlignSize; | |||
| } else { | |||
| return reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); | |||
| } | |||
| } | |||
| uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { | |||
| size_t align_size = 0; | |||
| if (communication_mem) { | |||
| align_size = GetCommunicationAlignSize(size); | |||
| } else { | |||
| align_size = GetCommonAlignSize(size); | |||
| } | |||
| if (dynamic_mem_offset_ < align_size) { | |||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | |||
| << "]) malloc [" << align_size << "] failed!"; | |||
| } | |||
| auto new_offset = dynamic_mem_offset_ - align_size; | |||
| auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); | |||
| if (new_offset <= device_mem_pool_offset) { | |||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ | |||
| << "] memory pool[" << device_mem_pool_offset << "])" | |||
| << " malloc [" << align_size << "] failed!"; | |||
| } | |||
| total_dynamic_size_ += align_size; | |||
| dynamic_mem_offset_ = new_offset; | |||
| AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); | |||
| if (communication_mem) { | |||
| // create protect area [kMemAlignSize -- data -- kMemAlignSize] | |||
| return device_mem_base_ + dynamic_mem_offset_ + kMemAlignSize; | |||
| } else { | |||
| return device_mem_base_ + dynamic_mem_offset_; | |||
| } | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -27,8 +27,13 @@ class AscendMemoryManager : public MemoryManager { | |||
| void MallocDeviceMemory() override; | |||
| void FreeDeviceMemory() override; | |||
| void ResetDynamicMemory() override; | |||
| void *MallocMemFromMemPool(size_t size) override; | |||
| protected: | |||
| uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; | |||
| uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override; | |||
| private: | |||
| uint8_t *device_mem_pool_base_{nullptr}; | |||
| uint64_t device_mem_pool_size_{0}; | |||
| @@ -22,51 +22,56 @@ namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | |||
| if (has_malloc_) { | |||
| MS_LOG(EXCEPTION) << "Memory pool has been allocated memory resource!"; | |||
| if (size == 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; | |||
| } | |||
| if (size == 0 || size > free_mem_size_) { | |||
| MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero or large than free mem size!"; | |||
| if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) { | |||
| MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" | |||
| << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ | |||
| << "], need memory size [" << size << "]"; | |||
| } | |||
| *addr = device_mem_pool_base_; | |||
| *addr = device_mem_pool_base_ + device_mem_pool_offset_; | |||
| device_mem_pool_offset_ += size; | |||
| if (*addr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Device memory pool base address is nullptr, failed to alloc memory pool resource!"; | |||
| MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!"; | |||
| } | |||
| has_malloc_ = true; | |||
| free_mem_size_ -= size; | |||
| return size; | |||
| } | |||
| bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { | |||
| MS_EXCEPTION_IF_NULL(addr); | |||
| has_malloc_ = false; | |||
| free_mem_size_ = total_mem_size_; | |||
| return true; | |||
| } | |||
| size_t AscendMemoryPool::AlignMemorySize(size_t size) const { | |||
| if (size == 0) { | |||
| return DYNAMIC_MEM_ALIGN_SIZE; | |||
| MS_LOG(EXCEPTION) << "The align memory size is a zero !"; | |||
| } | |||
| return ((size + DYNAMIC_MEM_ALIGN_SIZE + 31) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; | |||
| return size; | |||
| } | |||
| size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - DYNAMIC_MEM_ALIGN_SIZE; } | |||
| size_t AscendMemoryPool::mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE / 2; } | |||
| void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { | |||
| MS_EXCEPTION_IF_NULL(device_mem_pool_base); | |||
| device_mem_pool_base_ = device_mem_pool_base; | |||
| } | |||
| void AscendMemoryPool::set_device_mem_pool_size(uint64_t device_mem_pool_size) { | |||
| device_mem_pool_size_ = device_mem_pool_size; | |||
| free_mem_size_ = device_mem_pool_size_; | |||
| total_mem_size_ = free_mem_size_; | |||
| void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset) { | |||
| graph_dynamic_mem_offset_ = graph_dynamic_mem_offset; | |||
| } | |||
| size_t AscendMemoryPool::free_mem_size() { return free_mem_size_; } | |||
| uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; } | |||
| size_t AscendMemoryPool::free_mem_size() { | |||
| if (graph_dynamic_mem_offset_ <= device_mem_pool_offset_) { | |||
| MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_ | |||
| << "] less than or equal to device mem pool offset [" << device_mem_pool_offset_ << "]!"; | |||
| } | |||
| return graph_dynamic_mem_offset_ - device_mem_pool_offset_; | |||
| } | |||
| size_t AscendMemoryPool::total_mem_size() { return total_mem_size_; } | |||
| size_t AscendMemoryPool::total_mem_size() { return graph_dynamic_mem_offset_ == 0 ? 0 : graph_dynamic_mem_offset_ - 1; } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -32,8 +32,9 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | |||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | |||
| void set_device_mem_pool_base(uint8_t *device_mem_pool_base); | |||
| void set_device_mem_pool_size(uint64_t device_mem_pool_size); | |||
| void set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset); | |||
| uint64_t device_mem_pool_offset() const; | |||
| size_t free_mem_size() override; | |||
| size_t total_mem_size() override; | |||
| @@ -50,11 +51,9 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||
| private: | |||
| AscendMemoryPool() = default; | |||
| bool has_malloc_{false}; | |||
| uint8_t *device_mem_pool_base_{nullptr}; | |||
| uint64_t device_mem_pool_size_{0}; | |||
| size_t free_mem_size_{0}; | |||
| size_t total_mem_size_{0}; | |||
| uint64_t device_mem_pool_offset_{0}; | |||
| uint64_t graph_dynamic_mem_offset_{0}; | |||
| }; | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -76,6 +76,7 @@ class DeviceAddress : public mindspore::DeviceSync { | |||
| string format_{"DefaultFormat"}; | |||
| TypeId type_id_{kNumberTypeFloat16}; | |||
| bool from_mem_pool_{false}; | |||
| uint8_t *communication_ptr_{nullptr}; | |||
| std::vector<int> host_shape_{}; | |||
| friend class KernelRuntime; | |||
| friend class MemoryManager; | |||
| @@ -335,8 +335,10 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(item, index); | |||
| } | |||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | |||
| auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); | |||
| auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | |||
| } | |||
| AnfAlgo::SetOutputAddr(address, index, item.get()); | |||
| } | |||
| } | |||
| @@ -434,11 +436,18 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode | |||
| // reuse communication op's all outputs' memory | |||
| type = kReuseDynamicCommMem; | |||
| } | |||
| uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); | |||
| uint8_t *output_ptr = nullptr; | |||
| for (size_t j = 0; j < align_size_list.size(); ++j) { | |||
| std::string output_format = AnfAlgo::GetOutputFormat(node, j); | |||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); | |||
| auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); | |||
| auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (output_ptr == nullptr) { | |||
| output_ptr = mem_manager_->MallocMem(address, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0)); | |||
| MS_EXCEPTION_IF_NULL(output_ptr); | |||
| } else { | |||
| address->set_ptr(output_ptr); | |||
| } | |||
| AnfAlgo::SetOutputAddr(address, j, node.get()); | |||
| output_ptr += align_size_list[j]; | |||
| } | |||
| @@ -464,7 +473,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||
| size_t total_size = 0; | |||
| std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; | |||
| std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size; | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { | |||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); | |||
| auto input_node = input_node_with_index.first; | |||
| @@ -477,9 +486,13 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); | |||
| total_size += mem_size; | |||
| addr_size.emplace_back(address.get(), mem_size); | |||
| addr_size.emplace_back(address, mem_size); | |||
| } | |||
| uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size); | |||
| if (addr_size.empty()) { | |||
| return; | |||
| } | |||
| uint8_t *input_ptr = | |||
| mem_manager_->MallocMem(addr_size[0].first, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0)); | |||
| for (const auto &iter : addr_size) { | |||
| MS_EXCEPTION_IF_NULL(iter.first); | |||
| iter.first->set_ptr(input_ptr); | |||
| @@ -509,15 +522,13 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in | |||
| MS_LOG(INFO) << "Already malloc index:" << i; | |||
| continue; | |||
| } | |||
| auto ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i]); | |||
| if (ptr == nullptr) { | |||
| // reused ptr, no need alloc, continue; | |||
| continue; | |||
| } | |||
| std::string output_format = AnfAlgo::GetOutputFormat(node, i); | |||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); | |||
| auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); | |||
| auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| uint8_t *ptr = | |||
| mem_manager_->MallocMem(device_address, type, output_sizes[i], std::pair<AnfNodePtr, size_t>(node, i)); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); | |||
| AnfAlgo::SetOutputAddr(device_address, i, node.get()); | |||
| } | |||
| @@ -543,16 +554,12 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||
| } | |||
| auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); | |||
| DeviceAddressPtr address = nullptr; | |||
| if (ms_context->enable_pynative_infer()) { | |||
| address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { | |||
| MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; | |||
| } | |||
| } else { | |||
| auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); | |||
| address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size; | |||
| } else if (mem_manager_->MallocMem(address, kStaticMem, node_size) == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size; | |||
| } | |||
| AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | |||
| if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), | |||
| @@ -582,16 +589,12 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||
| auto value = GetValue<std::string>(node_value); | |||
| size_t tensor_size = value.size(); | |||
| DeviceAddressPtr address = nullptr; | |||
| if (ms_context->enable_pynative_infer()) { | |||
| address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { | |||
| MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; | |||
| } | |||
| } else { | |||
| auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); | |||
| address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size; | |||
| } else if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | |||
| } | |||
| AnfAlgo::SetOutputAddr(address, 0, value_node.get()); | |||
| std::vector<int> shape = {1, SizeToInt(tensor_size)}; | |||
| @@ -95,6 +95,31 @@ uint8_t *MemoryManager::MallocMem(MemType type, size_t size) { | |||
| return ptr; | |||
| } | |||
| uint8_t *MemoryManager::MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size, | |||
| const session::KernelWithIndex &node_with_index) { | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| uint8_t *ptr = nullptr; | |||
| if (node_with_index.first != nullptr) { | |||
| ptr = MallocOutputMem(node_with_index.first, node_with_index.second, flag, size); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| if (AnfAlgo::IsCommunicationOp(node_with_index.first) && context_ptr->enable_hccl()) { | |||
| address->communication_ptr_ = ptr - kMemAlignSize; | |||
| } | |||
| } else { | |||
| ptr = MallocMem(flag, size); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| } | |||
| address->ptr_ = ptr; | |||
| if (flag == kStaticMem) { | |||
| address->from_mem_pool_ = true; | |||
| } | |||
| return ptr; | |||
| } | |||
| uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { | |||
| size_t align_size = 0; | |||
| if (communication_mem) { | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "backend/optimizer/mem_reuse/mem_reuse.h" | |||
| #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" | |||
| namespace mindspore { | |||
| @@ -34,7 +35,7 @@ class MemoryManager { | |||
| virtual void MallocDeviceMemory() = 0; | |||
| virtual void FreeDeviceMemory() = 0; | |||
| void ResetDynamicMemory() { | |||
| virtual void ResetDynamicMemory() { | |||
| total_dynamic_size_ = 0; | |||
| dynamic_mem_offset_ = 0; | |||
| } | |||
| @@ -42,6 +43,8 @@ class MemoryManager { | |||
| void MallocReusedDynamicMem(const session::KernelGraph *graph); | |||
| uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); | |||
| uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); | |||
| uint8_t *MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size, | |||
| const session::KernelWithIndex &node_with_index = std::pair<AnfNodePtr, size_t>(nullptr, 0)); | |||
| virtual uint8_t *MallocMem(MemType type, size_t size); | |||
| virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); | |||