|
|
|
@@ -231,6 +231,25 @@ void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SetSummaryNodesRefCount(const KernelGraph *graph) { |
|
|
|
if (!graph->summary_node_exist()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes(); |
|
|
|
if (summary_nodes.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &item : summary_nodes) { |
|
|
|
const AnfNodePtr &node = item.second.first; |
|
|
|
size_t index = IntToSize(item.second.second); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
device_address->set_original_ref_count(SIZE_MAX); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void GraphCompiler::set_device_context(DeviceContext *device_context) { |
|
|
|
@@ -272,6 +291,9 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
session_->InitAllBucket(graph, device_context_); |
|
|
|
|
|
|
|
session_->SetSummaryNodes(graph.get()); |
|
|
|
SetSummaryNodesRefCount(graph.get()); |
|
|
|
|
|
|
|
return graph->graph_id(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -412,5 +434,17 @@ void GraphCompiler::ClearAllBucket(const GraphId &graph_id) { |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
session_->ClearAllBucket(graph_id); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const { |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
session_->RegisterSummaryCallBackFunc(callback); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
for (const auto &graph : graphs) { |
|
|
|
session_->Summary(graph.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace runtime |
|
|
|
} // namespace mindspore |