You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.cc 45 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/ascend_session.h"
  17. #include <algorithm>
  18. #include "operator/ops.h"
  19. #include "ir/meta_tensor.h"
  20. #include "ir/anf.h"
  21. #include "common/trans.h"
  22. #include "device/kernel_runtime.h"
  23. #include "device/ascend/kernel_select_ascend.h"
  24. #include "device/ascend/kernel_build_ascend.h"
  25. #include "device/ascend/ascend_kernel_runtime.h"
  26. #include "device/ascend/ascend_device_address.h"
  27. #include "pre_activate/ascend/ascend_backend_optimization.h"
  28. #include "device/kernel_adjust.h"
  29. #include "device/ascend/ascend_stream_assign.h"
  30. #include "predict/predict.h"
  31. #include "session/anf_runtime_algorithm.h"
  32. #include "ir/scalar.h"
  33. #include "debug/anf_ir_dump.h"
  34. #include "debug/anf_ir_utils.h"
  35. #include "common/utils.h"
  36. #include "pre_activate/common/helper.h"
  37. #include "device/kernel_runtime_manager.h"
  38. #include "kernel/tbe/tbe_python_funcs.h"
  39. #include "utils/config_manager.h"
  40. namespace mindspore {
  41. namespace session {
  42. const size_t kInvalidIndex = SIZE_MAX;
  43. namespace {
  44. void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order) {
  45. MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
  46. MS_LOG(INFO) << "[index][stream_label][graph_id][node string]";
  47. int i = 0;
  48. for (auto &cnode : execution_order) {
  49. MS_EXCEPTION_IF_NULL(cnode);
  50. MS_LOG(INFO) << "[ " << i << "]"
  51. << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]"
  52. << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]"
  53. << "[" << cnode->DebugString() << "]";
  54. i++;
  55. }
  56. }
  57. void DumpGraphInputArgs(const VectorRef &args) {
  58. MS_LOG(INFO) << "Args size[%lu]" << args.size();
  59. for (size_t i = 0; i < args.size(); i++) {
  60. if (utils::isa<AnfNodePtr>(args[i])) {
  61. auto anf = utils::cast<AnfNodePtr>(args[i]);
  62. MS_EXCEPTION_IF_NULL(anf);
  63. MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString();
  64. } else if (utils::isa<ValuePtr>(args[i])) {
  65. auto value = utils::cast<ValuePtr>(args[i]);
  66. MS_EXCEPTION_IF_NULL(value);
  67. MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString();
  68. } else {
  69. MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString();
  70. }
  71. }
  72. }
  73. void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
  74. MS_EXCEPTION_IF_NULL(graph);
  75. for (auto &node : graph->execution_order()) {
  76. if (is_override || AnfAlgo::GetStreamDistinctionLabel(node.get()) == kInvalidDistincLabel) {
  77. MS_EXCEPTION_IF_NULL(node);
  78. AnfAlgo::SetStreamDistinctionLabel(label, node.get());
  79. }
  80. }
  81. }
  82. GraphId GetDistinctionLabel(const KernelGraphPtr &graph) {
  83. MS_EXCEPTION_IF_NULL(graph);
  84. // if graph is empty,use graph id as distinction label
  85. if (graph->execution_order().empty()) {
  86. return graph->graph_id();
  87. }
  88. // else use first node of execution order as label
  89. return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get());
  90. }
  91. std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) {
  92. MS_EXCEPTION_IF_NULL(graph);
  93. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  94. auto valid_inputs = graph->ValidInputs();
  95. size_t real_args_size = 0;
  96. std::vector<BaseRef> real_args = {};
  97. for (size_t i = 0; i < args.size(); i++) {
  98. if (utils::isa<AnfNodePtr>(args[i])) {
  99. auto tmp_args = AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem});
  100. for (auto &real_arg : tmp_args) {
  101. auto anf_node = utils::cast<AnfNodePtr>(real_arg);
  102. MS_EXCEPTION_IF_NULL(anf_node);
  103. auto abstract = anf_node->abstract();
  104. MS_EXCEPTION_IF_NULL(abstract);
  105. // create multiple parameters if is a tuple output real kernel
  106. if (abstract->isa<abstract::AbstractTuple>() &&
  107. !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
  108. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  109. real_args_size += tuple_abstract->size();
  110. continue;
  111. }
  112. real_args_size += 1;
  113. real_args.push_back(real_arg);
  114. }
  115. } else {
  116. real_args_size += 1;
  117. real_args.push_back(args[i]);
  118. }
  119. }
  120. if (graph_inputs.size() != valid_inputs.size()) {
  121. MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
  122. << ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
  123. }
  124. if (real_args_size != graph_inputs.size()) {
  125. for (size_t j = 0; j < valid_inputs.size(); j++) {
  126. if (valid_inputs[j]) {
  127. MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
  128. }
  129. }
  130. MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
  131. << " not equal";
  132. }
  133. return real_args;
  134. }
  135. } // namespace
  136. GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  137. MS_LOG(INFO) << "start";
  138. auto graph_id = graph_sum_;
  139. // construct graph, if successfully, graph_sum_ + 1
  140. auto graph = ConstructKernelGraph(lst, outputs);
  141. MS_EXCEPTION_IF_NULL(graph);
  142. opt::AscendBackendIRFusionOptimization(graph);
  143. // select kernel build info
  144. SelectKernel(*graph);
  145. // convert kernel Graph to model
  146. predictmodel::StepConvertGraph(graph);
  147. // optimize graph
  148. HardwareOptimize(graph);
  149. // init runtime resource
  150. InitRuntimeResource();
  151. // assign static memory of parameters
  152. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  153. MS_EXCEPTION_IF_NULL(runtime_instance);
  154. runtime_instance->AssignStaticMemoryInput(graph.get());
  155. MS_LOG(INFO) << "Compile graph " << graph_id << " success";
  156. return graph_id;
  157. }
  158. void AscendSession::BuildGraph(GraphId graph_id) {
  159. MS_LOG(INFO) << "start";
  160. auto graph = GetGraph(graph_id);
  161. MS_EXCEPTION_IF_NULL(graph);
  162. // multiple graph handle
  163. if (graph_id == final_graph_id_) {
  164. if (!graph->executable()) {
  165. return;
  166. }
  167. // merge child graph
  168. MergeGraphExecOrder();
  169. } else {
  170. // set the distinction label of single graph
  171. SetStreamDistinctionLabel(GetGraph(graph_id), graph_id, false);
  172. }
  173. // adjust execution order because merge child graph and other special operations
  174. AdjustKernel(graph);
  175. // Assign streams for control sink and hccl and so on
  176. AssignStream(graph);
  177. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
  178. // build kernel if node is cnode
  179. BuildKernel(graph);
  180. auto ms_context = MsContext::GetInstance();
  181. MS_EXCEPTION_IF_NULL(ms_context);
  182. if (ms_context->precompile_only()) {
  183. MS_LOG(INFO) << "Precompile only, stop in build kernel step";
  184. } else {
  185. // alloc memory, including static memory and dynamic memory
  186. MemoryAlloc(graph.get());
  187. // generate task info for task sink mode
  188. GenerateTaskInfo(graph);
  189. // load task info to device if it is sink mode
  190. LoadTask(graph);
  191. }
  192. MS_LOG(INFO) << "end";
  193. }
  194. void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  195. VectorRef *const outputs) {
  196. MS_LOG(INFO) << "start";
  197. auto kernel_graph = GetGraph(graph_id);
  198. MS_EXCEPTION_IF_NULL(kernel_graph);
  199. // if none of child graph and no anf output exists
  200. if (!kernel_graph->executable()) {
  201. MS_LOG(INFO) << "No child graph has anf output";
  202. UpdateOutputs(kernel_graph, outputs, inputs);
  203. return;
  204. }
  205. // load input data from user input
  206. LoadInputData(kernel_graph, inputs);
  207. // convert inputs to model
  208. predictmodel::StepConvertWeight(inputs);
  209. {
  210. py::gil_scoped_release release;
  211. // run task on device
  212. ExecTask(kernel_graph);
  213. }
  214. // get result from device
  215. UpdateOutputs(kernel_graph, outputs, inputs);
  216. // summary
  217. Summary(kernel_graph.get());
  218. // dump used for debug
  219. Dump(kernel_graph);
  220. MS_LOG(INFO) << "Finish!";
  221. }
  222. void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
  223. MS_LOG(INFO) << "Start";
  224. // data layout optimization
  225. opt::RunOpAscendDataLayout(kernel_graph);
  226. // mixed precision optimization
  227. opt::AscendMixPrecision(kernel_graph);
  228. MS_LOG(INFO) << "Finish";
  229. }
  230. void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  231. MS_LOG(INFO) << "Start!";
  232. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  233. MS_EXCEPTION_IF_NULL(runtime_instance);
  234. bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
  235. if (!ret_ok) {
  236. MS_LOG(EXCEPTION) << "run task error!";
  237. }
  238. MS_LOG(INFO) << "Finish!";
  239. }
  240. bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
  241. if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
  242. return true;
  243. }
  244. return false;
  245. }
  246. void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  247. const std::vector<tensor::TensorPtr> &input_tensors,
  248. const std::vector<bool> &tensors_mask) {
  249. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
  250. if (GraphCacheExist(graph_info)) {
  251. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
  252. return;
  253. }
  254. // construct graph include one op
  255. auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
  256. MS_EXCEPTION_IF_NULL(graph);
  257. opt::RunOpAscendBackendIRFusionOptimization(graph);
  258. // kernel select
  259. SelectKernel(*graph);
  260. // optimize
  261. RunOpHardwareOptimize(graph);
  262. // init runtime resource
  263. InitRuntimeResource();
  264. // build kernel
  265. RunOpAdjustKernel(graph);
  266. BuildKernel(graph);
  267. run_op_graphs_[graph_info] = graph;
  268. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
  269. }
  270. py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  271. const std::vector<tensor::TensorPtr> &input_tensors) {
  272. auto graph = run_op_graphs_[graph_info];
  273. MS_EXCEPTION_IF_NULL(graph);
  274. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
  275. // malloc mem
  276. RunOpMemoryAlloc(input_tensors, graph.get());
  277. // load input data to device
  278. LoadInputData(graph, input_tensors);
  279. // run op
  280. RunOpExecTask(graph);
  281. // get output
  282. VectorRef outputs;
  283. UpdateOutputs(graph, &outputs, input_tensors);
  284. // trans output to tuple
  285. auto output_tensors = TransformBaseRefListToTuple(outputs);
  286. if (!utils::isa<PyObjectRef>(output_tensors) ||
  287. !py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
  288. MS_LOG(EXCEPTION) << "The output tensors should be a tuple !";
  289. }
  290. py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
  291. py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
  292. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
  293. return tuple_tensors;
  294. }
  295. // compile graph steps
  296. void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
  297. MS_LOG(INFO) << "Start!";
  298. for (const auto &cnode : kernel_graph.execution_order()) {
  299. device::ascend::SelectKernelInfo(cnode);
  300. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  301. }
  302. MS_LOG(INFO) << "Finish!";
  303. }
  304. void AscendSession::InitRuntimeResource() {
  305. MS_LOG(INFO) << "Start!";
  306. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  307. MS_EXCEPTION_IF_NULL(runtime_instance);
  308. if (!runtime_instance->Init()) {
  309. MS_LOG(EXCEPTION) << "Kernel runtime init error.";
  310. }
  311. MS_LOG(INFO) << "Finish!";
  312. }
  313. void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  314. MS_LOG(INFO) << "HardwareOptimize start!";
  315. opt::AscendBackendOptimization(kernel_graph);
  316. MS_EXCEPTION_IF_NULL(kernel_graph);
  317. kernel_graph->SetExecOrderByDefault();
  318. MS_LOG(INFO) << "HardwareOptimize Finish!";
  319. }
  320. void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  321. MS_LOG(INFO) << "Start!";
  322. device::KernelAdjust::GetInstance().Reorder(kernel_graph);
  323. opt::HideNopNode(kernel_graph.get());
  324. // Insert CLearZero op
  325. // prepare for next step from json get atomic info
  326. BuildKernel(kernel_graph);
  327. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  328. device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
  329. auto context_ptr = MsContext::GetInstance();
  330. MS_EXCEPTION_IF_NULL(context_ptr);
  331. bool save_graphs = context_ptr->save_graphs_flag();
  332. auto save_graphs_path = context_ptr->save_graphs_path();
  333. if (save_graphs_path.empty()) {
  334. save_graphs_path = ".";
  335. }
  336. if (save_graphs) {
  337. std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir";
  338. DumpIR(file_path, kernel_graph);
  339. }
  340. MS_LOG(INFO) << "Finish!";
  341. }
  342. void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  343. MS_LOG(INFO) << "Start!";
  344. opt::HideNopNode(kernel_graph.get());
  345. // Insert CLearZero op
  346. // prepare for next step from json get atomic info
  347. BuildKernel(kernel_graph);
  348. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  349. MS_LOG(INFO) << "Finish!";
  350. }
  351. void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  352. MS_LOG(INFO) << "Start!";
  353. device::ascend::AscendStreamAssign::GetInstance().AssignStreamNew(kernel_graph);
  354. MS_LOG(INFO) << "Finish!";
  355. }
  356. void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  357. MS_LOG(INFO) << "Start!";
  358. struct timeval start_time, end_time;
  359. (void)gettimeofday(&start_time, nullptr);
  360. auto ret = device::ascend::KernelBuild(kernel_graph.get());
  361. if (!ret) {
  362. MS_LOG(EXCEPTION) << "Kernel build error.";
  363. }
  364. (void)gettimeofday(&end_time, nullptr);
  365. const uint64_t kUSecondInSecond = 1000000;
  366. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  367. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  368. MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost;
  369. MS_LOG(INFO) << "Finish!";
  370. }
  371. void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
  372. MS_LOG(INFO) << "Start!";
  373. MS_EXCEPTION_IF_NULL(kernel_graph);
  374. opt::RemoveNopNode(kernel_graph);
  375. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  376. MS_EXCEPTION_IF_NULL(runtime_instance);
  377. runtime_instance->AssignMemory(kernel_graph);
  378. MS_LOG(INFO) << "Finish!";
  379. }
  380. void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors,
  381. KernelGraph *kernel_graph) const {
  382. MS_LOG(INFO) << "Start memory alloc!";
  383. MS_EXCEPTION_IF_NULL(kernel_graph);
  384. opt::RemoveNopNode(kernel_graph);
  385. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  386. MS_EXCEPTION_IF_NULL(runtime_instance);
  387. runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph);
  388. MS_LOG(INFO) << "Finish!";
  389. }
  390. void AscendSession::GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  391. MS_LOG(INFO) << "Start!";
  392. (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(context_, kernel_graph);
  393. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  394. MS_EXCEPTION_IF_NULL(runtime_instance);
  395. bool ret_ok = runtime_instance->GenTask(kernel_graph.get());
  396. if (!ret_ok) {
  397. MS_LOG(EXCEPTION) << "Generate task error!";
  398. }
  399. MS_LOG(INFO) << "Finish!";
  400. }
  401. void AscendSession::LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  402. MS_LOG(INFO) << "Start!";
  403. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  404. MS_EXCEPTION_IF_NULL(runtime_instance);
  405. bool ret_ok = runtime_instance->LoadTask(kernel_graph.get());
  406. if (!ret_ok) {
  407. MS_LOG(EXCEPTION) << "Load task error!";
  408. }
  409. MS_LOG(INFO) << "Finish!";
  410. }
  411. void AscendSession::ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  412. MS_LOG(INFO) << "Start!";
  413. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  414. MS_EXCEPTION_IF_NULL(runtime_instance);
  415. bool ret_ok = runtime_instance->Run(kernel_graph.get());
  416. if (!ret_ok) {
  417. MS_LOG(EXCEPTION) << "run task error!";
  418. }
  419. MS_LOG(INFO) << "Finish!";
  420. }
  421. void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  422. MS_LOG(INFO) << "Start!";
  423. MS_EXCEPTION_IF_NULL(kernel_graph);
  424. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  425. MS_EXCEPTION_IF_NULL(runtime_instance);
  426. (void)runtime_instance->DumpData(kernel_graph.get());
  427. MS_LOG(INFO) << "Finish!";
  428. }
  429. GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
  430. MS_LOG(INFO) << "Start! Args size " << args.size();
  431. auto final_graph = std::make_shared<KernelGraph>();
  432. final_graph_id_ = graph_sum_++;
  433. graphs_[final_graph_id_] = final_graph;
  434. final_graph->set_graph_id(final_graph_id_);
  435. MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << "success";
  436. // init private variables and bind them with final_graph_id
  437. graph_execute_orders_[final_graph_id_] = std::vector<GraphId>();
  438. graph_order_types_[final_graph_id_] = std::vector<GraphType>();
  439. for (const auto &parameter : args) {
  440. MS_EXCEPTION_IF_NULL(parameter);
  441. if (!parameter->isa<Parameter>()) {
  442. MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!";
  443. }
  444. AnfNodePtr parameter_backend = nullptr;
  445. // if function return UINT_MAX,the parameter is not exist in child graph
  446. auto parameter_belong_graph_id = GetGraphIdByNode(parameter);
  447. if (parameter_belong_graph_id == kInvalidGraphId) {
  448. parameter_backend = final_graph->NewParameter(parameter->cast<ParameterPtr>());
  449. final_graph->FrontBackendlMapAdd(parameter, parameter_backend);
  450. MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph";
  451. } else {
  452. // parametr is a parameter of child graph
  453. auto graph = GetGraph(parameter_belong_graph_id);
  454. MS_EXCEPTION_IF_NULL(graph);
  455. MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph ["
  456. << parameter_belong_graph_id << "]";
  457. parameter_backend = graph->GetBackendAnfByFrontAnf(parameter);
  458. }
  459. MS_EXCEPTION_IF_NULL(parameter_backend);
  460. MS_LOG(INFO) << "parameter backend " << parameter_backend->DebugString() << " belong_graph_id "
  461. << AnfAlgo::GetGraphId(parameter_backend.get());
  462. // add parameter in backend to final graph inputs
  463. auto final_graph_inputs = final_graph->MutableInputs();
  464. MS_EXCEPTION_IF_NULL(final_graph_inputs);
  465. final_graph_inputs->push_back(parameter_backend);
  466. }
  467. MS_LOG(INFO) << "End final_graph_id " << final_graph_id_;
  468. return final_graph_id_;
  469. }
  470. void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
  471. auto final_graph = GetGraph(final_graph_id_);
  472. MS_EXCEPTION_IF_NULL(final_graph);
  473. if (!utils::isa<AnfNodePtr>(output)) {
  474. if (!utils::isa<ValuePtr>(output)) {
  475. MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
  476. }
  477. auto value_ptr = utils::cast<ValuePtr>(output);
  478. auto value_node = NewValueNode(value_ptr);
  479. MS_EXCEPTION_IF_NULL(value_node);
  480. auto kernel_info = std::make_shared<device::KernelInfo>();
  481. value_node->set_kernel_info(kernel_info);
  482. value_node->set_abstract(abstract::FromValue(value_ptr));
  483. final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
  484. final_graph->set_executable(false);
  485. MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]";
  486. return;
  487. }
  488. // get the backend anf node related to the output node of front
  489. auto output_anf_node = utils::cast<AnfNodePtr>(output);
  490. auto output_from_graph_id = GetGraphIdByNode(output_anf_node);
  491. auto output_from_graph = GetGraph(output_from_graph_id);
  492. MS_EXCEPTION_IF_NULL(output_anf_node);
  493. MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id
  494. << "] to final graph";
  495. MS_EXCEPTION_IF_NULL(output_from_graph);
  496. // if output is from final graph,it remarks no child graph exist
  497. if (final_graph_id_ == output_from_graph_id) {
  498. MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString();
  499. final_graph->set_output(ConstructOutput({output_anf_node}, final_graph));
  500. final_graph->set_executable(false);
  501. return;
  502. }
  503. final_graph->set_output(output_from_graph->output());
  504. }
  505. KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
  506. auto it = graphs_.find(graph_id);
  507. if (it == graphs_.end()) {
  508. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  509. return nullptr;
  510. }
  511. return it->second;
  512. }
  513. void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) {
  514. MS_LOG(INFO) << "Start!";
  515. MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]";
  516. auto condition_graph = GetGraph(condition_graph_id);
  517. MS_EXCEPTION_IF_NULL(condition_graph);
  518. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int>{1});
  519. int32_t *val = nullptr;
  520. val = static_cast<int32_t *>(tensor->data_c(true));
  521. MS_EXCEPTION_IF_NULL(val);
  522. *val = 0;
  523. auto value_node = std::make_shared<ValueNode>(tensor);
  524. value_node->set_abstract(abstract::FromValue(tensor, false));
  525. auto counter_const = condition_graph->NewValueNode(value_node);
  526. condition_graph->AddValueNodeToGraph(counter_const);
  527. // create a new switch op
  528. auto switch_primitive = std::make_shared<Primitive>("StreamSwitch");
  529. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  530. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  531. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeInt32});
  532. kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
  533. kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
  534. kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
  535. auto cond_output_it = condition_output_.find(condition_graph_id);
  536. if (cond_output_it == condition_output_.end()) {
  537. MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
  538. }
  539. auto cond_output_kernel =
  540. AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first;
  541. MS_EXCEPTION_IF_NULL(cond_output_kernel);
  542. std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
  543. CNodePtr switch_node = condition_graph->NewCNode(inputs);
  544. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get());
  545. MS_EXCEPTION_IF_NULL(switch_node);
  546. switch_node->set_abstract(std::make_shared<abstract::AbstractNone>());
  547. AnfAlgo::SetGraphId(condition_graph_id, switch_node.get());
  548. AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(GetGraph(condition_graph_id)), switch_node.get());
  549. // set attr: cond_ RT_GREATER
  550. AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue<int>(static_cast<int>(RT_GREATER)), switch_node);
  551. // set attr:data_type
  552. AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue<int>(static_cast<int>(RT_SWITCH_INT64)), switch_node);
  553. // set attr:true branch graph id ,which is same to stream distinction label
  554. AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(true_graph_id), switch_node);
  555. // append switch at the end of condition graph
  556. std::vector<CNodePtr> exec_order = condition_graph->execution_order();
  557. exec_order.push_back(switch_node);
  558. condition_graph->set_execution_order(exec_order);
  559. MS_LOG(INFO) << "Finish!";
  560. }
  561. void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
  562. auto &graph_execute_order = GetGraphOrder(final_graph_id_);
  563. auto &graph_order_type = GetGraphOrderType(final_graph_id_);
  564. auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id);
  565. if (false_index == kInvalidIndex || false_index == 0) {
  566. return;
  567. }
  568. for (int i = SizeToInt(false_index) - 1; i >= 0; i--) {
  569. size_t graph_index = IntToSize(i);
  570. if (graph_index >= graph_execute_order.size()) {
  571. MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]";
  572. }
  573. if (graph_order_type[graph_index] == COMMON_GRAPH) {
  574. auto true_last_id = graph_execute_order[graph_index];
  575. MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id;
  576. auto true_last = GetGraph(true_last_id);
  577. auto final_graph = GetGraph(final_graph_id_);
  578. MS_EXCEPTION_IF_NULL(final_graph);
  579. auto false_last_id = AnfAlgo::GetGraphId(final_graph->output().get());
  580. auto false_last = GetGraph(false_last_id);
  581. MS_EXCEPTION_IF_NULL(true_last);
  582. MS_EXCEPTION_IF_NULL(false_last);
  583. MS_LOG(INFO) << "The last graph of false branch is " << false_last_id;
  584. // now only consider the single output
  585. InsertMultipleAssignToGraph(true_last_id, true_last->output(), false_last->output());
  586. // insert stream active for loop sink
  587. auto context_ptr = MsContext::GetInstance();
  588. MS_EXCEPTION_IF_NULL(context_ptr);
  589. if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
  590. ConfigManager::GetInstance().iter_num() > 1) {
  591. // insert active in true graph, another active will be inserted in kernel adjust
  592. InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel);
  593. }
  594. break;
  595. }
  596. }
  597. }
  598. void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id,
  599. const AnfNodePtr &output) {
  600. if (switches_.find(cond_graph_id) != switches_.end()) {
  601. MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before ";
  602. return;
  603. }
  604. switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id);
  605. condition_output_[cond_graph_id] = output;
  606. MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
  607. // set the type of condition graph
  608. auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
  609. auto &graph_order_type = GetGraphOrderType(final_graph_id_);
  610. if (cond_graph_index >= graph_order_type.size()) {
  611. MS_LOG(EXCEPTION) << "cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size();
  612. }
  613. graph_order_type[cond_graph_index] = CONDITION_GRAPH;
  614. // update distinction label of false graph,update before merge to sure the distinction
  615. if (false_graph_id != kInvalidGraphId) {
  616. // false graph and condition in graph same stream
  617. auto condition_graph = GetGraph(cond_graph_id);
  618. SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true);
  619. // if false graph is a condition graph and has been switch compiled before,it's false should be updated again
  620. auto cond_it = switches_.find(false_graph_id);
  621. while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) {
  622. cond_graph_id = cond_it->first;
  623. false_graph_id = cond_it->second.second;
  624. condition_graph = GetGraph(cond_graph_id);
  625. SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true);
  626. cond_it = switches_.find(false_graph_id);
  627. }
  628. }
  629. } // namespace session
  630. void AscendSession::MergeSwitchCompile() {
  631. auto graph_execute_order = GetGraphOrder(final_graph_id_);
  632. auto &graph_order_type = GetGraphOrderType(final_graph_id_);
  633. for (auto switch_compile : switches_) {
  634. auto cond_graph_id = switch_compile.first;
  635. auto true_graph_id = switch_compile.second.first;
  636. auto false_graph_id = switch_compile.second.second;
  637. MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
  638. auto condition_graph = GetGraph(cond_graph_id);
  639. auto final_graph = GetGraph(final_graph_id_);
  640. MS_EXCEPTION_IF_NULL(condition_graph);
  641. MS_EXCEPTION_IF_NULL(final_graph);
  642. // insert switch to condition graph
  643. InsertSwitchToGraph(cond_graph_id, true_graph_id);
  644. auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
  645. auto prev_graph_id = kInvalidGraphId;
  646. // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph
  647. if (cond_graph_index == 0 && !final_graph->execution_order().empty()) {
  648. prev_graph_id = final_graph_id_;
  649. // set the distinction label of final graph
  650. SetStreamDistinctionLabel(final_graph, final_graph_id_, true);
  651. // if condition graph is not the first graph
  652. } else if ((cond_graph_index - 1 < graph_execute_order.size()) &&
  653. (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) {
  654. prev_graph_id = graph_execute_order[cond_graph_index - 1];
  655. }
  656. // insert stream active to common graph
  657. if (prev_graph_id != kInvalidGraphId) {
  658. InsertStreamActiveToGraph(prev_graph_id, GetDistinctionLabel(condition_graph));
  659. }
  660. // if this is a 'if' condition
  661. auto it = while_condition_graphs_.find(cond_graph_id);
  662. if (it == while_condition_graphs_.end()) {
  663. CopyOutputOfIf(false_graph_id);
  664. } else {
  665. // if it is a while,insert a stream active to true graph
  666. GraphId from_graph = it->second;
  667. InsertStreamActiveToGraph(from_graph, GetDistinctionLabel(condition_graph));
  668. }
  669. }
  670. MS_LOG(INFO) << "Finish!";
  671. }
  672. // insert active to graph
  673. void AscendSession::SetActive(GraphId from, GraphId to) {
  674. if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) {
  675. MS_LOG(WARNING) << " to " << to << " has been exits in map,from " << from << ",exist from "
  676. << while_condition_graphs_[to];
  677. return;
  678. }
  679. MS_LOG(INFO) << "From " << from << " to " << to;
  680. auto &graph_order = GetGraphOrder(final_graph_id_);
  681. auto &graph_type = GetGraphOrderType(final_graph_id_);
  682. std::vector<GraphId> graph_order_new;
  683. std::vector<GraphType> graph_type_new;
  684. for (size_t i = 0; i < graph_order.size(); i++) {
  685. auto graph_id = graph_order[i];
  686. graph_order_new.push_back(graph_id);
  687. graph_type_new.push_back(graph_type[i]);
  688. if (from == graph_id) {
  689. graph_order_new.push_back(kInvalidGraphId);
  690. graph_type_new.push_back(BRANCH_END);
  691. }
  692. }
  693. graph_order = graph_order_new;
  694. graph_type = graph_type_new;
  695. // set the graph type of condition graph
  696. graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH;
  697. // record the condition graph into while condition set
  698. while_condition_graphs_[to] = from;
  699. }
  700. void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter) {
  701. MS_LOG(INFO) << "Start!";
  702. MS_EXCEPTION_IF_NULL(backend_parameter);
  703. MS_EXCEPTION_IF_NULL(front_anf);
  704. if (!backend_parameter->isa<Parameter>()) {
  705. MS_LOG(EXCEPTION) << "Backend parameter's type is not a parameter,but is " << backend_parameter->ToString();
  706. }
  707. auto from_graph_id = GetGraphIdByNode(front_anf);
  708. auto from_graph = GetGraph(from_graph_id);
  709. MS_EXCEPTION_IF_NULL(from_graph);
  710. auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
  711. auto to_graph = GetGraph(to_graph_id);
  712. auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
  713. MS_EXCEPTION_IF_NULL(to_graph);
  714. MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
  715. << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
  716. << "]";
  717. // a node should not assign to itself
  718. if (backend_arg.get() == backend_parameter.get()) {
  719. return;
  720. }
  721. // if arg is the the parameter of child graph,it is parameter of final graph too
  722. if (front_anf->isa<Parameter>()) {
  723. MS_EXCEPTION_IF_NULL(backend_arg);
  724. if (!AnfAlgo::OutputAddrExist(backend_arg, 0)) {
  725. // set parameter's addr in child graph to parameter in final graph
  726. AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_parameter, 0), 0, backend_arg.get());
  727. MS_LOG(INFO) << "Assign mem of node" << backend_parameter->DebugString() << " of graph "
  728. << AnfAlgo::GetGraphId(backend_parameter.get()) << " to node" << backend_arg->DebugString()
  729. << "of graph " << AnfAlgo::GetGraphId(backend_arg.get());
  730. return;
  731. }
  732. // if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device
  733. // type same to arg
  734. if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) {
  735. AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get());
  736. }
  737. // if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph
  738. // unless it's a weight.If backend_parameter is a weight,we should assign the value back.
  739. AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_arg, 0), 0, backend_parameter.get());
  740. return;
  741. }
  742. InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
  743. MS_LOG(INFO) << "Finish!";
  744. }
  745. void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter) {
  746. MS_LOG(INFO) << "Start!";
  747. // sync data from host to device
  748. MS_EXCEPTION_IF_NULL(front_tensor);
  749. size_t tensor_size = front_tensor->data().nbytes();
  750. auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
  751. MS_EXCEPTION_IF_NULL(addr);
  752. if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
  753. front_tensor->data_type(), front_tensor->data_c(false))) {
  754. MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
  755. }
  756. MS_LOG(INFO) << "Finish!";
  757. }
  758. void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
  759. MS_LOG(INFO) << "to_graph_id " << to_graph_id;
  760. auto &graph_order = GetGraphOrder(final_graph_id_);
  761. auto &graph_type = GetGraphOrderType(final_graph_id_);
  762. for (size_t i = 0; i < graph_order.size(); i++) {
  763. if (graph_order[i] == to_graph_id) {
  764. return;
  765. }
  766. }
  767. // if graph is not in graph order,add it to graph order
  768. SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false);
  769. graph_order.push_back(to_graph_id);
  770. graph_type.push_back(COMMON_GRAPH);
  771. for (size_t i = 0; i < graph_order.size(); i++) {
  772. MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i];
  773. }
  774. }
  775. size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) {
  776. auto output_num = AnfAlgo::GetOutputTensorNum(node);
  777. if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
  778. return input_index + output_num;
  779. }
  780. auto &graph_inputs = graph->inputs();
  781. auto &valid_inputs = graph->ValidInputs();
  782. if (valid_inputs[input_index]) {
  783. SetChildGraphParameter(node, graph_inputs[input_index]);
  784. } else {
  785. MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString();
  786. }
  787. return ++input_index;
  788. }
  789. size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) {
  790. MS_EXCEPTION_IF_NULL(value);
  791. if (!value->isa<Tensor>()) {
  792. MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString();
  793. }
  794. auto &graph_inputs = graph->inputs();
  795. SetChildGraphParameter(value->cast<TensorPtr>(), graph_inputs[input_index]);
  796. return ++input_index;
  797. }
  798. size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) {
  799. auto index = input_index;
  800. for (auto &arg : vec_args) {
  801. if (utils::isa<AnfNodePtr>(arg)) {
  802. // arg is a anf node
  803. auto node = utils::cast<AnfNodePtr>(arg);
  804. index = SetChildGraphInput(graph, node, input_index);
  805. } else if (utils::isa<ValuePtr>(arg)) {
  806. // arg is a tensor
  807. auto value = utils::cast<ValuePtr>(arg);
  808. index = SetChildGraphInput(graph, value, input_index);
  809. } else {
  810. MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString();
  811. }
  812. }
  813. return index;
  814. }
  815. void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
  816. MS_LOG(INFO) << "Set input of graph " << g;
  817. auto to_graph = GetGraph(g);
  818. MS_EXCEPTION_IF_NULL(to_graph);
  819. DumpGraphInputArgs(args);
  820. UpdateGraphOrder(g);
  821. auto &graph_inputs = to_graph->inputs();
  822. auto real_args = GetRealArgs(to_graph, args);
  823. size_t input_index = 0;
  824. for (size_t i = 0; i < real_args.size(); i++) {
  825. if (input_index >= graph_inputs.size()) {
  826. MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
  827. }
  828. auto &real_arg = real_args[i];
  829. if (utils::isa<AnfNodePtr>(real_arg)) {
  830. // arg is a anf node
  831. auto node = utils::cast<AnfNodePtr>(real_arg);
  832. input_index = SetChildGraphInput(to_graph, node, input_index);
  833. } else if (utils::isa<ValuePtr>(real_arg)) {
  834. // arg is a tensor
  835. auto value = utils::cast<ValuePtr>(real_arg);
  836. input_index = SetChildGraphInput(to_graph, value, input_index);
  837. } else if (utils::isa<VectorRef>(real_arg)) {
  838. // arg is a VectorRef
  839. auto vec_args = utils::cast<VectorRef>(real_arg);
  840. input_index = SetChildGraphInput(to_graph, vec_args, input_index);
  841. } else {
  842. MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString();
  843. }
  844. }
  845. MS_LOG(INFO) << "Finish!";
  846. }
  847. GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
  848. for (const auto &graph_item : graphs_) {
  849. auto graph = graph_item.second;
  850. MS_EXCEPTION_IF_NULL(graph);
  851. // if front_anf is a parameter,the backend parameter may have two
  852. if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
  853. return graph_item.first;
  854. }
  855. }
  856. MS_EXCEPTION_IF_NULL(front_anf);
  857. MS_LOG(DEBUG) << "front_anf " << front_anf->DebugString() << " is not exist in any graph";
  858. return kInvalidGraphId;
  859. }
  860. void AscendSession::MergeGraphExecOrder() {
  861. MS_LOG(INFO) << "Start!";
  862. // insert switch to graph
  863. MergeSwitchCompile();
  864. // merge graph order
  865. auto &graph_order = GetGraphOrder(final_graph_id_);
  866. auto &graph_type = GetGraphOrderType(final_graph_id_);
  867. auto final_graph = GetGraph(final_graph_id_);
  868. MS_EXCEPTION_IF_NULL(final_graph);
  869. if (graph_order.empty()) {
  870. MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
  871. return;
  872. }
  873. // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
  874. SetStreamDistinctionLabel(final_graph, graph_order[0], false);
  875. std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
  876. KernelGraphPtr last_graph = nullptr;
  877. for (size_t i = 0; i < graph_order.size(); i++) {
  878. auto graph_id = graph_order[i];
  879. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  880. continue;
  881. }
  882. auto child_graph = GetGraph(graph_id);
  883. last_graph = child_graph;
  884. MS_EXCEPTION_IF_NULL(child_graph);
  885. auto exec_order = child_graph->execution_order();
  886. MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
  887. (void)std::copy(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order));
  888. // add all value nodes of child graphs to final graph
  889. for (auto &value_node : child_graph->graph_value_nodes()) {
  890. final_graph->AddValueNodeToGraph(value_node);
  891. }
  892. // copy ref map to final graph
  893. auto child_ref_map = child_graph->GetRefMap();
  894. for (auto &item : child_ref_map) {
  895. if (final_graph->IsInRefOutputMap(item.first)) {
  896. MS_LOG(EXCEPTION) << "The ref pair is already in final graph!";
  897. }
  898. final_graph->AddRefCorrespondPairs(item.first, item.second);
  899. }
  900. }
  901. // set final_exec_order into final graph
  902. MS_EXCEPTION_IF_NULL(final_graph);
  903. DumpGraphExeOrder(final_exec_order);
  904. final_graph->set_execution_order(final_exec_order);
  905. }
  906. void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
  907. MS_EXCEPTION_IF_NULL(from);
  908. MS_EXCEPTION_IF_NULL(to);
  909. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  910. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  911. return;
  912. }
  913. if (from.get() == to.get()) {
  914. return;
  915. }
  916. MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to "
  917. << to->DebugString();
  918. auto graph = graphs_[graph_id];
  919. MS_EXCEPTION_IF_NULL(graph);
  920. // config inputs of assign node
  921. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
  922. // generate a new cnode
  923. auto assign_node = graph->NewCNode(inputs);
  924. MS_EXCEPTION_IF_NULL(assign_node);
  925. assign_node->set_abstract(std::make_shared<abstract::AbstractNone>());
  926. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  927. kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
  928. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), assign_node.get());
  929. AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(graph), assign_node.get());
  930. // append the assign at the end of from graph
  931. auto exec_order = graph->execution_order();
  932. exec_order.push_back(assign_node);
  933. graph->set_execution_order(exec_order);
  934. }
  935. void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
  936. std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
  937. std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
  938. MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to ["
  939. << AnfAlgo::GetGraphId(to.get()) << "]";
  940. if (from_outputs.size() != to_outputs.size()) {
  941. MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]";
  942. MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size["
  943. << to_outputs.size() << "]";
  944. }
  945. for (size_t i = 0; i < from_outputs.size(); i++) {
  946. InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]);
  947. }
  948. }
  949. void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) {
  950. MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream;
  951. auto from_graph = graphs_[graph_id];
  952. MS_EXCEPTION_IF_NULL(from_graph);
  953. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("StreamActive"))};
  954. auto active_node = from_graph->NewCNode(inputs);
  955. MS_EXCEPTION_IF_NULL(active_node);
  956. active_node->set_abstract(std::make_shared<abstract::AbstractNone>());
  957. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  958. kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
  959. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), active_node.get());
  960. // set the active stream id into the attr of active node
  961. std::vector<uint32_t> active_index_value = {};
  962. active_index_value.push_back(actived_stream);
  963. AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_value), active_node);
  964. AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(from_graph), active_node.get());
  965. // append the active node at the end of from graph
  966. auto exec_order = from_graph->execution_order();
  967. exec_order.push_back(active_node);
  968. from_graph->set_execution_order(exec_order);
  969. }
  970. size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) {
  971. auto &graph_order = GetGraphOrder(final_graph);
  972. for (size_t i = 0; i < graph_order.size(); i++) {
  973. if (child_graph == graph_order[i]) {
  974. return i;
  975. }
  976. }
  977. return kInvalidIndex;
  978. }
  979. std::vector<GraphId> &AscendSession::GetGraphOrder(GraphId final_graph_id) {
  980. auto graph_order_iter = graph_execute_orders_.find(final_graph_id);
  981. if (graph_order_iter == graph_execute_orders_.end()) {
  982. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph";
  983. }
  984. return graph_order_iter->second;
  985. }
  986. // get graph order type vector by graph id
  987. std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id) {
  988. auto graph_type_iter = graph_order_types_.find(final_graph_id);
  989. if (graph_type_iter == graph_order_types_.end()) {
  990. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_";
  991. }
  992. return graph_type_iter->second;
  993. }
  994. } // namespace session
  995. } // namespace mindspore