From: @laiyongqiang Reviewed-by: @jjfeing,@majorzhang Signed-off-by: @majorzhangpull/15161/MERGE
| @@ -109,16 +109,16 @@ class DynamicMemPoolBestFit { | |||||
| protected: | protected: | ||||
| // The real size by memory alloc aligned. | // The real size by memory alloc aligned. | ||||
| virtual size_t AlignMemorySize(size_t size) const; | virtual size_t AlignMemorySize(size_t size) const; | ||||
| // Get the minimum memory unit size using for dynamic extend. | |||||
| virtual size_t mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE; } | |||||
| // Calculate memory block required alloc size when adding the memory block. | |||||
| virtual size_t CalMemBlockAllocSize(size_t size); | |||||
| private: | private: | ||||
| // Get the minimum memory unit size using for dynamic extend. | |||||
| size_t mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE; } | |||||
| // Find the idle memory buf by aligned size when memory alloc. | // Find the idle memory buf by aligned size when memory alloc. | ||||
| DeviceMemPtr FindIdleMemBuf(size_t size); | DeviceMemPtr FindIdleMemBuf(size_t size); | ||||
| // Add the memory block and memory buf when memory alloc not find the idle memory buf. | // Add the memory block and memory buf when memory alloc not find the idle memory buf. | ||||
| DeviceMemPtr AddMemBlockAndMemBuf(size_t size); | DeviceMemPtr AddMemBlockAndMemBuf(size_t size); | ||||
| // Calculate memory block required alloc size when adding the memory block. | |||||
| size_t CalMemBlockAllocSize(size_t size); | |||||
| // Judge whether need divide the memory buf by alloc size and memory buf size. | // Judge whether need divide the memory buf by alloc size and memory buf size. | ||||
| bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; | bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; | ||||
| // Divide the memory buf by alloc size. | // Divide the memory buf by alloc size. | ||||
| @@ -447,6 +447,7 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| auto index = 0; | |||||
| for (const auto &size : output_sizes) { | for (const auto &size : output_sizes) { | ||||
| auto output_tensor_index = tensor_index; | auto output_tensor_index = tensor_index; | ||||
| tensor_index++; | tensor_index++; | ||||
| @@ -455,15 +456,21 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| tensor->lifetime_.start_ = node->GetId(); | tensor->lifetime_.start_ = node->GetId(); | ||||
| tensor->lifetime_.end_ = node->GetId(); | tensor->lifetime_.end_ = node->GetId(); | ||||
| tensor->type_ = kOutputOnly; | tensor->type_ = kOutputOnly; | ||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||||
| tensor->aligned_size_ = 0; | |||||
| } | |||||
| tensors_list_.push_back(tensor); | tensors_list_.push_back(tensor); | ||||
| tensors_map_[output_tensor_index] = tensor; | tensors_map_[output_tensor_index] = tensor; | ||||
| stream->tensors_.push_back(tensor); | stream->tensors_.push_back(tensor); | ||||
| node->tensors_.insert(tensor); | node->tensors_.insert(tensor); | ||||
| node->output_tensors_.push_back(tensor); | node->output_tensors_.push_back(tensor); | ||||
| index++; | |||||
| } | } | ||||
| // WorkSpace Tensor | // WorkSpace Tensor | ||||
| auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | ||||
| index = 0; | |||||
| for (const auto &size : workspace_sizes) { | for (const auto &size : workspace_sizes) { | ||||
| auto workspace_tensor_index = tensor_index; | auto workspace_tensor_index = tensor_index; | ||||
| tensor_index++; | tensor_index++; | ||||
| @@ -471,11 +478,15 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph | |||||
| tensor->type_ = kWorkspace; | tensor->type_ = kWorkspace; | ||||
| tensor->lifetime_.start_ = node->GetId(); | tensor->lifetime_.start_ = node->GetId(); | ||||
| tensor->lifetime_.end_ = node->GetId(); | tensor->lifetime_.end_ = node->GetId(); | ||||
| if (AnfAlgo::WorkspaceAddrExist(kernel, index)) { | |||||
| tensor->aligned_size_ = 0; | |||||
| } | |||||
| tensors_list_.push_back(tensor); | tensors_list_.push_back(tensor); | ||||
| tensors_map_[workspace_tensor_index] = tensor; | tensors_map_[workspace_tensor_index] = tensor; | ||||
| stream->tensors_.push_back(tensor); | stream->tensors_.push_back(tensor); | ||||
| node->tensors_.insert(tensor); | node->tensors_.insert(tensor); | ||||
| node->workspace_tensors_.push_back(tensor); | node->workspace_tensors_.push_back(tensor); | ||||
| index++; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -874,8 +885,12 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) { | |||||
| // Contiguous input | // Contiguous input | ||||
| if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) { | if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) { | ||||
| node->input_tensors_[0]->aligned_size_ += kGapSize; | |||||
| node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += kGapSize; | |||||
| if (node->input_tensors_[0]->aligned_size_) { | |||||
| node->input_tensors_[0]->aligned_size_ += kGapSize; | |||||
| } | |||||
| if (node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_) { | |||||
| node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += kGapSize; | |||||
| } | |||||
| std::vector<size_t> inputs; | std::vector<size_t> inputs; | ||||
| for (const auto &input_tensor : node->input_tensors_) { | for (const auto &input_tensor : node->input_tensors_) { | ||||
| comm_input_total_size_ += input_tensor->aligned_size_; | comm_input_total_size_ += input_tensor->aligned_size_; | ||||
| @@ -887,8 +902,12 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) { | |||||
| // Contiguous output | // Contiguous output | ||||
| if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) { | if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) { | ||||
| node->output_tensors_[0]->aligned_size_ += kGapSize; | |||||
| node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += kGapSize; | |||||
| if (node->output_tensors_[0]->aligned_size_) { | |||||
| node->output_tensors_[0]->aligned_size_ += kGapSize; | |||||
| } | |||||
| if (node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_) { | |||||
| node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += kGapSize; | |||||
| } | |||||
| std::vector<size_t> outputs; | std::vector<size_t> outputs; | ||||
| for (const auto &output_tensor : node->output_tensors_) { | for (const auto &output_tensor : node->output_tensors_) { | ||||
| comm_output_total_size_ += output_tensor->aligned_size_; | comm_output_total_size_ += output_tensor->aligned_size_; | ||||
| @@ -1097,17 +1116,33 @@ bool Somas::Assign(const session::KernelGraph *graph) { | |||||
| // Ref Node Preprocessing | // Ref Node Preprocessing | ||||
| UpdateRefTensorsConflict(); | UpdateRefTensorsConflict(); | ||||
| std::map<size_t, size_t> contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor(); | std::map<size_t, size_t> contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor(); | ||||
| vector<vector<size_t>> contiguous_tensors_list_removed_ref = contiguous_tensors_list_; | |||||
| vector<vector<size_t>> contiguous_tensors_list_removed = contiguous_tensors_list_; | |||||
| std::set<vector<size_t>> contiguous_tensors_list_to_remove; | std::set<vector<size_t>> contiguous_tensors_list_to_remove; | ||||
| for (auto ref_list_pair : contiguous_list_with_ref_index_map) { | for (auto ref_list_pair : contiguous_list_with_ref_index_map) { | ||||
| contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]); | contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]); | ||||
| } | } | ||||
| // remove the contiguous list which all tensors' align size is 0 | |||||
| for (auto contiguous_list : contiguous_tensors_list_) { | |||||
| bool all_outputs = true; | |||||
| for (auto tensor_id : contiguous_list) { | |||||
| auto tensor = tensors_list_[tensor_id]; | |||||
| if (tensor->aligned_size_ != 0) { | |||||
| all_outputs = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (all_outputs) { | |||||
| contiguous_tensors_list_to_remove.insert(contiguous_list); | |||||
| } | |||||
| } | |||||
| for (auto contiguous_list : contiguous_tensors_list_to_remove) { | for (auto contiguous_list : contiguous_tensors_list_to_remove) { | ||||
| auto iterator = std::find(contiguous_tensors_list_removed_ref.begin(), contiguous_tensors_list_removed_ref.end(), | |||||
| contiguous_list); | |||||
| if (iterator != contiguous_tensors_list_removed_ref.end()) { | |||||
| contiguous_tensors_list_removed_ref.erase(iterator); | |||||
| auto iterator = | |||||
| std::find(contiguous_tensors_list_removed.begin(), contiguous_tensors_list_removed.end(), contiguous_list); | |||||
| if (iterator != contiguous_tensors_list_removed.end()) { | |||||
| contiguous_tensors_list_removed.erase(iterator); | |||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "Could not find contiguous list to remove for ref"; | MS_LOG(WARNING) << "Could not find contiguous list to remove for ref"; | ||||
| } | } | ||||
| @@ -1142,7 +1177,7 @@ bool Somas::Assign(const session::KernelGraph *graph) { | |||||
| somas_solver_ = std::make_shared<SomasSolverPre>(); | somas_solver_ = std::make_shared<SomasSolverPre>(); | ||||
| auto status = | auto status = | ||||
| somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed_ref, false); | |||||
| somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed, false); | |||||
| MS_LOG(INFO) << "End Solving"; | MS_LOG(INFO) << "End Solving"; | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| GenGraphStatisticInfo(); | GenGraphStatisticInfo(); | ||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include "runtime/device/ascend/ascend_memory_pool.h" | #include "runtime/device/ascend/ascend_memory_pool.h" | ||||
| #include "runtime/device/ascend/ascend_kernel_runtime.h" | #include "runtime/device/ascend/ascend_kernel_runtime.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -21,6 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| // The minimum unit size (256MB) of memory block used for dynamic extend. | |||||
| static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE = 256 << 20; | |||||
| void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size, uint64_t dynamic_mem_offset) { | void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size, uint64_t dynamic_mem_offset) { | ||||
| static bool initialized = false; | static bool initialized = false; | ||||
| if (initialized) { | if (initialized) { | ||||
| @@ -40,11 +44,43 @@ void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size, | |||||
| initialized = true; | initialized = true; | ||||
| } | } | ||||
| size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size) { | |||||
| auto device_free_mem_size = free_mem_size(); | |||||
| if (device_free_mem_size < size) { | |||||
| MS_LOG(EXCEPTION) << "Memory not enough: current free memory size[" << device_free_mem_size | |||||
| << "] is smaller than required size[" << size << "], dynamic offset [" | |||||
| << graph_dynamic_mem_offset_ << "] memory pool offset[" | |||||
| << device_mem_size_ - device_mem_pool_offset_ << "])"; | |||||
| return 0; | |||||
| } | |||||
| auto alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE; | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode); | |||||
| if (pynative_mode) { | |||||
| // Growing at twice of alloc size | |||||
| while (alloc_mem_size < size) { | |||||
| alloc_mem_size = alloc_mem_size * 2; | |||||
| } | |||||
| } else { | |||||
| while (alloc_mem_size < size) { | |||||
| alloc_mem_size = alloc_mem_size + ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE; | |||||
| } | |||||
| } | |||||
| alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size); | |||||
| return alloc_mem_size; | |||||
| } | |||||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | ||||
| MS_LOG(INFO) << "Malloc Memory: Pool, total[" << device_mem_size_ << "] (dynamic[" << graph_dynamic_mem_offset_ | |||||
| << "] memory pool[" << device_mem_size_ - device_mem_pool_offset_ << "])" | |||||
| << " malloc [" << size << "]"; | |||||
| if (size == 0) { | if (size == 0) { | ||||
| MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; | MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; | ||||
| } | } | ||||
| if (device_mem_pool_offset_ - size <= graph_dynamic_mem_offset_) { | |||||
| if (device_mem_pool_offset_ - size < graph_dynamic_mem_offset_) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" | MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" | ||||
| << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ | << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ | ||||
| << "], need memory size [" << size << "]"; | << "], need memory size [" << size << "]"; | ||||
| @@ -76,8 +112,6 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { | |||||
| return size; | return size; | ||||
| } | } | ||||
| size_t AscendMemoryPool::mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE / 4; } | |||||
| void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { | void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { | ||||
| MS_EXCEPTION_IF_NULL(device_mem_pool_base); | MS_EXCEPTION_IF_NULL(device_mem_pool_base); | ||||
| device_mem_pool_base_ = device_mem_pool_base; | device_mem_pool_base_ = device_mem_pool_base; | ||||
| @@ -50,8 +50,8 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||||
| protected: | protected: | ||||
| // The real size by memory alloc aligned. | // The real size by memory alloc aligned. | ||||
| size_t AlignMemorySize(size_t size) const override; | size_t AlignMemorySize(size_t size) const override; | ||||
| // Get the minimum memory unit size using for dynamic extend. | |||||
| size_t mem_alloc_unit_size() const override; | |||||
| // Calculate memory block required alloc size when adding the memory block. | |||||
| size_t CalMemBlockAllocSize(size_t size) override; | |||||
| private: | private: | ||||
| AscendMemoryPool() = default; | AscendMemoryPool() = default; | ||||
| @@ -340,7 +340,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| #endif | #endif | ||||
| auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); | ||||
| device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | ||||
| MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | |||||
| MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() | |||||
| << " index: " << index << " size: " << tensor_size; | |||||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { | if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | ||||
| } | } | ||||
| @@ -65,10 +65,12 @@ void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph *graph) { | |||||
| size_t total_allocated_size = somas_reuse_util_ptr->GetTotalMemSize(); | size_t total_allocated_size = somas_reuse_util_ptr->GetTotalMemSize(); | ||||
| MS_LOG(INFO) << "Graph " << graph->graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]"; | MS_LOG(INFO) << "Graph " << graph->graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]"; | ||||
| auto base_ptr = MallocDynamicMem(total_allocated_size, false); | |||||
| MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address [" | |||||
| << static_cast<void *>(base_ptr + total_allocated_size) << "]"; | |||||
| somas_reuse_util_ptr->set_mem_base_addr(base_ptr); | |||||
| if (total_allocated_size > 0) { | |||||
| auto base_ptr = MallocDynamicMem(total_allocated_size, false); | |||||
| MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address [" | |||||
| << static_cast<void *>(base_ptr + total_allocated_size) << "]"; | |||||
| somas_reuse_util_ptr->set_mem_base_addr(base_ptr); | |||||
| } | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -166,66 +168,7 @@ uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddress | |||||
| return ptr; | return ptr; | ||||
| } | } | ||||
| uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) { | |||||
| size_t align_size = 0; | |||||
| if (communication_mem) { | |||||
| align_size = GetCommunicationAlignSize(size); | |||||
| } else { | |||||
| align_size = GetCommonAlignSize(size); | |||||
| } | |||||
| MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||||
| << "] static[" << total_static_size_ << "])" | |||||
| << " malloc [" << align_size << "] communication_mem: " << communication_mem; | |||||
| if (static_mem_offset_ < align_size) { | |||||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||||
| << "] static[" << total_static_size_ << "])" | |||||
| << " malloc [" << align_size << "] failed!"; | |||||
| } | |||||
| total_static_size_ += align_size; | |||||
| auto offset = static_mem_offset_ - align_size; | |||||
| if (dynamic_mem_offset_ > offset) { | |||||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||||
| << "] static[" << total_static_size_ << "])" | |||||
| << " malloc [" << align_size << "] failed!"; | |||||
| } | |||||
| static_mem_offset_ = offset; | |||||
| if (communication_mem) { | |||||
| return device_mem_base_ + offset + kMemAlignSize; | |||||
| } else { | |||||
| return device_mem_base_ + offset; | |||||
| } | |||||
| } | |||||
| uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { | |||||
| size_t align_size = 0; | |||||
| if (communication_mem) { | |||||
| align_size = GetCommunicationAlignSize(size); | |||||
| } else { | |||||
| align_size = GetCommonAlignSize(size); | |||||
| } | |||||
| MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||||
| << "] static[" << total_static_size_ << "])" | |||||
| << " malloc [" << align_size << "] communication_mem: " << communication_mem; | |||||
| uint64_t offset = dynamic_mem_offset_; | |||||
| auto new_offset = dynamic_mem_offset_ + align_size; | |||||
| if (new_offset > static_mem_offset_) { | |||||
| MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ | |||||
| << "] static[" << total_static_size_ << "])" | |||||
| << " malloc [" << align_size << "] failed!"; | |||||
| } | |||||
| total_dynamic_size_ += align_size; | |||||
| dynamic_mem_offset_ = new_offset; | |||||
| if (communication_mem) { | |||||
| return device_mem_base_ + offset + kMemAlignSize; | |||||
| } else { | |||||
| return device_mem_base_ + offset; | |||||
| } | |||||
| } | |||||
| uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { return nullptr; } | |||||
| bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { | bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { | ||||
| auto device_ptr = MallocMemFromMemPool(size); | auto device_ptr = MallocMemFromMemPool(size); | ||||
| @@ -64,7 +64,7 @@ class MemoryManager { | |||||
| size_t GetCommunicationAlignSize(size_t input_size) const; | size_t GetCommunicationAlignSize(size_t input_size) const; | ||||
| protected: | protected: | ||||
| virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId); | |||||
| virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) = 0; | |||||
| virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); | virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); | ||||
| uint8_t *device_mem_base_{nullptr}; | uint8_t *device_mem_base_{nullptr}; | ||||
| uint64_t device_mem_size_{0}; | uint64_t device_mem_size_{0}; | ||||
| @@ -20,7 +20,7 @@ from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| context.set_context(device_target="Ascend") | |||||
| context.set_context(device_target="Ascend", mode=context.GRAPH_MODE, variable_memory_max_size="31GB") | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| @@ -34,8 +34,12 @@ class Net(nn.Cell): | |||||
| def test_net(): | def test_net(): | ||||
| x = np.random.randn(2, 3, 3, 4).astype(np.float32) | |||||
| # size (31GB/2/-512)s/ize(float32) 4160749440 | |||||
| x = np.random.randn(16, 120, 2167057).astype(np.float32) | |||||
| relu = Net() | relu = Net() | ||||
| output = relu(Tensor(x)) | output = relu(Tensor(x)) | ||||
| expect = 1 * (x > 0) * x | |||||
| print(x) | print(x) | ||||
| print(output.asnumpy()) | print(output.asnumpy()) | ||||
| print(expect) | |||||
| assert (output.asnumpy() == expect).all() | |||||