|
|
|
@@ -1352,39 +1352,54 @@ void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &m |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod); |
|
|
|
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) { |
|
|
|
auto tensor = graph.GetNodeOutputTensor(std::make_pair(kernel, j)); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, j, true); |
|
|
|
if (mock) { |
|
|
|
if (graph.IsInternalOutput(kernel, j) && device_address != nullptr) { |
|
|
|
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh); |
|
|
|
} |
|
|
|
for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) { |
|
|
|
const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true); |
|
|
|
if (input_node_index.first == nullptr || !input_node_index.first->isa<Parameter>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (tensor == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (device_address == nullptr) { |
|
|
|
tensor->data_sync(false); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
tensor->set_sync_status(kNeedSyncHostToDevice); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!SyncStream()) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncStream failed"; |
|
|
|
} |
|
|
|
auto origin_ptr = device_address->ptr_; |
|
|
|
if (origin_ptr == nullptr) { |
|
|
|
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_); |
|
|
|
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock); |
|
|
|
} |
|
|
|
for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) { |
|
|
|
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, |
|
|
|
const KernelWithIndex &node_output_index, const session::KernelGraph &graph, |
|
|
|
bool mock) { |
|
|
|
MS_EXCEPTION_IF_NULL(mem_scheduler); |
|
|
|
if (node_output_index.first == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true); |
|
|
|
if (mock) { |
|
|
|
if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) { |
|
|
|
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh); |
|
|
|
} |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
return; |
|
|
|
} |
|
|
|
auto tensor = graph.GetNodeOutputTensor(node_output_index); |
|
|
|
if (tensor == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (device_address == nullptr) { |
|
|
|
tensor->data_sync(false); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
if (origin_ptr == nullptr) { |
|
|
|
device_address->ptr_ = nullptr; |
|
|
|
} |
|
|
|
tensor->set_sync_status(kNeedSyncHostToDevice); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!SyncStream()) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncStream failed"; |
|
|
|
} |
|
|
|
auto origin_ptr = device_address->ptr_; |
|
|
|
if (device_address->ptr_ == nullptr) { |
|
|
|
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_); |
|
|
|
} |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
tensor->data_sync(false); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
device_address->ptr_ = origin_ptr; |
|
|
|
tensor->set_sync_status(kNeedSyncHostToDevice); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, |
|
|
|
|