|
|
|
@@ -56,55 +56,50 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { |
|
|
|
return graph->graph_id(); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, |
|
|
|
VectorRef *outputs) { |
|
|
|
GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, |
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, |
|
|
|
const std::vector<int64_t> &tensors_mask) { |
|
|
|
// Check if the graph cache exists. |
|
|
|
auto iter = run_op_graphs_.find(graph_info); |
|
|
|
if (iter != run_op_graphs_.end()) { |
|
|
|
const auto &graph = iter->second; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
return graph->graph_id(); |
|
|
|
} |
|
|
|
// Generate kernel graph. |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
auto graph = session_->GetGraph(graph_id); |
|
|
|
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto actor_set = GraphScheduler::GetInstance().Fetch(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(actor_set); |
|
|
|
GraphScheduler::GetInstance().Run(actor_set); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, |
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, |
|
|
|
const std::vector<int64_t> &tensors_mask, VectorRef *outputs) { |
|
|
|
// Check if the graph cache exists. |
|
|
|
if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) { |
|
|
|
// Prepare the graph |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(device_context_); |
|
|
|
device_context_->SetOperatorInfo(graph->execution_order()); |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_context_); |
|
|
|
device_context_->SetOperatorInfo(graph->execution_order()); |
|
|
|
device_context_->OptimizeSingleOpGraph(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
session_->RunOpHideNopNode(graph); |
|
|
|
session_->RunOpRemoveNopNode(graph); |
|
|
|
|
|
|
|
device_context_->OptimizeSingleOpGraph(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
session_->RunOpHideNopNode(graph); |
|
|
|
// Generate 'KernelMod' for kernel in graph. |
|
|
|
device_context_->CreateKernel(graph->execution_order()); |
|
|
|
|
|
|
|
device_context_->CreateKernel(graph->execution_order()); |
|
|
|
run_op_graphs_[graph_info] = graph; |
|
|
|
} |
|
|
|
// Transform graph to actor DAG, contains build and link. |
|
|
|
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); |
|
|
|
run_op_graphs_[graph_info] = graph; |
|
|
|
return graph->graph_id(); |
|
|
|
} |
|
|
|
|
|
|
|
session_->EraseValueNodeTensor(tensors_mask, input_tensors); |
|
|
|
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const { |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
return session_->GetGraph(graph_id); |
|
|
|
} |
|
|
|
|
|
|
|
// wait for allreduce |
|
|
|
for (auto &tensor : *input_tensors) { |
|
|
|
if (tensor->NeedWaitDevice()) { |
|
|
|
tensor->WaitDevice(); |
|
|
|
} |
|
|
|
KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const { |
|
|
|
auto iter = run_op_graphs_.find(graph_info); |
|
|
|
if (iter == run_op_graphs_.end()) { |
|
|
|
MS_LOG(ERROR) << "Can't find graph for: " << graph_info; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// run op |
|
|
|
auto graph = run_op_graphs_[graph_info]; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
session_->RunOpRemoveNopNode(graph); |
|
|
|
|
|
|
|
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); |
|
|
|
auto actor_set = GraphScheduler::GetInstance().Fetch(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(actor_set); |
|
|
|
GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep); |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
} // namespace runtime |
|
|
|
} // namespace mindspore |