|
|
|
@@ -1124,13 +1124,21 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, &tensor_to_node)); |
|
|
|
} |
|
|
|
|
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
for (auto &item : tensor_to_node) { |
|
|
|
auto &tensor = item.first; |
|
|
|
auto &node = item.second.first; |
|
|
|
auto &output_index = item.second.second; |
|
|
|
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
tensor->set_device_address(address); |
|
|
|
tensor->SetNeedWait(false); |
|
|
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { |
|
|
|
tensor->data_sync(false); |
|
|
|
tensor->set_sync_status(kNeedSyncHostToDevice); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|