|
|
|
@@ -1456,23 +1456,23 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); |
|
|
|
const auto tensor_size = LongToSize(tensor->data().nbytes()); |
|
|
|
if (tensor_address == device_address) { |
|
|
|
if (tensor->NeedSyncHostToDevice()) { |
|
|
|
tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(), |
|
|
|
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_); |
|
|
|
tensor->set_sync_status(kNoNeedSync); |
|
|
|
} |
|
|
|
if (mem_scheduler->HasDeviceMem(tensor_address.get())) { |
|
|
|
tensor_address->set_ptr(nullptr); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
bool need_sync = false; |
|
|
|
if (tensor->NeedSyncHostToDevice()) { |
|
|
|
mem_scheduler->AddMemNeedInit(device_address.get()); |
|
|
|
} else if (tensor_address != nullptr) { |
|
|
|
need_sync = true; |
|
|
|
} else if (tensor_address != device_address) { |
|
|
|
tensor->data_sync(false); |
|
|
|
mem_scheduler->AddMemNeedInit(device_address.get()); |
|
|
|
need_sync = true; |
|
|
|
} |
|
|
|
if (mem_scheduler->HasDeviceMem(device_address.get())) { |
|
|
|
device_address->set_ptr(nullptr); |
|
|
|
} |
|
|
|
if (need_sync) { |
|
|
|
if (device_address->GetPtr() != nullptr) { |
|
|
|
device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(), |
|
|
|
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_); |
|
|
|
} else { |
|
|
|
mem_scheduler->AddMemNeedInit(device_address.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
MemPriority priority = kMemPriorityLow; |
|
|
|
const auto ¶meter = input_node->cast<ParameterPtr>(); |
|
|
|
@@ -1642,18 +1642,17 @@ void KernelRuntime::SyncParameter(const session::KernelGraph &graph, |
|
|
|
if (!AnfAlgo::IsParameterWeight(parameter) && !graph.IsUpdatedParameter(parameter)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto tensor = input_tensors[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
if (mem_scheduler->HasDeviceMem(device_address.get())) { |
|
|
|
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh); |
|
|
|
device_address->set_ptr(device_ptr); |
|
|
|
auto tensor = input_tensors[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto origin_tensor_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); |
|
|
|
if (origin_tensor_device_address != nullptr) { |
|
|
|
origin_tensor_device_address->set_ptr(nullptr); |
|
|
|
} |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
tensor->set_sync_status(kNeedSyncDeviceToHost); |
|
|
|
} |
|
|
|
if (graph.IsUpdatedParameter(parameter)) { |
|
|
|
tensor->SetIsUpdateByDevice(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|