Browse Source

unify runtime supports summary

tags/v1.3.0
lizhenyu 4 years ago
parent
commit
05473b2f2b
3 changed files with 45 additions and 0 deletions
  1. +34
    -0
      mindspore/ccsrc/runtime/framework/graph_compiler.cc
  2. +6
    -0
      mindspore/ccsrc/runtime/framework/graph_compiler.h
  3. +5
    -0
      mindspore/ccsrc/vm/backend.cc

+ 34
- 0
mindspore/ccsrc/runtime/framework/graph_compiler.cc View File

@@ -231,6 +231,25 @@ void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
}
}
}

void SetSummaryNodesRefCount(const KernelGraph *graph) {
if (!graph->summary_node_exist()) {
return;
}

const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
if (summary_nodes.empty()) {
return;
}

for (const auto &item : summary_nodes) {
const AnfNodePtr &node = item.second.first;
size_t index = IntToSize(item.second.second);
auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
MS_EXCEPTION_IF_NULL(device_address);
device_address->set_original_ref_count(SIZE_MAX);
}
}
} // namespace

void GraphCompiler::set_device_context(DeviceContext *device_context) {
@@ -272,6 +291,9 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(session_);
session_->InitAllBucket(graph, device_context_);

session_->SetSummaryNodes(graph.get());
SetSummaryNodesRefCount(graph.get());

return graph->graph_id();
}

@@ -412,5 +434,17 @@ void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
MS_EXCEPTION_IF_NULL(session_);
session_->ClearAllBucket(graph_id);
}

void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
MS_EXCEPTION_IF_NULL(session_);
session_->RegisterSummaryCallBackFunc(callback);
}

void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
MS_EXCEPTION_IF_NULL(session_);
for (const auto &graph : graphs) {
session_->Summary(graph.get());
}
}
} // namespace runtime
} // namespace mindspore

+ 6
- 0
mindspore/ccsrc/runtime/framework/graph_compiler.h View File

@@ -29,6 +29,7 @@
namespace mindspore {
using device::DeviceContext;
using mindspore::tensor::TensorPtr;
using session::CallBackFunc;
using session::InputTensorInfo;
using session::KernelWithIndex;
using session::OpRunInfo;
@@ -94,6 +95,11 @@ class GraphCompiler {
// operator.
void ClearAllBucket(const GraphId &graph_id);

// Register a summary callback function, which is called in the final stages of summary.
void RegisterSummaryCallBackFunc(const CallBackFunc &callback) const;
// Execute graph summary.
void Summary(const std::vector<KernelGraphPtr> &graphs) const;

private:
GraphCompiler() = default;
~GraphCompiler() = default;


+ 5
- 0
mindspore/ccsrc/vm/backend.cc View File

@@ -246,6 +246,9 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphPtr root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
// Register a summary callback function, which is called in the final stages of summary.
runtime::GraphCompiler::GetInstance().RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);

// Compile root graph.
graph_id_to_device_context_.clear();
control_nodes_.clear();
@@ -468,6 +471,8 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(outputs.elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
MS_LOG(INFO) << "Run actor end, actor name: " << actor_info;

runtime::GraphCompiler::GetInstance().Summary(graph_compiler_info.graphs_);
return outputs;
}



Loading…
Cancel
Save