Browse Source

fix bug of control flow togther host and device

tags/v1.4.0
limingqi107 4 years ago
parent
commit
7594537a01
5 changed files with 86 additions and 8 deletions
  1. +38
    -0
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc
  2. +4
    -0
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.h
  3. +14
    -0
      mindspore/ccsrc/runtime/framework/control_node_parser.cc
  4. +3
    -0
      mindspore/ccsrc/runtime/framework/control_node_parser.h
  5. +27
    -8
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc

+ 38
- 0
mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc View File

@@ -34,6 +34,7 @@ void KernelActor::Init() {
is_dynamic_shape_ = AnfAlgo::IsDynamicShape(kernel_);

// Init the device tensors and kernel launch info.
copy_input_device_tensors_.resize(real_input_num_);
input_device_tensors_.resize(real_input_num_);
for (auto &input_address : input_device_tensors_) {
memory_free_list_.emplace_back(input_address);
@@ -277,6 +278,42 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
}
}

void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(input_data);
if ((input_data->data_ == nullptr) || (input_data->data_->DeviceType() == device_context_->GetDeviceAddressType())) {
return;
}

MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType()
<< " to device type: " << device_context_->GetDeviceAddressType() << " in " << GetAID().Name();
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
copy_input_device_tensors_[input_data->index_] = device_context_->CreateDeviceAddress(
nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id());
}
// Dynamic shape need update size.
copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize());

if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) {
if (!device_context_->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
copy_input_device_tensors_[input_data->index_]->GetSize())) {
std::string error_info =
"Device(id:" + std::to_string(device_context_->device_context_key().device_id_) +
") memory isn't enough and alloc failed, actor name: " + GetAID().Name() +
", alloc size: " + std::to_string(copy_input_device_tensors_[input_data->index_]->GetSize());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}

if (!Copy(copy_input_device_tensors_[input_data->index_].get(), input_data->data_)) {
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}

// Update by the copy input device tensor.
input_device_tensors_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
}

void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
@@ -289,6 +326,7 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
input_device_tensors_[input_data->index_] = input_data->data_;
memory_free_list_[input_data->index_] = input_data->data_;
}
CopyInputDeviceTensor(input_data, context);
}
}



+ 4
- 0
mindspore/ccsrc/runtime/framework/actor/kernel_actor.h View File

@@ -91,6 +91,7 @@ class KernelActor : public DebugAwareActor {
// Fetch the device tensor for launch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
void FetchOutputDeviceTensor();
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);

@@ -144,6 +145,9 @@ class KernelActor : public DebugAwareActor {
std::vector<DeviceTensor *> input_device_tensors_;
std::vector<DeviceTensor *> output_device_tensors_;
std::vector<DeviceTensor *> workspace_device_tensors_;
// The received input device type may be different from the device context type in the control flow and host device
// scenarios, so it needs to be copied from the input device type to the device context type.
std::vector<DeviceTensorPtr> copy_input_device_tensors_;

// The device tensors for memory alloc and free.
// output + workspace


+ 14
- 0
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -1287,6 +1287,13 @@ void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front
std::vector<AnfNodePtr> dest_nodes;
FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters);
host_parameter_to_weights_[pair.first] = dest_nodes;

if (std::find(root_graph_parameters_.begin(), root_graph_parameters_.end(), pair.first) !=
root_graph_parameters_.end()) {
for (auto &sub_front_node : dest_nodes) {
sub_front_node_to_root_front_node_[sub_front_node] = pair.first;
}
}
}
}

@@ -1584,5 +1591,12 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &contro
}
}
}

AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node) {
if (sub_front_node_to_root_front_node_.count(sub_front_node) == 0) {
return sub_front_node;
}
return sub_front_node_to_root_front_node_[sub_front_node];
}
} // namespace runtime
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/runtime/framework/control_node_parser.h View File

@@ -126,6 +126,8 @@ class ControlNodeParser {
return front_to_backend_kernels_[front_node_with_index].first;
}

AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);

private:
friend class GraphScheduler;

@@ -221,6 +223,7 @@ class ControlNodeParser {
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
HostParameterToWeight host_parameter_to_weights_;
std::unordered_map<AnfNodePtr, AnfNodePtr> sub_front_node_to_root_front_node_;

// The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
// input of the control node.


+ 27
- 8
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -248,7 +248,14 @@ void PrepareDataForControlWeightNode(
MS_EXCEPTION_IF_NULL(device_context);

auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
if (device_tensors.empty()) {
bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false;
for (auto &device_tensor : device_tensors) {
if (device_tensor->GetPtr() == nullptr) {
need_update_device_tensor_store = true;
break;
}
}
if (need_update_device_tensor_store) {
PrepareDataForWeightNode(node, front_node, tensor, device_context);
}

@@ -455,7 +462,7 @@ void GraphScheduler::Initialize() {
auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
<< ", the computed OMP thread number : " << OMP_thread_num
<< ", the used OMP thread number : " << stoi(OMP_thread_num_used);
<< ", the used OMP thread number : " << OMP_thread_num_used;

BuildAndScheduleGlobalActor();
}
@@ -2719,20 +2726,27 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler

for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
AnfNodePtr front_node = nullptr;
AnfNodePtr sub_front_node = nullptr;
if (IsInternalParameter(input_node, graph)) {
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(input_node);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
const auto &front_output_with_index =
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
front_node = front_output_with_index.first;
} else if (IsPersistentDeviceTensor(input_node)) {
front_node = FetchFrontNodeByBackendNode(input_node, graph);
sub_front_node = front_output_with_index.first;
} else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
}
if (front_node == nullptr) {
if (sub_front_node == nullptr) {
continue;
}

// The sub front nodes share the device tensor store with the root front node.
auto front_node = sub_front_node;
if (graph_compiler_info.control_node_parser_ != nullptr) {
front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
}
MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
<< ", root front node:" << front_node->DebugString();
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (IsPersistentDeviceTensor(input_node)) {
@@ -3091,7 +3105,12 @@ void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compil
if (!IsPersistentDeviceTensor(input_node)) {
continue;
}
const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
// The sub front nodes share the device tensor store with the root front node.
auto front_node = sub_front_node;
if (graph_compiler_info.control_node_parser_ != nullptr) {
front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
}
const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
ofs << "\t\tdevcie tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
<< "\n";


Loading…
Cancel
Save