|
|
|
@@ -231,15 +231,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr & |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope() |
|
|
|
<< ", device type:" << another_device_type; |
|
|
|
if (host_tensor_address->DeviceType() == device::DeviceAddressType::kCPU) { |
|
|
|
// CPU device tensor copy to other device tensor. |
|
|
|
(void)another_device_tensor->SyncHostToDevice(host_tensor_address->GetSize(), host_tensor_address->GetPtr()); |
|
|
|
} else if (another_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { |
|
|
|
// Other device tensor copy to CPU device tensor. |
|
|
|
(void)host_tensor_address->SyncDeviceToHost(another_device_tensor->GetSize(), |
|
|
|
another_device_tensor->GetMutablePtr()); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid device type for sync data."; |
|
|
|
if (!Copy(another_device_tensor.get(), host_tensor_address.get())) { |
|
|
|
MS_LOG(EXCEPTION) << "Sync data error."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -312,10 +305,11 @@ void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size |
|
|
|
} |
|
|
|
|
|
|
|
(*host_tensors)[iter->second] = tensor; |
|
|
|
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()); |
|
|
|
if (device_address != nullptr) { |
|
|
|
AnfAlgo::SetOutputAddr(device_address, 0, node.get()); |
|
|
|
return; |
|
|
|
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) { |
|
|
|
AnfAlgo::SetOutputAddr(tensor_address, 0, node.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1565,8 +1559,14 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor |
|
|
|
|
|
|
|
// Set the member of the copy actor. |
|
|
|
MS_EXCEPTION_IF_NULL(from_device_tensor); |
|
|
|
auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first); |
|
|
|
MS_EXCEPTION_IF_NULL(to_kernel_mod); |
|
|
|
auto input_sizes = to_kernel_mod->GetInputSizeList(); |
|
|
|
if (to_input_index >= input_sizes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "To input index(" << to_input_index << ") is out of size: " << input_sizes.size(); |
|
|
|
} |
|
|
|
copy_actor->output_ = to_devcie_context->CreateDeviceAddress( |
|
|
|
nullptr, from_device_tensor->GetSize(), from_device_tensor->format(), from_device_tensor->type_id()); |
|
|
|
nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id()); |
|
|
|
MS_EXCEPTION_IF_NULL(from_devcie_context); |
|
|
|
copy_actor->input_device_context_ = from_devcie_context; |
|
|
|
copy_actor->output_device_context_ = to_devcie_context; |
|
|
|
|