|
|
|
@@ -676,7 +676,9 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
} |
|
|
|
} |
|
|
|
if (need_sync) { |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
if (AnfAlgo::IsParameterWeight(pk_node)) { |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
|