|
|
|
@@ -999,16 +999,16 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); |
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || |
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
tensor->data_c())) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || |
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
} |
|
|
|
} |
|
|
|
tensor->set_sync_status(kNoNeedSync); |
|
|
|
} |
|
|
|
|