|
|
|
@@ -1903,7 +1903,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ab |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
MS_EXCEPTION_IF_NULL(device_tensor); |
|
|
|
MS_LOG(DEBUG) << "Add device tensor store:" << device_tensor << " for node:" << anf_node->DebugString(); |
|
|
|
DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor); |
|
|
|
UpdateRefCount(device_tensor.get(), true); |
|
|
|
} |
|
|
|
@@ -2050,8 +2052,8 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler |
|
|
|
front_node = FetchFrontNodeByBackendNode(input_node, graph); |
|
|
|
} |
|
|
|
// The front node may be value node in the heterogeneous scene, needs to handle. |
|
|
|
if ((front_node == nullptr) || |
|
|
|
(front_node->isa<Parameter>() && !parser->IsRootGraphPersistentDeviceTensor(front_node))) { |
|
|
|
if ((front_node == nullptr) || ((front_node->isa<Parameter>() || front_node->isa<CNode>()) && |
|
|
|
(!parser->IsRootGraphPersistentDeviceTensor(front_node)))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
|