/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "vm/backend.h" #include #include #include "utils/log_adapter.h" #include "ir/anf.h" #include "utils/callbacks.h" #include "utils/convert_utils.h" #include "backend/session/session_factory.h" #include "utils/ms_utils.h" #include "pybind_api/ir/base_ref_py.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" #endif namespace mindspore { namespace compile { bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast(c), value); } Backend::Backend(const std::string &name) : name_(name) { MS_LOG(DEBUG) << "select backend:" << name; convert_fn_ = MsVmConvert; is_multi_graph_sink_ = false; } LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) { MS_LOG(DEBUG) << "MsConvert"; MS_EXCEPTION_IF_NULL(segment); MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); auto cached = g_ConvertCache.find(segment); if (cached != g_ConvertCache.end()) { return cached->second; } LinConvertResult result; FuncGraphPtr fg; AnfNodePtrList inputs; AnfNodePtrList outputs; std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); result.inputs = inputs; result.outputs = outputs; result.graph_id = kInvalidGraphId; GraphId graph_id = kInvalidGraphId; auto current_session = target_sess_; if (target != target_device_ && !target.empty()) { CreateOtherSession(target); current_session = other_sess_; } MS_EXCEPTION_IF_NULL(current_session); graph_id = current_session->CompileGraph(segment, outputs); segment->graph_id_ = graph_id; auto graph = current_session->GetGraph(graph_id); MS_EXCEPTION_IF_NULL(graph); for (auto &pre_segment : segment->pre_segments_) { MS_EXCEPTION_IF_NULL(pre_segment); auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_); if (pre_graph == nullptr) { pre_graph = other_sess_->GetGraph(pre_segment->graph_id_); } MS_EXCEPTION_IF_NULL(pre_graph); pre_graph->AddPostGraph(graph); graph->AddPreGraph(pre_graph); MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id; } if (MsContext::GetInstance()->get_param(MS_CTX_PRECOMPILE_ONLY)) { MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } if (target != target_device_ && !target.empty()) { other_sess_->BuildGraph(graph_id); } else if (!is_multi_graph_sink_) { target_sess_->BuildGraph(graph_id); } result.run = std::make_shared( [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); MS_EXCEPTION_IF_NULL(result.run); result.simu_run = std::make_shared( [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); }); MS_EXCEPTION_IF_NULL(result.simu_run); result.graph_id = graph_id; graph_id_map_[graph_id] = result; (void)g_ConvertCache.emplace(segment, result); return result; } // compile set input output VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { MS_LOG(DEBUG) << "set graph input:" << g; std::vector outputs; (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs), [](const AnfNodePtr &v) { return v; }); return VectorRef(outputs); } namespace { void PushInputTensor(const BaseRef &arg, std::vector *inputs) { MS_EXCEPTION_IF_NULL(inputs); if (utils::isa(arg)) { auto value = utils::cast(arg); inputs->push_back(value); } else if (utils::isa(arg)) { auto value = utils::cast(arg); MS_EXCEPTION_IF_NULL(value); if (value->isa()) { auto value_tuple = value->cast(); MS_EXCEPTION_IF_NULL(value_tuple); auto tuple_value = value_tuple->value(); (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs), [](const ValuePtr &v) { return v->cast(); }); } else if (value->isa()) { tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast()); inputs->push_back(scalar_tensor); } else { inputs->push_back(value->cast()); } } else if (utils::isa(arg)) { auto value = utils::cast(arg).object_; inputs->push_back(py::cast(value)); } else if (utils::isa(arg)) { const auto &args_new = utils::cast(arg); for (const auto &v : args_new) { PushInputTensor(v, inputs); } } else { MS_LOG(WARNING) << "Invalid input type."; } } } // namespace VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; // Run graph std::vector inputs; for (const auto &arg : args) { PushInputTensor(arg, &inputs); } VectorRef outputs; // call ms rungraph (graphId, input ,output) if (target != target_device_ && !target.empty()) { other_sess_->RunGraphAsync(g, inputs, &outputs); } else { target_sess_->RunGraphAsync(g, inputs, &outputs); } MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); return outputs; } void MsBackend::Link(GraphId graph_id) { if (graph_id == kInvalidGraphId) { graph_id = target_sess_->GetFinalRunGraph(); } target_sess_->BuildGraph(graph_id); } MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); target_sess_ = session::SessionFactory::Get().Create(target); if (target_sess_ == nullptr) { MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } target_sess_->Init(device_id); target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); target_device_ = target; } void MsBackend::CreateOtherSession(const std::string &target) { if (other_sess_ != nullptr && other_device_ == target) { return; } other_sess_ = session::SessionFactory::Get().Create(target); if (other_sess_ == nullptr) { MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); other_sess_->Init(device_id); other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); other_device_ = target; } GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_->CompileGraph(fg); } VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } #ifdef ENABLE_DEBUGGER void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } #endif } // namespace compile } // namespace mindspore