|
|
|
@@ -36,8 +36,8 @@ void MemoryManagerActor::AllocateMemory(const std::vector<DeviceTensor *> *alloc |
|
|
|
} |
|
|
|
// Allocate memory through the device context. |
|
|
|
if (!device_context->AllocateMemory(device_tensor, device_tensor->GetSize())) { |
|
|
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*op_context), device_context, |
|
|
|
from_aid.Name(), device_tensor->GetSize()); |
|
|
|
SetOpContextMemoryAllocFail(from_aid.Name(), device_context, device_tensor->GetSize(), op_context); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -69,8 +69,8 @@ void MemoryManagerActor::AllocateContinuousMemory(const std::vector<std::vector< |
|
|
|
auto &device_context = (*device_contexts)[i]; |
|
|
|
// Allocate memory through the device context. |
|
|
|
if (!device_context->AllocateContinuousMemory(alloc_list, total_size, size_list)) { |
|
|
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*op_context), device_context, |
|
|
|
from_aid.Name(), total_size); |
|
|
|
SetOpContextMemoryAllocFail(from_aid.Name(), device_context, total_size, op_context); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -100,8 +100,8 @@ void MemoryManagerActor::AllocateBatchMemory(const std::vector<DeviceTensor *> * |
|
|
|
|
|
|
|
// Allocate memory through the device context. |
|
|
|
if (!device_context->AllocateMemory(device_tensor, device_tensor->GetSize())) { |
|
|
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*op_context), device_context, |
|
|
|
from_aid.Name(), device_tensor->GetSize()); |
|
|
|
SetOpContextMemoryAllocFail(from_aid.Name(), device_context, device_tensor->GetSize(), op_context); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -165,5 +165,22 @@ void MemoryManagerActor::Wait(OpContext<DeviceTensor> *const op_context, const A |
|
|
|
// Call back to the from actor to process. |
|
|
|
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context); |
|
|
|
} |
|
|
|
|
|
|
|
void MemoryManagerActor::SetOpContextMemoryAllocFail(const std::string &kernel_name, |
|
|
|
const DeviceContext *device_context, size_t alloc_size, |
|
|
|
OpContext<DeviceTensor> *const op_context) { |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
MS_EXCEPTION_IF_NULL(op_context); |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_context->sequential_num_); |
|
|
|
auto step_id = uuids::uuid::ToBytes(*(op_context->sequential_num_)); |
|
|
|
// First occur allocating memory failed. |
|
|
|
if (mem_alloc_failed_step_ids_.find(step_id) == mem_alloc_failed_step_ids_.end()) { |
|
|
|
mem_alloc_failed_step_ids_.clear(); |
|
|
|
(void)mem_alloc_failed_step_ids_.insert(step_id); |
|
|
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*op_context), device_context, |
|
|
|
kernel_name, alloc_size); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace runtime |
|
|
|
} // namespace mindspore |