|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -606,6 +606,37 @@ void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const D |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void DataPrepareActor::CopyDataFromHostToOtherDevice(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, |
|
|
|
const device::DeviceAddressPtr &host_tensor_address, |
|
|
|
const DeviceContext *device_context, |
|
|
|
OpContext<DeviceTensor> *context) const { |
|
|
|
MS_EXCEPTION_IF_NULL(backend_node); |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
MS_EXCEPTION_IF_NULL(context); |
|
|
|
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); |
|
|
|
if (device_tensors.size() > 1) { |
|
|
|
auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0]; |
|
|
|
MS_EXCEPTION_IF_NULL(another_device_tensor); |
|
|
|
auto another_device_type = another_device_tensor->DeviceType(); |
|
|
|
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( |
|
|
|
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_}); |
|
|
|
MS_EXCEPTION_IF_NULL(another_device_context); |
|
|
|
if ((another_device_tensor->GetPtr() == nullptr) && |
|
|
|
(!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) { |
|
|
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context, |
|
|
|
backend_node->fullname_with_scope(), |
|
|
|
another_device_tensor->GetSize()); |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope() |
|
|
|
<< ", device type:" << another_device_type; |
|
|
|
if (!Copy(another_device_tensor.get(), host_tensor_address.get())) { |
|
|
|
std::string error_info = "Sync data error."; |
|
|
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Prepare the device data for persistent device tensor of weight node from host tensor. |
|
|
|
void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node, |
|
|
|
const TensorPtr &tensor, const DeviceContext *device_context, |
|
|
|
@@ -639,7 +670,9 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, |
|
|
|
UpdateRefCount(host_tensor_address.get(), true); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(host_tensor_address); |
|
|
|
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) { |
|
|
|
|
|
|
|
if (host_tensor_address->DeviceType() == device_tensor->DeviceType() && |
|
|
|
!(host_tensor_address->format() != device_tensor->format() && strategy_ == GraphExecutionStrategy::kStep)) { |
|
|
|
// In the scenario of training + inference , the device address of the weight node can not be changed when |
|
|
|
// multi-graphs sink mode is set. |
|
|
|
if (device_tensor->is_ptr_persisted() && (host_tensor_address != device_tensor)) { |
|
|
|
@@ -653,8 +686,9 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, |
|
|
|
host_tensor_address->SetNodeIndex(backend_node, 0); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() |
|
|
|
<< ", device tensor type:" << device_tensor->DeviceType(); |
|
|
|
MS_LOG(INFO) << "The device type or format is not equal, host tensor type:" << host_tensor_address->DeviceType() |
|
|
|
<< " format:" << host_tensor_address->format() |
|
|
|
<< ", device tensor type:" << device_tensor->DeviceType() << " format:" << device_tensor->format(); |
|
|
|
if (strategy_ == GraphExecutionStrategy::kStep) { |
|
|
|
tensor->data_sync(); |
|
|
|
host_tensor_address = device_tensor; |
|
|
|
@@ -677,28 +711,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, |
|
|
|
} |
|
|
|
|
|
|
|
// Allocate another device memory and copy data from host tensor to another device(if exist). |
|
|
|
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); |
|
|
|
if (device_tensors.size() > 1) { |
|
|
|
auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0]; |
|
|
|
MS_EXCEPTION_IF_NULL(another_device_tensor); |
|
|
|
auto another_device_type = another_device_tensor->DeviceType(); |
|
|
|
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( |
|
|
|
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_}); |
|
|
|
MS_EXCEPTION_IF_NULL(another_device_context); |
|
|
|
if ((another_device_tensor->GetPtr() == nullptr) && |
|
|
|
(!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) { |
|
|
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context, |
|
|
|
backend_node->fullname_with_scope(), |
|
|
|
another_device_tensor->GetSize()); |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope() |
|
|
|
<< ", device type:" << another_device_type; |
|
|
|
if (!Copy(another_device_tensor.get(), host_tensor_address.get())) { |
|
|
|
std::string error_info = "Sync data error."; |
|
|
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info); |
|
|
|
} |
|
|
|
} |
|
|
|
CopyDataFromHostToOtherDevice(front_node, backend_node, host_tensor_address, device_context, context); |
|
|
|
} |
|
|
|
|
|
|
|
void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser, |
|
|
|
|