|
- /**
- * Copyright 2019-2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "vm/backend.h"
-
- #include <algorithm>
- #include <vector>
-
- #include "backend/session/session_factory.h"
- #include "pipeline/pynative/pynative_execute.h"
- #include "ir/anf.h"
- #include "pybind_api/ir/base_ref_py.h"
- #include "utils/callbacks.h"
- #include "utils/convert_utils.h"
- #include "utils/log_adapter.h"
- #include "utils/ms_utils.h"
- #ifdef ENABLE_GE
- #include "utils/callbacks_ge.h"
- #endif
-
- namespace mindspore {
- namespace compile {
- bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
- bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
-
- Backend::Backend(const std::string &name) : name_(name) {
- MS_LOG(DEBUG) << "select backend:" << name;
- convert_fn_ = MsVmConvert;
- is_multi_graph_sink_ = false;
- }
-
- LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
- MS_LOG(DEBUG) << "MsConvert";
- MS_EXCEPTION_IF_NULL(segment);
- MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
- auto cached = g_ConvertCache.find(segment);
- if (cached != g_ConvertCache.end()) {
- return cached->second;
- }
- LinConvertResult result;
- FuncGraphPtr fg;
- AnfNodePtrList inputs;
- AnfNodePtrList outputs;
- std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
- result.inputs = inputs;
- result.outputs = outputs;
- result.graph_id = kInvalidGraphId;
- auto current_session = target_sess_;
- if (target != target_device_ && !target.empty()) {
- CreateOtherSession(target);
- current_session = other_sess_;
- }
- MS_EXCEPTION_IF_NULL(current_session);
- GraphId graph_id = current_session->CompileGraph(segment, outputs);
- segment->graph_id_ = graph_id;
- auto graph = current_session->GetGraph(graph_id);
- MS_EXCEPTION_IF_NULL(graph);
- for (auto &pre_segment : segment->pre_segments_) {
- MS_EXCEPTION_IF_NULL(pre_segment);
- auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
- if (pre_graph == nullptr) {
- pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
- }
- MS_EXCEPTION_IF_NULL(pre_graph);
- pre_graph->AddPostGraph(graph);
- graph->AddPreGraph(pre_graph);
- MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
- }
-
- if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
- MS_LOG(INFO) << "PrecompileOnly, stop run graph";
- return result;
- }
- auto ms_context = MsContext::GetInstance();
- const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
- if (!pynative_mode || target != "Ascend") {
- if (target != target_device_ && !target.empty()) {
- other_sess_->BuildGraph(graph_id);
- } else if (!is_multi_graph_sink_) {
- target_sess_->BuildGraph(graph_id);
- }
- }
- result.run = std::make_shared<RunFunc>(
- [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
- MS_EXCEPTION_IF_NULL(result.run);
-
- result.simu_run = std::make_shared<RunFunc>(
- [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); });
- MS_EXCEPTION_IF_NULL(result.simu_run);
- result.graph_id = graph_id;
-
- graph_id_map_[graph_id] = result;
- if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) {
- (void)g_ConvertCache.emplace(segment, result);
- }
- return result;
- }
-
- // compile set input output
- VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
- MS_LOG(DEBUG) << "set graph input:" << g;
- std::vector<BaseRef> outputs;
- (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
- [](const AnfNodePtr &v) { return v; });
- return VectorRef(outputs);
- }
-
- namespace {
- void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
- MS_EXCEPTION_IF_NULL(inputs);
- if (utils::isa<tensor::TensorPtr>(arg)) {
- auto value = utils::cast<tensor::TensorPtr>(arg);
- inputs->push_back(value);
- } else if (utils::isa<ValuePtr>(arg)) {
- auto value = utils::cast<ValuePtr>(arg);
- MS_EXCEPTION_IF_NULL(value);
- if (value->isa<ValueTuple>()) {
- auto value_tuple = value->cast<ValueTuplePtr>();
- MS_EXCEPTION_IF_NULL(value_tuple);
- auto tuple_value = value_tuple->value();
- (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
- [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
- } else if (value->isa<Scalar>()) {
- tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
- inputs->push_back(scalar_tensor);
- } else if (value->isa<Monad>()) {
- // If value is a monad, replace it with an unused tensor.
- inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
- } else {
- inputs->push_back(value->cast<tensor::TensorPtr>());
- }
- } else if (utils::isa<PyObjectRef>(arg)) {
- auto value = utils::cast<PyObjectRef>(arg).object_;
- inputs->push_back(py::cast<tensor::TensorPtr>(value));
- } else if (utils::isa<VectorRefPtr>(arg)) {
- const auto &args_new = utils::cast<VectorRef>(arg);
- for (const auto &v : args_new) {
- PushInputTensor(v, inputs);
- }
- } else {
- MS_LOG(WARNING) << "Invalid input type.";
- }
- }
- } // namespace
-
- VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
- MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
- // Run graph
- std::vector<tensor::TensorPtr> inputs;
- for (const auto &arg : args) {
- PushInputTensor(arg, &inputs);
- }
-
- VectorRef outputs;
- // Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output)
- const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
- auto ms_context = MsContext::GetInstance();
- const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
- if (pynative_mode) {
- exe_session->RunOpsInGraph(g, inputs, &outputs);
- } else {
- exe_session->RunGraphAsync(g, inputs, &outputs);
- }
-
- MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
- return outputs;
- }
-
- void MsBackend::Link(GraphId graph_id) {
- if (graph_id == kInvalidGraphId) {
- graph_id = target_sess_->GetFinalRunGraph();
- }
- target_sess_->BuildGraph(graph_id);
- }
-
- MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
- convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
- target_sess_ = session::SessionFactory::Get().Create(target);
- if (target_sess_ == nullptr) {
- MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
- }
- target_sess_->Init(device_id);
- target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
- target_device_ = target;
- }
-
- void MsBackend::CreateOtherSession(const std::string &target) {
- if (other_sess_ != nullptr && other_device_ == target) {
- return;
- }
- other_sess_ = session::SessionFactory::Get().Create(target);
- if (other_sess_ == nullptr) {
- MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
- }
- auto context_ptr = MsContext::GetInstance();
- MS_EXCEPTION_IF_NULL(context_ptr);
- uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
- other_sess_->Init(device_id);
- other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
- other_device_ = target;
- }
-
- GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
-
- VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
-
- void MsBackend::ClearSessionGraphs() {
- if (target_sess_ != nullptr) {
- target_sess_->ClearGraph();
- }
- }
- #ifdef ENABLE_DEBUGGER
- void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
- #endif
- } // namespace compile
- } // namespace mindspore
|