Browse Source

[Bugfix]use SyncHostToDevice when tensor address format is different from parameter address format

feature/build-system-rewrite
caifubi 4 years ago
parent
commit
d7cb3cc5e7
2 changed files with 45 additions and 27 deletions
  1. +39
    -26
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc
  2. +6
    -1
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h

+ 39
- 26
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -606,6 +606,37 @@ void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const D
}
}

void DataPrepareActor::CopyDataFromHostToOtherDevice(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
const device::DeviceAddressPtr &host_tensor_address,
const DeviceContext *device_context,
OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(context);
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
if (device_tensors.size() > 1) {
auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0];
MS_EXCEPTION_IF_NULL(another_device_tensor);
auto another_device_type = another_device_tensor->DeviceType();
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(another_device_context);
if ((another_device_tensor->GetPtr() == nullptr) &&
(!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context,
backend_node->fullname_with_scope(),
another_device_tensor->GetSize());
}

MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
<< ", device type:" << another_device_type;
if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
std::string error_info = "Sync data error.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
}
}
}

// Prepare the device data for persistent device tensor of weight node from host tensor.
void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node,
const TensorPtr &tensor, const DeviceContext *device_context,
@@ -639,7 +670,9 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
UpdateRefCount(host_tensor_address.get(), true);
}
MS_EXCEPTION_IF_NULL(host_tensor_address);
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {

if (host_tensor_address->DeviceType() == device_tensor->DeviceType() &&
!(host_tensor_address->format() != device_tensor->format() && strategy_ == GraphExecutionStrategy::kStep)) {
// In the scenario of training + inference , the device address of the weight node can not be changed when
// multi-graphs sink mode is set.
if (device_tensor->is_ptr_persisted() && (host_tensor_address != device_tensor)) {
@@ -653,8 +686,9 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
host_tensor_address->SetNodeIndex(backend_node, 0);
}
} else {
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType();
MS_LOG(INFO) << "The device type or format is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< " format:" << host_tensor_address->format()
<< ", device tensor type:" << device_tensor->DeviceType() << " format:" << device_tensor->format();
if (strategy_ == GraphExecutionStrategy::kStep) {
tensor->data_sync();
host_tensor_address = device_tensor;
@@ -677,28 +711,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
}

// Allocate another device memory and copy data from host tensor to another device(if exist).
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
if (device_tensors.size() > 1) {
auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0];
MS_EXCEPTION_IF_NULL(another_device_tensor);
auto another_device_type = another_device_tensor->DeviceType();
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(another_device_context);
if ((another_device_tensor->GetPtr() == nullptr) &&
(!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context,
backend_node->fullname_with_scope(),
another_device_tensor->GetSize());
}

MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
<< ", device type:" << another_device_type;
if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
std::string error_info = "Sync data error.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
}
}
CopyDataFromHostToOtherDevice(front_node, backend_node, host_tensor_address, device_context, context);
}

void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,


+ 6
- 1
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -100,6 +100,11 @@ class DataPrepareActor : public DebugAwareActor {
void PrepareDataForControlValueNode(const KernelWithIndex &node_with_index, const DeviceContext *device_context,
OpContext<DeviceTensor> *const context);

// Extract this function in order to reduce the cyclomatic complexity
void CopyDataFromHostToOtherDevice(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
const device::DeviceAddressPtr &host_tensor_address,
const DeviceContext *device_context, OpContext<DeviceTensor> *context) const;

const GraphCompilerInfo *graph_compiler_info_;
GraphExecutionStrategy strategy_;
GraphExecutionStrategy real_strategy_;


Loading…
Cancel
Save