Browse Source

Skip front cnode for device tensor store.

r1.7
gaoyong10 4 years ago
parent
commit
4e0dbffbc1
2 changed files with 6 additions and 2 deletions
  1. +2
    -0
      mindspore/ccsrc/runtime/graph_scheduler/control_node_parser.cc
  2. +4
    -2
      mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc

+ 2
- 0
mindspore/ccsrc/runtime/graph_scheduler/control_node_parser.cc View File

@@ -1869,6 +1869,7 @@ void ControlNodeParser::ParseKernelGraphGroup(const KernelGraphToDeviceContext &

kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info;
if (kernel_graph_group_info->need_stack_) {
MS_LOG(DEBUG) << "Add call input kernel graph:" << kernel_graph->ToString();
(void)call_input_kernel_graphs_.emplace(kernel_graph.get());
}
}
@@ -1876,6 +1877,7 @@ void ControlNodeParser::ParseKernelGraphGroup(const KernelGraphToDeviceContext &
for (const auto &graph : kernel_graph_group_info->graphs_) {
kernel_graph_group_info->group_name_ += ("_" + std::to_string(graph->graph_id()));
}
MS_LOG(DEBUG) << "Add kernel graph info for group:" << kernel_graph_group_info->group_name_;
(void)kernel_graph_group_infos_.emplace(kernel_graph_group_info);
}
}


+ 4
- 2
mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc View File

@@ -1903,7 +1903,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ab
}

void GraphScheduler::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(device_tensor);
MS_LOG(DEBUG) << "Add device tensor store:" << device_tensor << " for node:" << anf_node->DebugString();
DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
UpdateRefCount(device_tensor.get(), true);
}
@@ -2050,8 +2052,8 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
front_node = FetchFrontNodeByBackendNode(input_node, graph);
}
// The front node may be value node in the heterogeneous scene, needs to handle.
if ((front_node == nullptr) ||
(front_node->isa<Parameter>() && !parser->IsRootGraphPersistentDeviceTensor(front_node))) {
if ((front_node == nullptr) || ((front_node->isa<Parameter>() || front_node->isa<CNode>()) &&
(!parser->IsRootGraphPersistentDeviceTensor(front_node)))) {
continue;
}



Loading…
Cancel
Save