|
|
|
@@ -82,22 +82,21 @@ void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) { |
|
|
|
std::string error_info = "Launch graph exception, graph id: " + std::to_string(graph_->graph_id()); |
|
|
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); |
|
|
|
} |
|
|
|
|
|
|
|
for (auto item : ref_node_addr_map_) { |
|
|
|
const auto &input_node = item.first; |
|
|
|
auto formal_param_addr = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); |
|
|
|
MS_EXCEPTION_IF_NULL(item.first); |
|
|
|
MS_EXCEPTION_IF_NULL(item.second); |
|
|
|
auto formal_param_addr = AnfAlgo::GetMutableOutputAddr(item.first, 0, false); |
|
|
|
MS_EXCEPTION_IF_NULL(formal_param_addr); |
|
|
|
auto device_address = item.second; |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
MS_LOG(INFO) << "The input ref_node: " << input_node->DebugString() |
|
|
|
MS_LOG(INFO) << "The input ref_node: " << item.first->DebugString() |
|
|
|
<< " need copy back, from address: " << formal_param_addr->GetPtr() |
|
|
|
<< " to address: " << device_address->GetPtr() << "."; |
|
|
|
if (!device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(input_node, 0), formal_param_addr->GetSize(), |
|
|
|
formal_param_addr->type_id(), formal_param_addr->GetPtr(), |
|
|
|
formal_param_addr->format())) { |
|
|
|
MS_LOG(EXCEPTION) << "Sync device to device failed."; |
|
|
|
<< " to address: " << item.second->GetPtr() << "."; |
|
|
|
if (!Copy(item.second, formal_param_addr.get())) { |
|
|
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed."); |
|
|
|
} |
|
|
|
} |
|
|
|
ref_node_addr_map_.clear(); |
|
|
|
|
|
|
|
PostRun(context); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -110,6 +109,7 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) { |
|
|
|
} |
|
|
|
|
|
|
|
auto &input_nodes = graph_->input_nodes(); |
|
|
|
// Copy input data. |
|
|
|
for (auto &input_data : data_iter->second) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_data); |
|
|
|
if (IntToSize(input_data->index_) >= input_nodes.size()) { |
|
|
|
@@ -120,25 +120,18 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
|
|
|
|
auto &input_device_tensor = input_data->data_; |
|
|
|
MS_EXCEPTION_IF_NULL(input_device_tensor); |
|
|
|
if (input_device_tensor->DeviceType() != device_address->DeviceType()) { |
|
|
|
MS_LOG(ERROR) << "The input data device type:" << input_device_tensor->DeviceType() |
|
|
|
<< " is not equal to the graph node device type:" << device_address->DeviceType() << "."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (input_device_tensor->GetPtr() == device_address->GetPtr()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "The input data of node:" << input_node->DebugString() |
|
|
|
<< " need copy from address:" << input_device_tensor->GetPtr() |
|
|
|
<< " to address:" << device_address->GetPtr() << "."; |
|
|
|
if (!device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(input_node, 0), |
|
|
|
input_device_tensor->GetSize(), input_device_tensor->type_id(), |
|
|
|
input_device_tensor->GetPtr(), input_device_tensor->format())) { |
|
|
|
MS_LOG(ERROR) << "Sync device to device failed."; |
|
|
|
<< ", type:" << input_device_tensor->DeviceType() << " to address:" << device_address->GetPtr() |
|
|
|
<< ", type:" << device_address->DeviceType() << "."; |
|
|
|
if (!Copy(device_address.get(), input_device_tensor)) { |
|
|
|
MS_LOG(ERROR) << "Copy data failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (HasAbstractRef(input_node) && ref_node_addr_map_.count(input_node) == 0) { |
|
|
|
@@ -146,6 +139,29 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Check device tensor store. |
|
|
|
for (auto &device_tensor_store_key : device_tensor_store_keys_) { |
|
|
|
auto input_device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(), |
|
|
|
device_contexts_[0]->GetDeviceAddressType()); |
|
|
|
MS_EXCEPTION_IF_NULL(input_device_tensor); |
|
|
|
if (device_tensor_store_key.first >= input_nodes.size()) { |
|
|
|
MS_LOG(ERROR) << "The input index:" << device_tensor_store_key.first << "is out of range:" << input_nodes.size(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto input_node = input_nodes[device_tensor_store_key.first]; |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (input_device_tensor->GetPtr() != device_address->GetPtr()) { |
|
|
|
MS_LOG(ERROR) << "The input data of node:" << input_node->DebugString() |
|
|
|
<< " device address:" << input_device_tensor->GetPtr() |
|
|
|
<< ", type:" << input_device_tensor->DeviceType() |
|
|
|
<< " is not equal to the graph node device address:" << device_address->GetPtr() |
|
|
|
<< ", type:" << device_address->DeviceType() << "."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace runtime |
|
|
|
|