You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backend.cc 54 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "vm/backend.h"
  17. #include <algorithm>
  18. #include <vector>
  19. #include <map>
  20. #include "vm/transform.h"
  21. #include "backend/session/session_factory.h"
  22. #include "runtime/op_builder/op_lazy_builder.h"
  23. #include "backend/optimizer/common/helper.h"
  24. #include "pipeline/pynative/pynative_execute.h"
  25. #include "pipeline/jit/parse/data_converter.h"
  26. #include "ir/anf.h"
  27. #include "pybind_api/ir/base_ref_py.h"
  28. #include "pybind_api/pybind_patch.h"
  29. #include "utils/callbacks.h"
  30. #include "utils/convert_utils.h"
  31. #include "utils/log_adapter.h"
  32. #include "utils/ms_utils.h"
  33. #include "runtime/hardware/device_context_manager.h"
  34. #include "runtime/framework/graph_compiler.h"
  35. #include "utils/scoped_long_running.h"
  36. #ifdef ENABLE_GE
  37. #include "utils/callbacks_ge.h"
  38. #endif
  39. #ifdef ENABLE_DEBUGGER
  40. #include "debug/debugger/debugger.h"
  41. #endif
  42. #ifndef ENABLE_SECURITY
  43. #include "debug/data_dump/dump_json_parser.h"
  44. #endif
  45. namespace mindspore {
  46. namespace compile {
  47. bool Backend::GetCond(const BaseRef &c, bool *const value) {
  48. mindspore::ScopedLongRunning long_running;
  49. return BaseRefToBool(c, value);
  50. }
  51. bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
  52. Backend::Backend(const std::string &name) : name_(name) {
  53. MS_LOG(DEBUG) << "Select backend:" << name;
  54. convert_fn_ = MsVmConvert;
  55. is_multi_graph_sink_ = false;
  56. }
  57. LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
  58. MS_LOG(DEBUG) << "MsConvert";
  59. MS_EXCEPTION_IF_NULL(segment);
  60. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  61. LinConvertResult result;
  62. FuncGraphPtr fg;
  63. AnfNodePtrList inputs;
  64. AnfNodePtrList outputs;
  65. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
  66. result.inputs = inputs;
  67. result.outputs = outputs;
  68. result.graph_id = kInvalidGraphId;
  69. auto current_session = target_sess_;
  70. if (target != target_device_ && !target.empty()) {
  71. CreateOtherSession(target);
  72. current_session = other_sess_;
  73. }
  74. MS_EXCEPTION_IF_NULL(current_session);
  75. GraphId graph_id = current_session->CompileGraph(segment, outputs);
  76. segment->graph_id_ = graph_id;
  77. auto graph = current_session->GetGraph(graph_id);
  78. MS_EXCEPTION_IF_NULL(graph);
  79. for (auto &pre_segment : segment->pre_segments_) {
  80. MS_EXCEPTION_IF_NULL(pre_segment);
  81. MS_EXCEPTION_IF_NULL(target_sess_);
  82. auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
  83. if (pre_graph == nullptr) {
  84. MS_EXCEPTION_IF_NULL(other_sess_);
  85. pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
  86. }
  87. MS_EXCEPTION_IF_NULL(pre_graph);
  88. pre_graph->AddPostGraph(graph);
  89. graph->AddPreGraph(pre_graph);
  90. MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
  91. }
  92. if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
  93. MS_LOG(INFO) << "PrecompileOnly, stop run graph";
  94. return result;
  95. }
  96. auto ms_context = MsContext::GetInstance();
  97. const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
  98. if (!pynative_mode || target != "Ascend") {
  99. if (target != target_device_ && !target.empty()) {
  100. MS_EXCEPTION_IF_NULL(other_sess_);
  101. other_sess_->BuildGraph(graph_id);
  102. } else if (!is_multi_graph_sink_) {
  103. MS_EXCEPTION_IF_NULL(target_sess_);
  104. target_sess_->BuildGraph(graph_id);
  105. }
  106. }
  107. result.run = std::make_shared<RunFunc>(
  108. [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
  109. MS_EXCEPTION_IF_NULL(result.run);
  110. result.simu_run = std::make_shared<RunFunc>(
  111. [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id); });
  112. MS_EXCEPTION_IF_NULL(result.simu_run);
  113. result.graph_id = graph_id;
  114. graph_id_map_[graph_id] = result;
  115. return result;
  116. }
  117. // compile set input output
  118. VectorRef MsBackend::MsSimuRunGraph(const GraphId &g) {
  119. MS_LOG(DEBUG) << "Set graph input:" << g;
  120. std::vector<BaseRef> outputs;
  121. (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
  122. [](const AnfNodePtr &v) { return v; });
  123. return VectorRef(outputs);
  124. }
  125. namespace {
  126. void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
  127. MS_EXCEPTION_IF_NULL(inputs);
  128. if (utils::isa<tensor::TensorPtr>(arg)) {
  129. auto value = utils::cast<tensor::TensorPtr>(arg);
  130. inputs->push_back(value);
  131. } else if (utils::isa<ValuePtr>(arg)) {
  132. auto value = utils::cast<ValuePtr>(arg);
  133. MS_EXCEPTION_IF_NULL(value);
  134. if (value->isa<ValueTuple>()) {
  135. auto value_tuple = value->cast<ValueTuplePtr>();
  136. MS_EXCEPTION_IF_NULL(value_tuple);
  137. auto tuple_value = value_tuple->value();
  138. (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
  139. [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
  140. } else if (value->isa<Scalar>()) {
  141. tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
  142. inputs->push_back(scalar_tensor);
  143. } else if (value->isa<Monad>()) {
  144. // If value is a monad, replace it with an unused tensor.
  145. inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
  146. } else {
  147. inputs->push_back(value->cast<tensor::TensorPtr>());
  148. }
  149. } else if (utils::isa<PyObjectRef>(arg)) {
  150. auto value = utils::cast<PyObjectRef>(arg).object_;
  151. inputs->push_back(py::cast<tensor::TensorPtr>(value));
  152. } else if (utils::isa<VectorRefPtr>(arg)) {
  153. const auto &args_new = utils::cast<VectorRef>(arg);
  154. for (const auto &v : args_new) {
  155. PushInputTensor(v, inputs);
  156. }
  157. } else {
  158. MS_LOG(WARNING) << "Invalid input type.";
  159. }
  160. }
  161. // Insert the front_node related tensor in the input_tensor.
  162. void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
  163. std::vector<tensor::TensorPtr> *input_tensor) {
  164. const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
  165. if (iter == parameters.end()) {
  166. (void)((*input_tensor).emplace_back(nullptr));
  167. return;
  168. }
  169. auto position = iter - parameters.begin();
  170. PushInputTensor(args[position], input_tensor);
  171. }
  172. void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_info) {
  173. MS_EXCEPTION_IF_NULL(kernel_graph);
  174. MS_EXCEPTION_IF_NULL(op_run_info);
  175. const auto &kernels = kernel_graph->execution_order();
  176. for (const auto &kernel : kernels) {
  177. MS_EXCEPTION_IF_NULL(kernel);
  178. if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
  179. op_run_info->abstract = kernel->abstract();
  180. }
  181. }
  182. }
  183. TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
  184. MS_EXCEPTION_IF_NULL(output_node);
  185. // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
  186. // when infer type is not equal to device type.
  187. auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
  188. std::vector<int64_t> temp_shape;
  189. const auto &shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
  190. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  191. auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  192. tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
  193. // Put device tensor into host tensor.
  194. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
  195. MS_EXCEPTION_IF_NULL(device_tensor);
  196. tensor->set_device_address(device_tensor);
  197. // MindRT is disabled in the multi graphs scenario
  198. // Delete tensor->data_sync() when MindRT is enabled in all scenes.
  199. auto ms_context = MsContext::GetInstance();
  200. MS_EXCEPTION_IF_NULL(ms_context);
  201. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  202. // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
  203. // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
  204. tensor->data_sync(false);
  205. }
  206. return tensor;
  207. }
  208. void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
  209. MS_EXCEPTION_IF_NULL(outputs);
  210. for (auto &item_with_index : output_nodes) {
  211. MS_EXCEPTION_IF_NULL(item_with_index.first);
  212. // if is graph return nothing ,the function should return a null anylist
  213. if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
  214. continue;
  215. }
  216. outputs->emplace_back(CreateOutputTensor(item_with_index.first, item_with_index.second));
  217. }
  218. }
  219. void UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> &output_nodes,
  220. const DeviceContext *device_context) {
  221. for (auto &item_with_index : output_nodes) {
  222. auto &output_node = item_with_index.first;
  223. auto output_index = item_with_index.second;
  224. if (output_node != nullptr) {
  225. if (!AnfAlgo::OutputAddrExist(output_node, output_index, false)) {
  226. continue;
  227. }
  228. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
  229. if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) {
  230. continue;
  231. }
  232. MS_EXCEPTION_IF_NULL(device_context);
  233. auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
  234. device_tensor->format(), device_tensor->type_id());
  235. MS_EXCEPTION_IF_NULL(new_device_tensor);
  236. new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
  237. new_device_tensor->ResetRefCount();
  238. AnfAlgo::SetOutputAddr(new_device_tensor, output_index, output_node.get());
  239. }
  240. }
  241. }
  242. void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
  243. MS_EXCEPTION_IF_NULL(graph);
  244. for (const auto &node : graph->input_nodes()) {
  245. MS_EXCEPTION_IF_NULL(node);
  246. if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
  247. AnfAlgo::SetOutputAddr(nullptr, 0, node.get());
  248. }
  249. }
  250. }
  251. } // namespace
  252. VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
  253. MS_LOG(DEBUG) << "Start ms graph run:" << args.size() << ", g:" << g;
  254. // Run graph
  255. std::vector<tensor::TensorPtr> inputs;
  256. for (const auto &arg : args) {
  257. PushInputTensor(arg, &inputs);
  258. }
  259. VectorRef outputs;
  260. // Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output)
  261. const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
  262. MS_EXCEPTION_IF_NULL(exe_session);
  263. auto ms_context = MsContext::GetInstance();
  264. const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
  265. if (pynative_mode) {
  266. exe_session->RunOpsInGraph(g, inputs, &outputs);
  267. } else {
  268. exe_session->RunGraphAsync(g, inputs, &outputs);
  269. }
  270. MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
  271. return outputs;
  272. }
  273. MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
  274. convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
  275. target_sess_ = session::SessionFactory::Get().Create(target);
  276. if (target_sess_ == nullptr) {
  277. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  278. }
  279. target_sess_->Init(device_id);
  280. #ifndef ENABLE_SECURITY
  281. target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  282. #endif
  283. target_device_ = target;
  284. }
  285. void MsBackend::CreateOtherSession(const std::string &target) {
  286. if (other_sess_ != nullptr && other_device_ == target) {
  287. return;
  288. }
  289. other_sess_ = session::SessionFactory::Get().Create(target);
  290. if (other_sess_ == nullptr) {
  291. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  292. }
  293. auto context_ptr = MsContext::GetInstance();
  294. MS_EXCEPTION_IF_NULL(context_ptr);
  295. uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  296. other_sess_->Init(device_id);
  297. #ifndef ENABLE_SECURITY
  298. other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  299. #endif
  300. other_device_ = target;
  301. }
  302. GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) {
  303. MS_EXCEPTION_IF_NULL(target_sess_);
  304. return target_sess_->CompileGraph(fg);
  305. }
  306. VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
  307. void MsBackend::ClearSessionGraphs() {
  308. if (target_sess_ != nullptr) {
  309. target_sess_->ClearGraph();
  310. }
  311. }
  312. #ifdef ENABLE_DEBUGGER
  313. void MsBackend::SetDebugger() {
  314. MS_EXCEPTION_IF_NULL(target_sess_);
  315. target_sess_->SetDebugger();
  316. }
  317. #endif
  318. MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
  319. : Backend(backend_name), device_name_(device_name) {
  320. root_graph_ = nullptr;
  321. auto ms_context = MsContext::GetInstance();
  322. const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
  323. auto &cut_list = pynative_mode ? compile::control_ops : GetMsNonlinearOps();
  324. graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
  325. graph_compiler_ = std::make_shared<GraphCompiler>();
  326. const auto &device_context =
  327. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
  328. device_context->Initialize();
  329. device_id_ = device_context->device_context_key().device_id_;
  330. #ifdef ENABLE_DEBUGGER
  331. SetDebuggerInit();
  332. #endif
  333. runtime::GraphScheduler::GetInstance().Initialize();
  334. }
  335. const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
  336. MS_EXCEPTION_IF_NULL(graph_compiler_);
  337. MS_EXCEPTION_IF_NULL(func_graph);
  338. MS_LOG(INFO) << "Status record: start compile function graph: " << func_graph->ToString();
  339. auto root_graph = WrapPrimitives(func_graph);
  340. MS_EXCEPTION_IF_NULL(root_graph);
  341. root_graph_ = root_graph.get();
  342. // Register a summary callback function, which is called in the final stages of summary.
  343. graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  344. auto context_ptr = MsContext::GetInstance();
  345. MS_EXCEPTION_IF_NULL(context_ptr);
  346. ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
  347. real_execution_mode_ = ms_execution_mode_;
  348. // Compile root graph.
  349. graph_id_to_device_context_.clear();
  350. func_graph_to_kernel_graph_ids_.clear();
  351. control_nodes_.clear();
  352. auto subgraph_need_compile = CompileGraph(root_graph);
  353. // Compile sub graphs.
  354. if (subgraph_need_compile) {
  355. MS_EXCEPTION_IF_NULL(root_graph->manager());
  356. FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
  357. for (auto sub_graph : sub_graphs) {
  358. if (sub_graph != func_graph && sub_graph != nullptr) {
  359. (void)CompileGraph(sub_graph);
  360. }
  361. }
  362. }
  363. // Construct the graph compiler info.
  364. auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
  365. if (real_execution_mode_ == kGraphMode) {
  366. // Transform graph to actor DAG, and schedule the actor DAG.
  367. const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
  368. runtime::GraphScheduler::GetInstance().Schedule(actor_set);
  369. }
  370. MS_EXCEPTION_IF_NULL(graph_compiler_info);
  371. const ActorInfo &actor_info = graph_compiler_info->name_;
  372. (void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
  373. MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
  374. << ", produce actor: " << actor_info;
  375. return actor_info;
  376. }
  377. bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
  378. MS_EXCEPTION_IF_NULL(func_graph);
  379. MS_EXCEPTION_IF_NULL(graph_partition_);
  380. MS_EXCEPTION_IF_NULL(graph_compiler_);
  381. bool contain_multi_target = false;
  382. // Split graph to segments.
  383. const auto &segments = graph_partition_->Partition(func_graph, &contain_multi_target);
  384. MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
  385. const auto &device_context =
  386. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
  387. MS_EXCEPTION_IF_NULL(device_context);
  388. const auto &new_segments = device_context->PartitionGraph(func_graph, segments);
  389. // Compile the whole function graph if not split graph.
  390. if (new_segments.size() == 0) {
  391. auto graph_id = graph_compiler_->CompileGraph(func_graph, device_context);
  392. graph_id_to_device_context_[graph_id] = device_context;
  393. return false;
  394. }
  395. // Foreach the segments to compile graph.
  396. for (const auto &segment : new_segments) {
  397. CompileGraph(segment, contain_multi_target);
  398. }
  399. return true;
  400. }
  401. void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target) {
  402. MS_EXCEPTION_IF_NULL(segment);
  403. // Compile the normal nodes, which doesn't contain the cut node.
  404. if (segment->nodes_.size() == 0) {
  405. MS_LOG(EXCEPTION) << "The segments size is 0.";
  406. }
  407. if (!segment->is_cut_) {
  408. MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
  409. MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->DebugString();
  410. // Get the device context.
  411. const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
  412. const auto &device_context =
  413. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
  414. MS_EXCEPTION_IF_NULL(device_context);
  415. device_context->Initialize();
  416. // Transform nodes to inputs and outputs.
  417. FuncGraphPtr fg;
  418. AnfNodePtrList inputs;
  419. AnfNodePtrList outputs;
  420. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
  421. auto context_ptr = MsContext::GetInstance();
  422. MS_EXCEPTION_IF_NULL(context_ptr);
  423. // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
  424. if (contain_multi_target && ms_execution_mode_ == kPynativeMode) {
  425. real_execution_mode_ = kGraphMode;
  426. context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
  427. }
  428. // Compile graph.
  429. auto graph_id = graph_compiler_->CompileGraph(segment->nodes_, outputs, device_context);
  430. if (ms_execution_mode_ != real_execution_mode_) {
  431. context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
  432. }
  433. graph_id_to_device_context_[graph_id] = device_context;
  434. const auto &func_graph = segment->nodes_[0]->func_graph();
  435. MS_EXCEPTION_IF_NULL(func_graph);
  436. func_graph_to_kernel_graph_ids_[func_graph].emplace_back(graph_id);
  437. } else {
  438. // Compile the cut node.
  439. auto cut_node = segment->nodes_[0];
  440. MS_EXCEPTION_IF_NULL(cut_node);
  441. MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
  442. control_nodes_.push_back(cut_node);
  443. }
  444. }
  445. namespace {
  446. void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &front_cnode,
  447. const CNodePtr &backend_cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
  448. const std::map<AnfNodePtr, size_t> &parameter_index,
  449. const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
  450. VectorRef *args) {
  451. MS_EXCEPTION_IF_NULL(front_cnode);
  452. MS_EXCEPTION_IF_NULL(backend_cnode);
  453. MS_EXCEPTION_IF_NULL(graph_compiler);
  454. MS_EXCEPTION_IF_NULL(args);
  455. size_t input_index = 0;
  456. auto inputs = front_cnode->inputs();
  457. for (size_t i = 1; i < inputs.size(); i++) {
  458. const auto &input_node = inputs[i];
  459. MS_EXCEPTION_IF_NULL(input_node);
  460. auto kernel_with_index = AnfAlgo::VisitKernel(input_node, 0);
  461. auto real_input = kernel_with_index.first;
  462. MS_EXCEPTION_IF_NULL(real_input);
  463. if (!real_input->isa<ValueNode>()) {
  464. TensorPtr tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
  465. graph_inputs, input_tensor_info, input_index);
  466. MS_EXCEPTION_IF_NULL(tensor);
  467. args->emplace_back(tensor);
  468. input_index++;
  469. continue;
  470. }
  471. // Get value from value node.
  472. const auto &value_node = real_input->cast<ValueNodePtr>();
  473. MS_EXCEPTION_IF_NULL(value_node);
  474. const auto &value = value_node->value();
  475. MS_EXCEPTION_IF_NULL(value);
  476. if (value->isa<ValueSequeue>()) {
  477. const auto &value_sequeue = value->cast<ValueSequeuePtr>();
  478. MS_EXCEPTION_IF_NULL(value_sequeue);
  479. input_index += value_sequeue->size();
  480. } else {
  481. input_index++;
  482. }
  483. args->emplace_back(value);
  484. }
  485. }
  486. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
  487. MS_EXCEPTION_IF_NULL(tensors);
  488. for (const auto &input_object : tuple_inputs) {
  489. if (!py::isinstance<tensor::Tensor>(input_object)) {
  490. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  491. }
  492. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  493. MS_EXCEPTION_IF_NULL(tensor);
  494. (void)tensors->emplace_back(tensor);
  495. }
  496. }
  497. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
  498. MS_EXCEPTION_IF_NULL(tensors);
  499. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  500. MS_EXCEPTION_IF_NULL(input_value);
  501. if (!input_value->isa<ValueTuple>()) {
  502. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  503. }
  504. auto value_tuple = input_value->cast<ValueTuplePtr>();
  505. MS_EXCEPTION_IF_NULL(value_tuple);
  506. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  507. MS_EXCEPTION_IF_NULL(tensor_ptr);
  508. (void)tensors->emplace_back(tensor_ptr);
  509. }
  510. void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
  511. MS_EXCEPTION_IF_NULL(tensors);
  512. if (!py::isinstance<py::tuple>(input_object)) {
  513. MS_LOG(EXCEPTION) << "The input should be a tuple!";
  514. }
  515. auto inputs = py::cast<py::tuple>(input_object);
  516. if (inputs.empty()) {
  517. MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
  518. }
  519. if (py::isinstance<tensor::Tensor>(inputs[0])) {
  520. PlantTensorTupleToVector(inputs, tensors);
  521. } else {
  522. ConvertValueTupleToTensor(input_object, tensors);
  523. }
  524. }
  525. void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, const KernelGraphPtr &graph,
  526. const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
  527. const std::map<AnfNodePtr, size_t> &parameter_index,
  528. const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
  529. VectorRef *op_outputs) {
  530. MS_EXCEPTION_IF_NULL(graph);
  531. MS_EXCEPTION_IF_NULL(kernel);
  532. MS_EXCEPTION_IF_NULL(op_outputs);
  533. AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
  534. MS_EXCEPTION_IF_NULL(front_node);
  535. if (!front_node->isa<CNode>()) {
  536. MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
  537. }
  538. CNodePtr cnode = front_node->cast<CNodePtr>();
  539. MS_EXCEPTION_IF_NULL(cnode);
  540. const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
  541. if (node_inputs.empty()) {
  542. MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
  543. }
  544. const AnfNodePtr &fn = node_inputs.at(0);
  545. if (!IsValueNode<Primitive>(fn)) {
  546. MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
  547. << "] is not a ValueNode of Primitive";
  548. }
  549. PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
  550. MS_EXCEPTION_IF_NULL(prim);
  551. if (prim->name() == kBpropCutOpName) {
  552. VectorRef args;
  553. GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info,
  554. &args);
  555. auto py_prim = prim->cast<PrimitivePyPtr>();
  556. MS_EXCEPTION_IF_NULL(py_prim);
  557. BaseRef out = py_prim->RunHookFunction(args);
  558. // Convert pyobject output to tensor.
  559. if (utils::isa<PyObjectRef>(out)) {
  560. PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
  561. auto out_py_tuple = py_ref.object_;
  562. std::vector<tensor::TensorPtr> output_tensors;
  563. ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
  564. (void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
  565. [](tensor::TensorPtr &tensor) { return std::move(tensor); });
  566. }
  567. }
  568. }
  569. void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
  570. MS_EXCEPTION_IF_NULL(value);
  571. MS_EXCEPTION_IF_NULL(outputs);
  572. if (value->isa<ValueTuple>()) {
  573. auto value_tuple = value->cast<ValueTuplePtr>();
  574. MS_EXCEPTION_IF_NULL(value_tuple);
  575. for (size_t i = 0; i < value_tuple->size(); ++i) {
  576. ValuePtr element = value_tuple->value()[i];
  577. MS_EXCEPTION_IF_NULL(element);
  578. if (element->isa<tensor::Tensor>()) {
  579. auto tensor = element->cast<tensor::TensorPtr>();
  580. MS_EXCEPTION_IF_NULL(tensor);
  581. outputs->emplace_back(tensor);
  582. } else if (element->isa<ValueTuple>()) {
  583. TensorValueToVector(element, outputs);
  584. }
  585. }
  586. } else if (value->isa<tensor::Tensor>()) {
  587. auto tensor = value->cast<tensor::TensorPtr>();
  588. MS_EXCEPTION_IF_NULL(tensor);
  589. outputs->emplace_back(tensor);
  590. }
  591. }
  592. bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
  593. MS_EXCEPTION_IF_NULL(graph_output);
  594. MS_EXCEPTION_IF_NULL(outputs);
  595. if (graph_output->isa<ValueNode>()) {
  596. MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
  597. VectorRef output_tmp;
  598. ValuePtr value = GetValueNode(graph_output);
  599. TensorValueToVector(value, &output_tmp);
  600. if (output_tmp.size() == 1) {
  601. *outputs = std::move(output_tmp);
  602. } else if (output_tmp.size() > 1) {
  603. outputs->emplace_back(output_tmp);
  604. } else {
  605. MS_LOG(EXCEPTION) << "Output is empty!";
  606. }
  607. return true;
  608. }
  609. if (graph_output->isa<Parameter>()) {
  610. MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
  611. // Find the right parameter as ret_val.
  612. auto func_graph = graph_output->func_graph();
  613. MS_EXCEPTION_IF_NULL(func_graph);
  614. auto params = func_graph->parameters();
  615. if (args.size() != params.size()) {
  616. MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to graph input size " << params.size();
  617. }
  618. auto it = std::find(params.begin(), params.end(), graph_output);
  619. if (it == params.end()) {
  620. MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
  621. }
  622. size_t index = it - params.cbegin();
  623. if (index >= args.size()) {
  624. MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
  625. }
  626. outputs->emplace_back(args[index]);
  627. return true;
  628. }
  629. return false;
  630. }
  631. } // namespace
  632. void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
  633. for (size_t i = 0; i < value.size(); ++i) {
  634. auto value_element = value[i];
  635. MS_EXCEPTION_IF_NULL(value_element);
  636. if (utils::isa<tensor::TensorPtr>(value_element)) {
  637. flatted_value->emplace_back(value_element);
  638. } else if (utils::isa<ValueTuplePtr>(value_element)) {
  639. auto value_tuple_element = value_element->cast<ValueTuplePtr>();
  640. MS_EXCEPTION_IF_NULL(value_tuple_element);
  641. FlatValueTupleValue(value_tuple_element->value(), flatted_value);
  642. } else {
  643. MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
  644. }
  645. }
  646. }
  647. void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
  648. size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
  649. const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
  650. const size_t position = iter - parameters.begin();
  651. // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
  652. // and there is no need to input a tensor.
  653. if (position >= args.size()) {
  654. MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size()
  655. << ".";
  656. return;
  657. }
  658. auto value_tuple = utils::cast<ValueTuplePtr>(args[position]);
  659. MS_EXCEPTION_IF_NULL(value_tuple);
  660. auto value_tuple_value = value_tuple->value();
  661. ValuePtrList flatted_value_tuple_value;
  662. FlatValueTupleValue(value_tuple_value, &flatted_value_tuple_value);
  663. if (index >= flatted_value_tuple_value.size()) {
  664. MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
  665. << " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
  666. }
  667. auto input = flatted_value_tuple_value[index];
  668. MS_EXCEPTION_IF_NULL(input);
  669. auto tensor_input = input->cast<tensor::TensorPtr>();
  670. input_tensor->push_back(tensor_input);
  671. }
  672. void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
  673. const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
  674. runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks();
  675. MS_EXCEPTION_IF_NULL(graph_compiler_);
  676. for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
  677. const auto &graph = graphs[graph_index];
  678. MS_EXCEPTION_IF_NULL(graph);
  679. std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
  680. std::map<AnfNodePtr, size_t> parameter_index;
  681. GraphOutputInfo graph_output_info;
  682. graph_output_info.graph_outputs = outputs;
  683. graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, &parameter_index,
  684. &graph_output_info.output_indexes);
  685. std::map<KernelWithIndex, size_t> cnode_ref_count;
  686. auto iter = cnode_ref_counts_.find(graph->graph_id());
  687. if (iter == cnode_ref_counts_.end()) {
  688. graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
  689. (void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
  690. } else {
  691. cnode_ref_count = iter->second;
  692. }
  693. // Clear bucket resources every step
  694. if (graph->is_bprop()) {
  695. graph_compiler_->ClearAllBucket(graph->graph_id());
  696. }
  697. for (const auto &kernel : graph->execution_order()) {
  698. InputTensorInfo input_tensor_info;
  699. VectorRef op_outputs;
  700. if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
  701. OpRunInfo op_run_info;
  702. GraphInfo graph_info;
  703. graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
  704. &input_tensor_info);
  705. graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info);
  706. RunOp(&op_run_info, &op_outputs);
  707. } else {
  708. RunControlOperator(graph_compiler_, graph, kernel, op_output_map, parameter_index, inputs[graph_index],
  709. &input_tensor_info, &op_outputs);
  710. // Execute remaining lazy tasks before PyNative hook exit.
  711. runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks();
  712. }
  713. graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
  714. graph_output_info.graph_output_tensors.clear();
  715. graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
  716. // Save grad node to Bucket
  717. if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
  718. graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
  719. }
  720. }
  721. }
  722. }
  723. void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
  724. MS_EXCEPTION_IF_NULL(root_graph_);
  725. if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
  726. return;
  727. }
  728. const auto &context_ptr = MsContext::GetInstance();
  729. MS_EXCEPTION_IF_NULL(context_ptr);
  730. if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
  731. MS_LOG(INFO) << "PrecompileOnly, stop run graph";
  732. return;
  733. }
  734. MS_LOG(INFO) << "Status record: start run actor: " << actor_info;
  735. // Fetch the graph compiler info.
  736. const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
  737. if (graph_iter == actor_to_graph_compiler_info_.end()) {
  738. MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
  739. }
  740. MS_EXCEPTION_IF_NULL(graph_iter->second);
  741. const auto &graph_compiler_info = *(graph_iter->second);
  742. const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
  743. // Transform args to input tensors.
  744. // Input tensors of the graph.
  745. std::vector<std::vector<tensor::TensorPtr>> input_tensors;
  746. for (const auto &kernel_graph : graph_compiler_info.graphs_) {
  747. std::vector<tensor::TensorPtr> input_tensor;
  748. MS_EXCEPTION_IF_NULL(kernel_graph);
  749. for (const auto &input_node : kernel_graph->input_nodes()) {
  750. auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
  751. if (element_pair.first) {
  752. PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor);
  753. } else {
  754. const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
  755. PushTensor(args, origin_parameters, front_node, &input_tensor);
  756. }
  757. }
  758. (void)input_tensors.emplace_back(input_tensor);
  759. }
  760. // Input tensors of the control node.
  761. std::vector<tensor::TensorPtr> input_tensor;
  762. MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
  763. // Get inputs of control node which come from the host actor.
  764. const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
  765. for (const auto &parameter : control_node_parameters) {
  766. PushTensor(args, origin_parameters, parameter, &input_tensor);
  767. }
  768. (void)input_tensors.emplace_back(input_tensor);
  769. // Run in the pynative mode.
  770. MS_EXCEPTION_IF_NULL(outputs);
  771. // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
  772. if (real_execution_mode_ == kPynativeMode) {
  773. RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
  774. MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
  775. return;
  776. }
  777. // Run actor DAG.
  778. mindspore::ScopedLongRunning long_running;
  779. const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
  780. MS_EXCEPTION_IF_NULL(actor_set);
  781. runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, input_tensors);
  782. MS_EXCEPTION_IF_NULL(graph_compiler_);
  783. graph_compiler_->Summary(graph_compiler_info.graphs_);
  784. // Update device address for output node of graph.
  785. // Summary processing will use the output device address, so must be after the summary processing.
  786. actor_set->output_actor_->UpdateOutputDeviceAddress();
  787. // Fetch outputs.
  788. MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
  789. auto &output_tensors = actor_set->output_actor_->outputs();
  790. if (output_tensors.size() > 0) {
  791. size_t output_position = 0;
  792. ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
  793. }
  794. runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
  795. MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
  796. }
  797. void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
  798. const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
  799. VectorRef *outputs) {
  800. MS_EXCEPTION_IF_NULL(output_node);
  801. MS_EXCEPTION_IF_NULL(outputs);
  802. MS_EXCEPTION_IF_NULL(output_position);
  803. const PrimitiveSet expand_prims{
  804. prim::kPrimMakeTuple,
  805. prim::kPrimMakeCSRTensor,
  806. prim::kPrimMakeSparseTensor,
  807. prim::kPrimMakeRowTensor,
  808. };
  809. // The MakeTuple/MakeSaprse node need expand and recurse.
  810. if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
  811. auto make_tuple = output_node->cast<CNodePtr>();
  812. MS_EXCEPTION_IF_NULL(make_tuple);
  813. VectorRef make_tuple_output;
  814. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  815. ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output);
  816. }
  817. outputs->emplace_back(std::move(make_tuple_output));
  818. return;
  819. }
  820. // The depend node need get the real node.
  821. if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
  822. auto depend_node = output_node->cast<CNodePtr>();
  823. MS_EXCEPTION_IF_NULL(depend_node);
  824. ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs);
  825. return;
  826. }
  827. auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node);
  828. // The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
  829. if (output_node->isa<ValueNode>()) {
  830. auto value = output_node->cast<ValueNodePtr>()->value();
  831. MS_EXCEPTION_IF_NULL(value);
  832. if (value->isa<ValueTuple>()) {
  833. outputs->emplace_back(value);
  834. (*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
  835. } else if (outputs_num != 0) {
  836. outputs->emplace_back(value);
  837. (*output_position) += outputs_num;
  838. }
  839. // The empty value node return the empty VectorRef.
  840. return;
  841. }
  842. auto &output_abstract = output_node->abstract();
  843. MS_EXCEPTION_IF_NULL(output_abstract);
  844. // Wrap output to VectorRef if the output is tuple.
  845. if (output_abstract->isa<abstract::AbstractTuple>()) {
  846. VectorRef output_tuple;
  847. for (size_t i = 0; i < outputs_num; ++i) {
  848. if (*output_position >= output_tensors.size()) {
  849. MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
  850. }
  851. output_tuple.emplace_back(std::move(output_tensors[*output_position]));
  852. ++(*output_position);
  853. }
  854. outputs->emplace_back(std::move(output_tuple));
  855. } else {
  856. for (size_t i = 0; i < outputs_num; ++i) {
  857. if (*output_position >= output_tensors.size()) {
  858. MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
  859. }
  860. outputs->emplace_back(std::move(output_tensors[*output_position]));
  861. ++(*output_position);
  862. }
  863. }
  864. }
  865. #ifdef ENABLE_DEBUGGER
  866. void MindRTBackend::SetDebuggerInit() {
  867. auto debugger_ = Debugger::GetInstance();
  868. auto ms_context = MsContext::GetInstance();
  869. MS_EXCEPTION_IF_NULL(ms_context);
  870. debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
  871. }
  872. #endif
  873. void MindRTBackend::SyncLazyTasks() const { runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks(); }
  874. void MindRTBackend::ClearOpBuilderResource() const { runtime::OpLazyBuilder::GetInstance().Reset(); }
  875. void MindRTBackend::SyncStream() {
  876. const auto &device_context =
  877. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
  878. MS_EXCEPTION_IF_NULL(device_context);
  879. (void)device_context->SyncStream();
  880. }
  881. std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
  882. MS_EXCEPTION_IF_NULL(root_graph);
  883. MS_EXCEPTION_IF_NULL(graph_compiler_);
  884. std::vector<KernelGraphPtr> graphs;
  885. std::vector<DeviceContext *> device_contexts;
  886. std::string name = "kernel_graph";
  887. for (const auto &graph_id_to_context : graph_id_to_device_context_) {
  888. (void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
  889. (void)device_contexts.emplace_back(graph_id_to_context.second);
  890. (void)name.append("_").append(std::to_string(graph_id_to_context.first));
  891. }
  892. FuncGraphToKernelGraph func_graph_to_kernel_graphs;
  893. for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
  894. const auto &func_graph = func_graph_to_kernel_graph_ids.first;
  895. for (const auto &graph_id : func_graph_to_kernel_graph_ids.second) {
  896. const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
  897. MS_EXCEPTION_IF_NULL(kernel_graph);
  898. func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graph);
  899. }
  900. }
  901. auto parser = std::make_shared<ControlNodeParser>();
  902. parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
  903. runtime::KernelMapPosition outputs_order;
  904. size_t outputs_num = 0;
  905. const auto &root_output =
  906. AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
  907. size_t position = 0;
  908. auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
  909. outputs_num = outputs.size();
  910. for (const auto &output : outputs) {
  911. if (outputs_order.count(output) == 0) {
  912. outputs_order[output] = {position++};
  913. } else {
  914. (void)outputs_order[output].emplace_back(position++);
  915. }
  916. }
  917. std::vector<std::vector<int64_t> *> tensors_mask;
  918. std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
  919. return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
  920. root_graph->parameters(), parser, outputs_order, outputs_num, name, false,
  921. runtime::GraphExecutionStrategy::kPipeline);
  922. }
  923. std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
  924. const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask,
  925. const std::vector<tensor::TensorPtr> *input_tensors, bool need_erase) {
  926. std::vector<KernelGraphPtr> graphs;
  927. std::vector<DeviceContext *> device_contexts;
  928. runtime::KernelMapPosition outputs_order;
  929. size_t position = 0;
  930. MS_EXCEPTION_IF_NULL(graph_compiler_);
  931. for (const auto &graph_info_to_context : graph_info_to_device_context_) {
  932. const auto &graph = graph_compiler_->Fetch(graph_info_to_context.first);
  933. MS_EXCEPTION_IF_NULL(graph);
  934. (void)graphs.emplace_back(graph);
  935. (void)device_contexts.emplace_back(graph_info_to_context.second);
  936. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  937. for (const auto &output : outputs) {
  938. if (outputs_order.count(output) == 0) {
  939. outputs_order[output] = {position++};
  940. } else {
  941. (void)outputs_order[output].emplace_back(position++);
  942. }
  943. }
  944. }
  945. std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
  946. std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
  947. const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
  948. auto parser = std::make_shared<ControlNodeParser>();
  949. return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
  950. std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
  951. outputs_order, 0, actor_info, need_erase,
  952. runtime::GraphExecutionStrategy::kStep);
  953. }
  954. void MindRTBackend::EraseSingleOpCache(const ActorInfo &actor_info, const KernelGraphPtr &graph) {
  955. MS_EXCEPTION_IF_NULL(graph);
  956. if (graph_info_to_device_context_.empty()) {
  957. MS_LOG(EXCEPTION) << "The map graph_info_to_device_context_ is empty.";
  958. }
  959. const auto &graph_info = graph_info_to_device_context_.begin()->first;
  960. MS_EXCEPTION_IF_NULL(graph_compiler_);
  961. graph_compiler_->EraseSingleOpCache(graph_info, graph->graph_id());
  962. actor_to_graph_compiler_info_.erase(actor_info);
  963. }
  964. void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph,
  965. const std::vector<session::KernelWithIndex> &output_nodes,
  966. const OpRunInfo &op_run_info, const GraphCompilerInfo *graph_compiler_info,
  967. DeviceContext *device_context) {
  968. // Erase value node tensor.
  969. std::vector<tensor::TensorPtr> tensors_without_value_node;
  970. const auto &input_tensors = op_run_info.input_tensors;
  971. const auto &tensors_mask = op_run_info.tensor_mask;
  972. if (input_tensors.size() != tensors_mask.size()) {
  973. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  974. << tensors_mask.size();
  975. }
  976. for (size_t index = 0; index < tensors_mask.size(); ++index) {
  977. if (tensors_mask.at(index) != kValueNodeTensorMask) {
  978. (void)tensors_without_value_node.emplace_back(input_tensors.at(index));
  979. }
  980. }
  981. for (auto &tensor : tensors_without_value_node) {
  982. MS_EXCEPTION_IF_NULL(tensor);
  983. if (tensor->NeedWaitDevice()) {
  984. tensor->WaitDevice();
  985. }
  986. }
  987. // Run actor DAG.
  988. const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(graph_compiler_info->name_);
  989. MS_EXCEPTION_IF_NULL(actor_set);
  990. runtime::GraphScheduler::GetInstance().Run(actor_set, {}, {tensors_without_value_node}, input_tensors,
  991. runtime::GraphExecutionStrategy::kStep);
  992. // Release the kernel resource.
  993. const auto &kernels = graph->execution_order();
  994. for (const auto &kernel : kernels) {
  995. MS_EXCEPTION_IF_NULL(kernel);
  996. if (kOpCacheBlackList.find(AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
  997. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  998. if (kernel_mod) {
  999. kernel_mod->ReleaseResource();
  1000. }
  1001. }
  1002. }
  1003. }
  1004. void MindRTBackend::CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpTask>> &build_tasks) {
  1005. if (build_tasks.empty()) {
  1006. return;
  1007. }
  1008. std::vector<KernelGraphPtr> graphs;
  1009. std::vector<GraphCompilerInfo *> graph_compiler_infos;
  1010. for (const auto &task : build_tasks) {
  1011. MS_EXCEPTION_IF_NULL(task);
  1012. const auto &context = task->context();
  1013. MS_EXCEPTION_IF_NULL(context);
  1014. graphs.push_back(context->graph());
  1015. graph_compiler_infos.push_back(context->graph_compiler_info());
  1016. }
  1017. MS_EXCEPTION_IF_NULL(build_tasks[0]);
  1018. auto &task_context = build_tasks[0]->context();
  1019. MS_EXCEPTION_IF_NULL(task_context);
  1020. auto device_context = task_context->device_context();
  1021. graph_compiler_->BuildSingleOpGraphs(graphs, device_context);
  1022. for (const auto &graph_compiler_info : graph_compiler_infos) {
  1023. MS_EXCEPTION_IF_NULL(graph_compiler_info);
  1024. auto actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
  1025. graph_compiler_info->input_tensors_.clear();
  1026. runtime::GraphScheduler::GetInstance().Schedule(actor_set);
  1027. }
  1028. }
  1029. void MindRTBackend::LazyExecuteTaskCallback() {
  1030. auto &op_lazy_builder = runtime::OpLazyBuilder::GetInstance();
  1031. if (op_lazy_builder.QueueEmpty()) {
  1032. return;
  1033. }
  1034. try {
  1035. MS_LOG(DEBUG) << "Start";
  1036. auto ms_context = MsContext::GetInstance();
  1037. auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
  1038. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
  1039. CompileSingleOpGraphs(op_lazy_builder.GetOpBuildTasks());
  1040. op_lazy_builder.ClearOpBuildTasks();
  1041. // Run op one by one
  1042. auto &op_run_tasks = op_lazy_builder.GetOpRunTasks();
  1043. while (!op_run_tasks.empty()) {
  1044. auto &op_run_task = op_run_tasks.front();
  1045. const auto &context = op_run_task->context();
  1046. RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(),
  1047. context->graph_compiler_info(), context->device_context());
  1048. UpdateOutputDeviceAddress(context->output_nodes(), context->device_context());
  1049. UpdateInputDeviceAddress(context->graph());
  1050. op_lazy_builder.PopOpRunTask();
  1051. }
  1052. ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, infer_flag);
  1053. MS_LOG(DEBUG) << "End";
  1054. } catch (const py::type_error &ex) {
  1055. op_lazy_builder.Reset();
  1056. throw py::type_error(ex);
  1057. } catch (const py::value_error &ex) {
  1058. op_lazy_builder.Reset();
  1059. throw py::value_error(ex);
  1060. } catch (const py::index_error &ex) {
  1061. op_lazy_builder.Reset();
  1062. throw py::index_error(ex);
  1063. } catch (const py::name_error &ex) {
  1064. op_lazy_builder.Reset();
  1065. throw py::name_error(ex);
  1066. } catch (const std::exception &ex) {
  1067. op_lazy_builder.Reset();
  1068. throw(std::runtime_error(ex.what()));
  1069. } catch (...) {
  1070. op_lazy_builder.Reset();
  1071. std::string exName(abi::__cxa_current_exception_type()->name());
  1072. MS_LOG(EXCEPTION) << "Error occurred when execute task in queue. Exception name: " << exName;
  1073. }
  1074. }
  1075. void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *graph_compiler_info,
  1076. OpRunInfo *op_run_info, VectorRef *outputs) {
  1077. MS_EXCEPTION_IF_NULL(op_run_info);
  1078. MS_EXCEPTION_IF_NULL(graph_compiler_info);
  1079. // Fetch outputs.
  1080. const auto &graph = graph_compiler_info->graphs_.front();
  1081. MS_EXCEPTION_IF_NULL(graph);
  1082. MS_EXCEPTION_IF_NULL(graph_compiler_);
  1083. const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id());
  1084. MS_EXCEPTION_IF_NULL(outputs);
  1085. auto device_context = graph_compiler_info->device_contexts_.front();
  1086. auto &op_lazy_builder = runtime::OpLazyBuilder::GetInstance();
  1087. // Disable lazy build when:
  1088. // 1. Execute Dynamic shape operator. The output shape depends on the calculation result of the operator.
  1089. // 2. Cache hit and there are no tasks in Queue. For example Non-first iteration.
  1090. // 3. Not in nn.Cell construct.
  1091. bool lazy_build_disabled = graph_compiler_info->need_erase_ ||
  1092. (single_op_cache_hit && op_lazy_builder.QueueEmpty()) || !op_run_info->lazy_build;
  1093. if (lazy_build_disabled) {
  1094. if (!op_lazy_builder.QueueEmpty()) {
  1095. op_lazy_builder.ExecuteRemainingTasks();
  1096. }
  1097. if (!single_op_cache_hit) {
  1098. CompileSingleOpGraph(graph, device_context, graph_compiler_info);
  1099. }
  1100. RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context);
  1101. UpdateOutput(output_nodes, outputs);
  1102. UpdateOutputDeviceAddress(output_nodes, device_context);
  1103. UpdateInputDeviceAddress(graph);
  1104. if (op_run_info->is_dynamic_shape) {
  1105. UpdateOutputAbstract(graph, op_run_info);
  1106. }
  1107. if (graph_compiler_info->need_erase_) {
  1108. EraseSingleOpCache(graph_compiler_info->name_, graph);
  1109. }
  1110. } else {
  1111. UpdateOutput(output_nodes, outputs);
  1112. auto run_op_context = std::make_shared<runtime::OpLazyBuilderContext>(
  1113. graph_compiler_info, graph, output_nodes, *op_run_info, graph_compiler_info->device_contexts_.front());
  1114. if (!single_op_cache_hit) {
  1115. op_lazy_builder.PushOpBuildTask(std::make_shared<runtime::OpBuildTask>(run_op_context));
  1116. }
  1117. op_lazy_builder.PushOpRunTask(std::make_shared<runtime::OpRunTask>(run_op_context));
  1118. if (!op_lazy_builder.registered()) {
  1119. op_lazy_builder.Register([this]() { LazyExecuteTaskCallback(); });
  1120. }
  1121. if (op_lazy_builder.QueueFull()) {
  1122. op_lazy_builder.ExecuteRemainingTasks();
  1123. }
  1124. }
  1125. }
  1126. void MindRTBackend::RunOp(OpRunInfo *op_run_info, VectorRef *outputs) {
  1127. MS_EXCEPTION_IF_NULL(op_run_info);
  1128. MS_EXCEPTION_IF_NULL(graph_compiler_);
  1129. // Get the device context.
  1130. const auto &device_context =
  1131. device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
  1132. MS_EXCEPTION_IF_NULL(device_context);
  1133. device_context->Initialize();
  1134. bool single_op_cache_hit = true;
  1135. auto graph_id = graph_compiler_->CompileGraph(*op_run_info, &single_op_cache_hit, device_context);
  1136. std::string actor_info = std::to_string(graph_id) + "_" + op_run_info->op_name;
  1137. GraphCompilerInfo *graph_compiler_info_ptr;
  1138. if (single_op_cache_hit) {
  1139. auto iter = actor_to_graph_compiler_info_.find(actor_info);
  1140. if (iter == actor_to_graph_compiler_info_.end()) {
  1141. MS_LOG(EXCEPTION) << "Can not find graph compiler info for actor set: " << actor_info;
  1142. }
  1143. graph_compiler_info_ptr = iter->second.get();
  1144. } else {
  1145. graph_info_to_device_context_.clear();
  1146. graph_info_to_device_context_[op_run_info->graph_info] = device_context;
  1147. auto context_ptr = MsContext::GetInstance();
  1148. MS_EXCEPTION_IF_NULL(context_ptr);
  1149. bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
  1150. auto graph_compiler_info =
  1151. ConstructGraphCompilerInfo(actor_info, &op_run_info->tensor_mask, &op_run_info->input_tensors, !enable_cache);
  1152. graph_compiler_info_ptr = graph_compiler_info.get();
  1153. auto ret = actor_to_graph_compiler_info_.try_emplace(actor_info, std::move(graph_compiler_info));
  1154. if (!ret.second) {
  1155. MS_LOG(WARNING) << "ActorInfo:" << actor_info << " already exist in the map.";
  1156. }
  1157. }
  1158. RunOpInternal(single_op_cache_hit, graph_compiler_info_ptr, op_run_info, outputs);
  1159. }
  1160. void MindRTBackend::CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,
  1161. GraphCompilerInfo *graph_compiler_info) const {
  1162. MS_EXCEPTION_IF_NULL(graph);
  1163. MS_EXCEPTION_IF_NULL(device_context);
  1164. graph_compiler_->BuildSingleOpGraphs({graph}, device_context);
  1165. MS_EXCEPTION_IF_NULL(graph_compiler_info);
  1166. auto actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
  1167. graph_compiler_info->input_tensors_.clear();
  1168. // Actor::Init() is called in Schedule.
  1169. // Workspace need to be initialized in Actor::Init().
  1170. // So `Schedule` need to execute after `CreateKernelWorkspaceDeviceAddress`.
  1171. runtime::GraphScheduler::GetInstance().Schedule(actor_set);
  1172. }
  1173. } // namespace compile
  1174. } // namespace mindspore