You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backend.cc 8.7 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "vm/backend.h"
  17. #include <algorithm>
  18. #include <vector>
  19. #include "backend/session/session_factory.h"
  20. #include "pipeline/pynative/pynative_execute.h"
  21. #include "ir/anf.h"
  22. #include "pybind_api/ir/base_ref_py.h"
  23. #include "utils/callbacks.h"
  24. #include "utils/convert_utils.h"
  25. #include "utils/log_adapter.h"
  26. #include "utils/ms_utils.h"
  27. #ifdef ENABLE_GE
  28. #include "utils/callbacks_ge.h"
  29. #endif
  30. namespace mindspore {
  31. namespace compile {
  32. bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
  33. bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
  34. Backend::Backend(const std::string &name) : name_(name) {
  35. MS_LOG(DEBUG) << "select backend:" << name;
  36. convert_fn_ = MsVmConvert;
  37. is_multi_graph_sink_ = false;
  38. }
  39. LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
  40. MS_LOG(DEBUG) << "MsConvert";
  41. MS_EXCEPTION_IF_NULL(segment);
  42. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  43. auto cached = g_ConvertCache.find(segment);
  44. if (cached != g_ConvertCache.end()) {
  45. return cached->second;
  46. }
  47. LinConvertResult result;
  48. FuncGraphPtr fg;
  49. AnfNodePtrList inputs;
  50. AnfNodePtrList outputs;
  51. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
  52. result.inputs = inputs;
  53. result.outputs = outputs;
  54. result.graph_id = kInvalidGraphId;
  55. auto current_session = target_sess_;
  56. if (target != target_device_ && !target.empty()) {
  57. CreateOtherSession(target);
  58. current_session = other_sess_;
  59. }
  60. MS_EXCEPTION_IF_NULL(current_session);
  61. GraphId graph_id = current_session->CompileGraph(segment, outputs);
  62. segment->graph_id_ = graph_id;
  63. auto graph = current_session->GetGraph(graph_id);
  64. MS_EXCEPTION_IF_NULL(graph);
  65. for (auto &pre_segment : segment->pre_segments_) {
  66. MS_EXCEPTION_IF_NULL(pre_segment);
  67. auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
  68. if (pre_graph == nullptr) {
  69. pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
  70. }
  71. MS_EXCEPTION_IF_NULL(pre_graph);
  72. pre_graph->AddPostGraph(graph);
  73. graph->AddPreGraph(pre_graph);
  74. MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
  75. }
  76. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
  77. MS_LOG(INFO) << "PrecompileOnly, stop run graph";
  78. return result;
  79. }
  80. auto ms_context = MsContext::GetInstance();
  81. const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
  82. if (!pynative_mode || target != "Ascend") {
  83. if (target != target_device_ && !target.empty()) {
  84. other_sess_->BuildGraph(graph_id);
  85. } else if (!is_multi_graph_sink_) {
  86. target_sess_->BuildGraph(graph_id);
  87. }
  88. }
  89. result.run = std::make_shared<RunFunc>(
  90. [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
  91. MS_EXCEPTION_IF_NULL(result.run);
  92. result.simu_run = std::make_shared<RunFunc>(
  93. [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); });
  94. MS_EXCEPTION_IF_NULL(result.simu_run);
  95. result.graph_id = graph_id;
  96. graph_id_map_[graph_id] = result;
  97. if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) {
  98. (void)g_ConvertCache.emplace(segment, result);
  99. }
  100. return result;
  101. }
  102. // compile set input output
  103. VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
  104. MS_LOG(DEBUG) << "set graph input:" << g;
  105. std::vector<BaseRef> outputs;
  106. (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
  107. [](const AnfNodePtr &v) { return v; });
  108. return VectorRef(outputs);
  109. }
  110. namespace {
  111. void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
  112. MS_EXCEPTION_IF_NULL(inputs);
  113. if (utils::isa<tensor::TensorPtr>(arg)) {
  114. auto value = utils::cast<tensor::TensorPtr>(arg);
  115. inputs->push_back(value);
  116. } else if (utils::isa<ValuePtr>(arg)) {
  117. auto value = utils::cast<ValuePtr>(arg);
  118. MS_EXCEPTION_IF_NULL(value);
  119. if (value->isa<ValueTuple>()) {
  120. auto value_tuple = value->cast<ValueTuplePtr>();
  121. MS_EXCEPTION_IF_NULL(value_tuple);
  122. auto tuple_value = value_tuple->value();
  123. (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
  124. [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
  125. } else if (value->isa<Scalar>()) {
  126. tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
  127. inputs->push_back(scalar_tensor);
  128. } else if (value->isa<Monad>()) {
  129. // If value is a monad, replace it with an unused tensor.
  130. inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
  131. } else {
  132. inputs->push_back(value->cast<tensor::TensorPtr>());
  133. }
  134. } else if (utils::isa<PyObjectRef>(arg)) {
  135. auto value = utils::cast<PyObjectRef>(arg).object_;
  136. inputs->push_back(py::cast<tensor::TensorPtr>(value));
  137. } else if (utils::isa<VectorRefPtr>(arg)) {
  138. const auto &args_new = utils::cast<VectorRef>(arg);
  139. for (const auto &v : args_new) {
  140. PushInputTensor(v, inputs);
  141. }
  142. } else {
  143. MS_LOG(WARNING) << "Invalid input type.";
  144. }
  145. }
  146. } // namespace
  147. VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
  148. MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
  149. // Run graph
  150. std::vector<tensor::TensorPtr> inputs;
  151. for (const auto &arg : args) {
  152. PushInputTensor(arg, &inputs);
  153. }
  154. VectorRef outputs;
  155. // Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output)
  156. const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
  157. auto ms_context = MsContext::GetInstance();
  158. const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
  159. if (pynative_mode) {
  160. exe_session->RunOpsInGraph(g, inputs, &outputs);
  161. } else {
  162. exe_session->RunGraphAsync(g, inputs, &outputs);
  163. }
  164. MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
  165. return outputs;
  166. }
  167. void MsBackend::Link(GraphId graph_id) {
  168. if (graph_id == kInvalidGraphId) {
  169. graph_id = target_sess_->GetFinalRunGraph();
  170. }
  171. target_sess_->BuildGraph(graph_id);
  172. }
  173. MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
  174. convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
  175. target_sess_ = session::SessionFactory::Get().Create(target);
  176. if (target_sess_ == nullptr) {
  177. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  178. }
  179. target_sess_->Init(device_id);
  180. target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  181. target_device_ = target;
  182. }
  183. void MsBackend::CreateOtherSession(const std::string &target) {
  184. if (other_sess_ != nullptr && other_device_ == target) {
  185. return;
  186. }
  187. other_sess_ = session::SessionFactory::Get().Create(target);
  188. if (other_sess_ == nullptr) {
  189. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  190. }
  191. auto context_ptr = MsContext::GetInstance();
  192. MS_EXCEPTION_IF_NULL(context_ptr);
  193. uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  194. other_sess_->Init(device_id);
  195. other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  196. other_device_ = target;
  197. }
  198. GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
  199. VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
  200. void MsBackend::ClearSessionGraphs() {
  201. if (target_sess_ != nullptr) {
  202. target_sess_->ClearGraph();
  203. }
  204. }
  205. #ifdef ENABLE_DEBUGGER
  206. void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
  207. #endif
  208. } // namespace compile
  209. } // namespace mindspore