|
|
|
@@ -89,6 +89,7 @@ void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De |
|
|
|
|
|
|
|
// Save the output nodes to clear the device tensor in the running end. |
|
|
|
output_nodes_[output_position] = node_with_index; |
|
|
|
output_device_tensors_[output_position] = input_data->data_; |
|
|
|
} |
|
|
|
|
|
|
|
TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) { |
|
|
|
@@ -152,7 +153,10 @@ void OutputActor::UpdateOutputDeviceAddress() { |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_device_address); |
|
|
|
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false); |
|
|
|
if (i >= output_device_tensors_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid index:" << i << " current:" << output_device_tensors_.size(); |
|
|
|
} |
|
|
|
auto device_tensor = output_device_tensors_[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(device_tensor); |
|
|
|
|
|
|
|
// Update tensor device address by device tensor of output node. |
|
|
|
@@ -192,6 +196,8 @@ void OutputActor::UpdateOutputDeviceAddress() { |
|
|
|
output_node_to_tensor_device_address_.clear(); |
|
|
|
output_nodes_.clear(); |
|
|
|
output_nodes_.resize(outputs_num_); |
|
|
|
output_device_tensors_.clear(); |
|
|
|
output_device_tensors_.resize(outputs_num_); |
|
|
|
} |
|
|
|
} // namespace runtime |
|
|
|
} // namespace mindspore |