|
|
|
@@ -29,6 +29,7 @@ |
|
|
|
#include "device/ascend/ascend_kernel_runtime.h" |
|
|
|
#include "device/ascend/ascend_device_address.h" |
|
|
|
#include "pre_activate/ascend/ascend_backend_optimization.h" |
|
|
|
#include "pre_activate/common/common_backend_optimization.h" |
|
|
|
#include "device/kernel_adjust.h" |
|
|
|
#include "device/ascend/ascend_stream_assign.h" |
|
|
|
#include "device/ascend/ascend_label_assign.h" |
|
|
|
@@ -283,36 +284,38 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL |
|
|
|
|
|
|
|
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { |
|
|
|
MS_LOG(INFO) << "start"; |
|
|
|
auto graph = ConstructKernelGraph(func_graph); |
|
|
|
std::vector<KernelGraphPtr> all_graphs; |
|
|
|
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); |
|
|
|
BackendOptimization(all_graphs); |
|
|
|
// split switch |
|
|
|
SplitGraphs(NOT_NULL(graph)); |
|
|
|
SplitGraphs(NOT_NULL(root_graph)); |
|
|
|
// insert goto labels and label_sets |
|
|
|
LinkChildGraphs(NOT_NULL(graph)); |
|
|
|
LinkChildGraphs(NOT_NULL(root_graph)); |
|
|
|
// resource initialize |
|
|
|
InitRuntimeResource(); |
|
|
|
// assign label |
|
|
|
AssignLabel(NOT_NULL(graph)); |
|
|
|
// recurse compile child graph |
|
|
|
AssignLabel(NOT_NULL(root_graph)); |
|
|
|
// recurse compile child root_graph |
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo)); |
|
|
|
// root graph valiate,include genearte execute order and so on |
|
|
|
RootGraphExecutorValidate(NOT_NULL(graph)); |
|
|
|
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); |
|
|
|
// root root_graph valiate,include genearte execute order and so on |
|
|
|
RootGraphExecutorValidate(NOT_NULL(root_graph)); |
|
|
|
// adjust kernel |
|
|
|
AdjustKernel(graph); |
|
|
|
AdjustKernel(root_graph); |
|
|
|
// assign stream |
|
|
|
AssignStream(graph); |
|
|
|
AssignStream(root_graph); |
|
|
|
// insert profiling point |
|
|
|
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); |
|
|
|
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); |
|
|
|
// build kernel |
|
|
|
BuildKernel(graph); |
|
|
|
BuildKernel(root_graph); |
|
|
|
// alloc mem |
|
|
|
MemoryAlloc(graph.get()); |
|
|
|
MemoryAlloc(root_graph.get()); |
|
|
|
// task generate |
|
|
|
GenerateTaskInfo(graph); |
|
|
|
GenerateTaskInfo(root_graph); |
|
|
|
// load task into device |
|
|
|
LoadTask(graph); |
|
|
|
// return the graph id to backend |
|
|
|
auto graph_id = graph->graph_id(); |
|
|
|
LoadTask(root_graph); |
|
|
|
// return the root_graph id to backend |
|
|
|
auto graph_id = root_graph->graph_id(); |
|
|
|
return graph_id; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1569,6 +1572,14 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt |
|
|
|
return call_node_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) { |
|
|
|
MS_LOG(INFO) << "Start BackendCommonOptimization"; |
|
|
|
for (auto &graph : all_graphs) { |
|
|
|
opt::BackendCommonOptimization(graph); |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "End."; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence |
|
|
|
|