/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "vm/backend.h" #include #include #include "vm/transform.h" #include "backend/session/session_factory.h" #include "pipeline/pynative/pynative_execute.h" #include "ir/anf.h" #include "pybind_api/ir/base_ref_py.h" #include "utils/callbacks.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" #include "utils/ms_utils.h" #include "runtime/hardware/device_context_manager.h" #include "runtime/framework/graph_compiler.h" #include "runtime/framework/graph_scheduler.h" #include "utils/scoped_long_running.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" #endif namespace mindspore { namespace compile { bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast(c), value); } Backend::Backend(const std::string &name) : name_(name) { MS_LOG(DEBUG) << "select backend:" << name; convert_fn_ = MsVmConvert; is_multi_graph_sink_ = false; } LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) { MS_LOG(DEBUG) << "MsConvert"; MS_EXCEPTION_IF_NULL(segment); MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); auto cached = g_ConvertCache.find(segment); if (cached != g_ConvertCache.end()) { return cached->second; } LinConvertResult result; FuncGraphPtr fg; AnfNodePtrList inputs; AnfNodePtrList outputs; std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); result.inputs = inputs; result.outputs = outputs; result.graph_id = kInvalidGraphId; auto current_session = target_sess_; if (target != target_device_ && !target.empty()) { CreateOtherSession(target); current_session = other_sess_; } MS_EXCEPTION_IF_NULL(current_session); GraphId graph_id = current_session->CompileGraph(segment, outputs); segment->graph_id_ = graph_id; auto graph = current_session->GetGraph(graph_id); MS_EXCEPTION_IF_NULL(graph); for (auto &pre_segment : segment->pre_segments_) { MS_EXCEPTION_IF_NULL(pre_segment); auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_); if (pre_graph == nullptr) { pre_graph = other_sess_->GetGraph(pre_segment->graph_id_); } MS_EXCEPTION_IF_NULL(pre_graph); pre_graph->AddPostGraph(graph); graph->AddPreGraph(pre_graph); MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id; } if (MsContext::GetInstance()->get_param(MS_CTX_PRECOMPILE_ONLY)) { MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } auto ms_context = MsContext::GetInstance(); const bool pynative_mode = (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode); if (!pynative_mode || target != "Ascend") { if (target != target_device_ && !target.empty()) { other_sess_->BuildGraph(graph_id); } else if (!is_multi_graph_sink_) { target_sess_->BuildGraph(graph_id); } } result.run = std::make_shared( [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); MS_EXCEPTION_IF_NULL(result.run); result.simu_run = std::make_shared( [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); }); MS_EXCEPTION_IF_NULL(result.simu_run); result.graph_id = graph_id; graph_id_map_[graph_id] = result; if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) { (void)g_ConvertCache.emplace(segment, result); } return result; } // compile set input output VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { MS_LOG(DEBUG) << "set graph input:" << g; std::vector outputs; (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs), [](const AnfNodePtr &v) { return v; }); return VectorRef(outputs); } namespace { void PushInputTensor(const BaseRef &arg, std::vector *inputs) { MS_EXCEPTION_IF_NULL(inputs); if (utils::isa(arg)) { auto value = utils::cast(arg); inputs->push_back(value); } else if (utils::isa(arg)) { auto value = utils::cast(arg); MS_EXCEPTION_IF_NULL(value); if (value->isa()) { auto value_tuple = value->cast(); MS_EXCEPTION_IF_NULL(value_tuple); auto tuple_value = value_tuple->value(); (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs), [](const ValuePtr &v) { return v->cast(); }); } else if (value->isa()) { tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast()); inputs->push_back(scalar_tensor); } else if (value->isa()) { // If value is a monad, replace it with an unused tensor. inputs->push_back(std::make_shared(int64_t(0), kBool)); } else { inputs->push_back(value->cast()); } } else if (utils::isa(arg)) { auto value = utils::cast(arg).object_; inputs->push_back(py::cast(value)); } else if (utils::isa(arg)) { const auto &args_new = utils::cast(arg); for (const auto &v : args_new) { PushInputTensor(v, inputs); } } else { MS_LOG(WARNING) << "Invalid input type."; } } } // namespace VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; // Run graph std::vector inputs; for (const auto &arg : args) { PushInputTensor(arg, &inputs); } VectorRef outputs; // Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output) const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_); auto ms_context = MsContext::GetInstance(); const bool pynative_mode = (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode); if (pynative_mode) { exe_session->RunOpsInGraph(g, inputs, &outputs); } else { exe_session->RunGraphAsync(g, inputs, &outputs); } MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); return outputs; } void MsBackend::Link(GraphId graph_id) { if (graph_id == kInvalidGraphId) { graph_id = target_sess_->GetFinalRunGraph(); } target_sess_->BuildGraph(graph_id); } MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); target_sess_ = session::SessionFactory::Get().Create(target); if (target_sess_ == nullptr) { MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } target_sess_->Init(device_id); target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); target_device_ = target; } void MsBackend::CreateOtherSession(const std::string &target) { if (other_sess_ != nullptr && other_device_ == target) { return; } other_sess_ = session::SessionFactory::Get().Create(target); if (other_sess_ == nullptr) { MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); other_sess_->Init(device_id); other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); other_device_ = target; } GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_->CompileGraph(fg); } VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } void MsBackend::ClearSessionGraphs() { if (target_sess_ != nullptr) { target_sess_->ClearGraph(); } } #ifdef ENABLE_DEBUGGER void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } #endif MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id) : Backend(backend_name), device_name_(device_name), device_id_(device_id) { auto cut_list = compile::GetMsNonlinearOps(); graph_partition_ = std::make_shared(cut_list, backend_name); } GraphId MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); FuncGraphPtr root_graph = WrapPrimitives(func_graph); MS_EXCEPTION_IF_NULL(root_graph); // Compile root graph. auto root_graph_id = CompileGraph(root_graph); // Compile sub graphs. FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); for (auto sub_graph : sub_graphs) { if (sub_graph != func_graph && sub_graph != nullptr) { (void)CompileGraph(sub_graph); } } // Transform graph to actor DAG, and schedule the actor DAG. std::vector graphs; std::vector device_contexts; for (const auto &graph_id_to_context : graph_to_device_context_) { graphs.emplace_back(runtime::GraphCompiler::GetInstance().Fetch(graph_id_to_context.first)); device_contexts.emplace_back(graph_id_to_context.second); } const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(graphs, device_contexts, nullptr, &control_nodes_); runtime::GraphScheduler::GetInstance().Schedule(actor_set); return root_graph_id; } GraphId MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(graph_partition_); // Split graph to segments. const auto &segments = graph_partition_->Partition(func_graph); MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size(); // Foreach the segments to compile graph. for (const auto &segment : segments) { MS_EXCEPTION_IF_NULL(segment); // Compile the normal nodes, which doesn't contain the cut node. if (!segment->is_cut_) { if (segment->nodes_.size() == 0) { MS_LOG(EXCEPTION) << "The segments size is 0."; } MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope(); // Get and set the device context. const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]); const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); device_context->Initialize(); runtime::GraphCompiler::GetInstance().set_device_context(device_context); // Transform nodes to inputs and outputs. FuncGraphPtr fg; AnfNodePtrList inputs; AnfNodePtrList outputs; std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); // Compile graph. auto graph_id = runtime::GraphCompiler::GetInstance().CompileGraph(segment->nodes_, outputs); graph_to_device_context_[graph_id] = device_context; } else { // Compile the cut node. auto cut_node = segment->nodes_[0]; MS_EXCEPTION_IF_NULL(cut_node); MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope(); control_nodes_.push_back(cut_node); } } return graph_to_device_context_.begin()->first; } VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) { MS_LOG(INFO) << "Run graph begin, graph id: " << graph_id; const auto &context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (context_ptr->get_param(MS_CTX_PRECOMPILE_ONLY)) { MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return VectorRef(); } // Fetch the kernel graph. const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id); MS_EXCEPTION_IF_NULL(kernel_graph); // Transform args to input tensors. std::vector inputs; for (const auto &input_node : kernel_graph->input_nodes()) { const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node); MS_EXCEPTION_IF_NULL(front_node); MS_EXCEPTION_IF_NULL(front_node->func_graph()); const auto &origin_parameters = front_node->func_graph()->parameters(); const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node); if (iter == origin_parameters.end()) { MS_LOG(EXCEPTION) << "Parameter node: " << front_node->fullname_with_scope() << " is not exist."; } auto position = IntToSize(std::distance(origin_parameters.begin(), iter)); PushInputTensor(args[position], &inputs); } // Fetch the actor DAG. const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph); MS_EXCEPTION_IF_NULL(actor_set); // Run actor DAG. mindspore::ScopedLongRunning long_running; VectorRef outputs; runtime::GraphScheduler::GetInstance().PrepareRun(kernel_graph, &inputs, &outputs); if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) { MS_LOG(EXCEPTION) << "The graph runs failed, graph id: " << graph_id << ", graph name: " << kernel_graph->ToString(); } MS_LOG(INFO) << "Run graph end, graph id: " << graph_id; return outputs; } } // namespace compile } // namespace mindspore