|
|
|
@@ -103,17 +103,19 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, |
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { |
|
|
|
auto pk_node = input_node->cast<ParameterPtr>(); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); |
|
|
|
auto tensor_address = tensor->device_address(); |
|
|
|
bool need_sync = false; |
|
|
|
if (ms_context->enable_pynative_infer()) { |
|
|
|
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { |
|
|
|
if (tensor_address.get() == nullptr || tensor_address != device_address) { |
|
|
|
need_sync = true; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (tensor->is_dirty()) { |
|
|
|
} else if (tensor->is_dirty()) { |
|
|
|
need_sync = true; |
|
|
|
} else if (tensor_address != device_address) { |
|
|
|
if (tensor_address->DeviceType() == device_address->DeviceType()) { |
|
|
|
AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get()); |
|
|
|
} else { |
|
|
|
need_sync = true; |
|
|
|
} else if (tensor->device_address() != device_address) { |
|
|
|
AnfAlgo::SetOutputAddr(tensor->device_address(), 0, pk_node.get()); |
|
|
|
need_sync = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (need_sync) { |
|
|
|
|