You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.cc 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/ascend_session.h"
  17. #include <algorithm>
  18. #include <map>
  19. #include <tuple>
  20. #include <set>
  21. #include <string>
  22. #include <list>
  23. #include "base/core_ops.h"
  24. #include "ir/tensor.h"
  25. #include "ir/anf.h"
  26. #include "common/trans.h"
  27. #include "runtime/device/kernel_runtime.h"
  28. #include "runtime/device/ascend/kernel_select_ascend.h"
  29. #include "runtime/device/ascend/kernel_build_ascend.h"
  30. #include "runtime/device/ascend/ascend_kernel_runtime.h"
  31. #include "backend/optimizer/ascend/ascend_backend_optimization.h"
  32. #include "backend/optimizer/common/common_backend_optimization.h"
  33. #include "runtime/device/kernel_adjust.h"
  34. #include "runtime/device/ascend/ascend_stream_assign.h"
  35. #include "backend/session/anf_runtime_algorithm.h"
  36. #include "utils/ms_utils.h"
  37. #include "backend/optimizer/common/helper.h"
  38. #include "runtime/device/kernel_runtime_manager.h"
  39. #include "utils/config_manager.h"
  40. #include "debug/data_dump/dump_json_parser.h"
  41. #include "debug/tensor_load.h"
  42. #include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
  43. #include "debug/data_dump/e2e_dump_util.h"
  44. #include "debug/anf_ir_dump.h"
  45. #include "debug/dump_proto.h"
  46. namespace mindspore {
  47. namespace session {
  48. const size_t kInvalidIndex = SIZE_MAX;
  49. constexpr size_t kReturnDataIndex = 1;
  50. namespace {
  51. void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") {
  52. MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
  53. MS_LOG(INFO) << "[index][stream_label][graph_id][node string]";
  54. int i = 0;
  55. for (auto &cnode : execution_order) {
  56. MS_EXCEPTION_IF_NULL(cnode);
  57. MS_LOG(INFO) << "[ " << i << "]"
  58. << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]"
  59. << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]"
  60. << "[" << cnode->DebugString() << "]";
  61. i++;
  62. }
  63. std::stringstream buf;
  64. buf << "================== execution order ==================\n";
  65. if (!tag.empty()) {
  66. buf << tag << "\n";
  67. }
  68. buf << "execution_order size: " << execution_order.size() << "\n";
  69. i = 0;
  70. for (auto &cnode : execution_order) {
  71. MS_EXCEPTION_IF_NULL(cnode);
  72. buf << i << ":\n";
  73. buf << "\t" << cnode->DebugString() << "\n";
  74. buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n";
  75. buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n";
  76. i++;
  77. }
  78. buf << "================== execution order ==================\n";
  79. }
  80. void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
  81. MS_EXCEPTION_IF_NULL(graph);
  82. if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) {
  83. graph->set_stream_distinction_label(label);
  84. }
  85. }
  86. std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
  87. std::vector<CNodePtr> cnodes = {};
  88. size_t i = 0;
  89. for (const auto &anf : anf_nodes) {
  90. MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString();
  91. MS_EXCEPTION_IF_NULL(anf);
  92. if (anf->isa<CNode>()) {
  93. cnodes.push_back(anf->cast<CNodePtr>());
  94. }
  95. }
  96. return cnodes;
  97. }
  98. void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
  99. auto return_node = root_graph->get_return();
  100. MS_EXCEPTION_IF_NULL(return_node);
  101. if (return_node->size() <= kReturnDataIndex) {
  102. return;
  103. }
  104. auto make_tuple = root_graph->NewCNode(
  105. {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
  106. root_graph->set_output(make_tuple);
  107. }
  108. } // namespace
  109. GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  110. MS_LOG(INFO) << "Start";
  111. // construct graph, if successfully, graph_sum_ + 1
  112. auto graph = ConstructKernelGraph(lst, outputs);
  113. auto graph_id = graph->graph_id();
  114. MS_LOG(INFO) << "Compile graph " << graph_id << " success";
  115. return graph_id;
  116. }
  117. GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
  118. MS_LOG(INFO) << "Start";
  119. std::vector<KernelGraphPtr> all_graphs;
  120. auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
  121. BackendOptimization(all_graphs);
  122. // empty graph dont entry to backend
  123. if (root_graph->execution_order().empty()) {
  124. MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
  125. InsertMakeTupleForOutput(NOT_NULL(root_graph));
  126. root_graph->set_executable(false);
  127. InitRuntimeResource();
  128. return root_graph->graph_id();
  129. }
  130. // create parameter for multiple branch
  131. std::set<KernelGraphPtr> memo;
  132. CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
  133. memo.clear();
  134. // insert goto labels and label_sets
  135. LinkChildGraphs(NOT_NULL(root_graph));
  136. // resource initialize
  137. InitRuntimeResource();
  138. IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo));
  139. memo.clear();
  140. SelectKernel(NOT_NULL(root_graph));
  141. memo.clear();
  142. HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
  143. memo.clear();
  144. AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
  145. memo.clear();
  146. UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
  147. memo.clear();
  148. // add make_tuple to the output graph
  149. InsertMakeTupleForOutput(NOT_NULL(root_graph));
  150. // root root_graph valiate,include genearte execute order and so on
  151. RootGraphExecutorValidate(NOT_NULL(root_graph));
  152. // adjust kernel
  153. AdjustKernel(root_graph);
  154. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  155. // Assign parameter keys.
  156. AssignParamKey(root_graph);
  157. #endif
  158. // assign stream
  159. AssignStream(NOT_NULL(root_graph));
  160. // insert profiling point
  161. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
  162. // build kernel
  163. BuildKernel(root_graph);
  164. #ifdef ENABLE_DEBUGGER
  165. if (debugger_) {
  166. debugger_->PreExecute(root_graph);
  167. }
  168. #endif
  169. // alloc mem
  170. MemoryAlloc(root_graph.get());
  171. // generate and load task into device
  172. Load(root_graph);
  173. DumpAllGraphs(all_graphs);
  174. // return the root_graph id to backend
  175. auto graph_id = root_graph->graph_id();
  176. return graph_id;
  177. }
  178. void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
  179. MS_EXCEPTION_IF_NULL(kernel_graph);
  180. auto graph_order = GetGraphOrder(kernel_graph->graph_id());
  181. for (auto graph_id : graph_order) {
  182. auto child_graph = GetGraph(graph_id);
  183. if (child_graph == nullptr) {
  184. continue;
  185. }
  186. if (child_graph->summary_node_exist()) {
  187. kernel_graph->set_summary_node_exist(true);
  188. return;
  189. }
  190. }
  191. kernel_graph->set_summary_node_exist(false);
  192. }
  193. void AscendSession::BuildGraph(GraphId graph_id) {
  194. MS_LOG(INFO) << "Start";
  195. auto graph = GetGraph(graph_id);
  196. MS_EXCEPTION_IF_NULL(graph);
  197. // resource initialize
  198. InitRuntimeResource();
  199. // multiple graph handle
  200. if (graph_id == final_graph_id_) {
  201. if (!graph->executable()) {
  202. return;
  203. }
  204. // insert assigns to child graph
  205. InsertAllAssigns();
  206. SetFinalGraphSummaryFlag(graph);
  207. // OptChildGraphs
  208. auto graph_order = GetGraphOrder(final_graph_id_);
  209. auto &graph_type = GetGraphOrderType(final_graph_id_);
  210. for (size_t i = 0; i < graph_order.size(); i++) {
  211. if (!(graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START)) {
  212. auto child_graph = GetGraph(graph_order[i]);
  213. CompileChildGraph(child_graph);
  214. }
  215. }
  216. SetSummaryNodes(graph.get());
  217. // merge child graph
  218. MergeGraphExecOrder();
  219. } else {
  220. auto single_graph = GetGraph(graph_id);
  221. MS_EXCEPTION_IF_NULL(single_graph);
  222. CompileChildGraph(single_graph);
  223. // set the distinction label of single graph
  224. single_graph->set_stream_distinction_label(graph_id);
  225. single_graph->UpdateExecuteKernelStreamLabel();
  226. }
  227. // adjust execution order because merge child graph and other special operations
  228. AdjustKernel(graph);
  229. // Assign streams for control sink and hccl and so on
  230. AssignStream(NOT_NULL(graph));
  231. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
  232. // build kernel if node is cnode
  233. BuildKernel(graph);
  234. auto ms_context = MsContext::GetInstance();
  235. MS_EXCEPTION_IF_NULL(ms_context);
  236. #ifdef ENABLE_DEBUGGER
  237. if (debugger_) {
  238. debugger_->PreExecute(graph);
  239. }
  240. #endif
  241. if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
  242. MS_LOG(INFO) << "Precompile only, stop in build kernel step";
  243. } else {
  244. // alloc memory, including static memory and dynamic memory
  245. MemoryAlloc(graph.get());
  246. // generate and load task info to device if it is sink mode
  247. Load(graph);
  248. }
  249. // sync the inital const tensor to device
  250. SyncInitialTenosrToDevice();
  251. DumpAllGraphs({graph});
  252. MS_LOG(INFO) << "End";
  253. }
  254. void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
  255. MS_EXCEPTION_IF_NULL(child_graph);
  256. MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString();
  257. opt::AscendBackendIRFusionOptimization(child_graph);
  258. opt::AscendBackendFuseBasicOpt(child_graph, true);
  259. opt::AscendBackendGraphKernelOpt(child_graph, true);
  260. child_graph->SetExecOrderByDefault();
  261. auto context_ptr = MsContext::GetInstance();
  262. MS_EXCEPTION_IF_NULL(context_ptr);
  263. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  264. auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
  265. if (save_graphs_path.empty()) {
  266. save_graphs_path = ".";
  267. }
  268. if (save_graphs) {
  269. std::string file_path =
  270. save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir";
  271. DumpIR(file_path, child_graph);
  272. }
  273. // select kernel build info
  274. SelectKernel(*child_graph);
  275. if (save_graphs) {
  276. std::string file_path =
  277. save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir";
  278. DumpIR(file_path, child_graph);
  279. }
  280. // optimize graph
  281. HardwareOptimize(child_graph);
  282. // assign static memory of parameters
  283. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  284. MS_EXCEPTION_IF_NULL(runtime_instance);
  285. runtime_instance->AssignStaticMemoryInput(child_graph.get());
  286. runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
  287. }
  288. void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  289. VectorRef *const outputs) {
  290. MS_LOG(INFO) << "Start";
  291. auto kernel_graph = GetGraph(graph_id);
  292. MS_EXCEPTION_IF_NULL(kernel_graph);
  293. // if none of child graph and no anf output exists
  294. if (!kernel_graph->executable()) {
  295. MS_LOG(INFO) << "No child graph has anf output";
  296. return;
  297. }
  298. // load input data from user input
  299. LoadInputData(kernel_graph, inputs);
  300. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  301. // Initialize parameter server
  302. InitPSParamAndOptim(kernel_graph, inputs);
  303. #endif
  304. {
  305. // run task on device
  306. Execute(kernel_graph, true);
  307. }
  308. // summary
  309. Summary(kernel_graph.get());
  310. #ifdef ENABLE_DEBUGGER
  311. // load tensor from device for debugger
  312. if (debugger_ && debugger_->debugger_enabled()) {
  313. LoadTensor(kernel_graph);
  314. }
  315. #endif
  316. #ifdef ENABLE_DEBUGGER
  317. // debugger post-execution processing
  318. if (debugger_) {
  319. debugger_->PostExecute();
  320. }
  321. #endif
  322. MS_LOG(INFO) << "Finish!";
  323. }
  324. void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
  325. MS_LOG(INFO) << "Start";
  326. // data layout optimization
  327. opt::RunOpAscendDataLayout(kernel_graph);
  328. // mixed precision optimization
  329. opt::AscendMixPrecision(kernel_graph);
  330. MS_LOG(INFO) << "Finish";
  331. }
  332. bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
  333. return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
  334. }
  335. void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  336. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
  337. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
  338. if (GraphCacheExist(graph_info)) {
  339. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !";
  340. return;
  341. }
  342. // construct graph include one op
  343. auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
  344. MS_EXCEPTION_IF_NULL(graph);
  345. opt::RunOpAscendBackendIRFusionOptimization(graph);
  346. // kernel select
  347. SelectKernel(*graph);
  348. // optimize
  349. RunOpHardwareOptimize(graph);
  350. // init runtime resource
  351. InitRuntimeResource();
  352. // build kernel
  353. RunOpAdjustKernel(graph);
  354. BuildKernel(graph);
  355. run_op_graphs_[graph_info] = graph;
  356. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
  357. }
  358. void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  359. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
  360. auto graph = run_op_graphs_[graph_info];
  361. MS_EXCEPTION_IF_NULL(graph);
  362. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
  363. // malloc mem
  364. RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
  365. // load input data to device
  366. LoadInputData(graph, input_tensors);
  367. // run op
  368. Execute(graph, false);
  369. // get output
  370. if (op_run_info.value != nullptr) {
  371. std::vector<tensor::TensorPtr> pre_output_tensors;
  372. TensorValueToTensor(op_run_info.value, &pre_output_tensors);
  373. for (auto &pre_output : pre_output_tensors) {
  374. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
  375. tensor->set_device_address(pre_output->device_address());
  376. tensor->set_sync_status(kNoNeedSync);
  377. outputs->emplace_back(tensor);
  378. }
  379. } else {
  380. UpdateOutputs(graph, outputs, input_tensors);
  381. }
  382. RunOpMemoryClear(graph.get());
  383. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
  384. }
  385. // compile graph steps
  386. void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
  387. MS_LOG(INFO) << "Start!";
  388. size_t raise_precision_count = 0;
  389. size_t reduce_precision_count = 0;
  390. for (const auto &cnode : kernel_graph.execution_order()) {
  391. auto status = device::ascend::SelectKernelInfo(cnode);
  392. if (status == device::ascend::kStatusRaisePrecision) {
  393. raise_precision_count++;
  394. } else if (status == device::ascend::kStatusReducePrecision) {
  395. reduce_precision_count++;
  396. }
  397. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  398. }
  399. auto ms_context = MsContext::GetInstance();
  400. MS_EXCEPTION_IF_NULL(ms_context);
  401. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  402. if (raise_precision_count > 0) {
  403. MS_LOG(WARNING) << "There has " << raise_precision_count
  404. << " node/nodes used raise precision to selected the kernel!";
  405. }
  406. if (reduce_precision_count > 0) {
  407. MS_LOG(WARNING) << "There has " << reduce_precision_count
  408. << " node/nodes used reduce precision to selected the kernel!";
  409. }
  410. }
  411. MS_LOG(INFO) << "Finish!";
  412. }
  413. void AscendSession::InitRuntimeResource() {
  414. MS_LOG(INFO) << "Start!";
  415. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  416. MS_EXCEPTION_IF_NULL(runtime_instance);
  417. if (!runtime_instance->Init()) {
  418. MS_LOG(EXCEPTION) << "Kernel runtime init error.";
  419. }
  420. DumpJsonParser::GetInstance().Parse();
  421. MS_LOG(INFO) << "Finish!";
  422. }
  423. void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  424. MS_LOG(INFO) << "HardwareOptimize start!";
  425. opt::AscendBackendOptimization(kernel_graph);
  426. opt::AscendGraphKernelCommonProcess(kernel_graph);
  427. opt::AscendBackendFuseBasicOpt(kernel_graph, false);
  428. opt::AscendBackendAddAtomicClean(kernel_graph);
  429. MS_EXCEPTION_IF_NULL(kernel_graph);
  430. kernel_graph->SetExecOrderByDefault();
  431. MS_LOG(INFO) << "HardwareOptimize Finish!";
  432. }
  433. void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  434. MS_LOG(INFO) << "Start!";
  435. opt::HideNopNode(kernel_graph.get());
  436. // Insert CLearZero op
  437. // prepare for next step from json get atomic info
  438. BuildKernel(kernel_graph);
  439. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  440. device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
  441. auto context_ptr = MsContext::GetInstance();
  442. MS_EXCEPTION_IF_NULL(context_ptr);
  443. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  444. auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
  445. if (save_graphs_path.empty()) {
  446. save_graphs_path = ".";
  447. }
  448. if (save_graphs) {
  449. std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir";
  450. DumpIR(file_path, kernel_graph);
  451. }
  452. MS_LOG(INFO) << "Finish!";
  453. }
  454. void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  455. MS_LOG(INFO) << "Start!";
  456. opt::HideNopNode(kernel_graph.get());
  457. // Insert CLearZero op
  458. // prepare for next step from json get atomic info
  459. BuildKernel(kernel_graph);
  460. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  461. MS_LOG(INFO) << "Finish!";
  462. }
  463. void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
  464. MS_LOG(INFO) << "Start!";
  465. device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph);
  466. MS_LOG(INFO) << "Finish!";
  467. }
  468. void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  469. MS_LOG(INFO) << "Start!";
  470. struct timeval start_time, end_time;
  471. (void)gettimeofday(&start_time, nullptr);
  472. auto ret = device::ascend::KernelBuild(kernel_graph.get());
  473. if (!ret) {
  474. MS_LOG(EXCEPTION) << "Kernel build error.";
  475. }
  476. (void)gettimeofday(&end_time, nullptr);
  477. const uint64_t kUSecondInSecond = 1000000;
  478. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  479. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  480. MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost;
  481. MS_LOG(INFO) << "Finish!";
  482. }
  483. void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
  484. MS_LOG(INFO) << "Start!";
  485. MS_EXCEPTION_IF_NULL(kernel_graph);
  486. opt::RemoveNopNode(kernel_graph);
  487. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  488. MS_EXCEPTION_IF_NULL(runtime_instance);
  489. runtime_instance->AssignMemory(kernel_graph);
  490. MS_LOG(INFO) << "Finish!";
  491. }
  492. void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value,
  493. const std::vector<tensor::TensorPtr> &input_tensors,
  494. KernelGraph *kernel_graph) const {
  495. MS_LOG(INFO) << "Start memory alloc!";
  496. MS_EXCEPTION_IF_NULL(kernel_graph);
  497. opt::RemoveNopNode(kernel_graph);
  498. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  499. MS_EXCEPTION_IF_NULL(runtime_instance);
  500. runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph);
  501. MS_LOG(INFO) << "Finish!";
  502. }
  503. void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
  504. MS_EXCEPTION_IF_NULL(kernel_graph);
  505. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  506. MS_EXCEPTION_IF_NULL(runtime_instance);
  507. runtime_instance->RunOpClearMemory(kernel_graph);
  508. }
  509. void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  510. MS_LOG(INFO) << "Start!";
  511. auto context_ptr = MsContext::GetInstance();
  512. MS_EXCEPTION_IF_NULL(context_ptr);
  513. bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  514. (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
  515. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  516. MS_EXCEPTION_IF_NULL(runtime_instance);
  517. bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
  518. if (!ret_ok) {
  519. MS_LOG(EXCEPTION) << "Load task error!";
  520. }
  521. MS_LOG(INFO) << "Finish!";
  522. }
  523. void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
  524. MS_LOG(INFO) << "Start!";
  525. bool is_task_sink = false;
  526. if (is_task) {
  527. auto context_ptr = MsContext::GetInstance();
  528. MS_EXCEPTION_IF_NULL(context_ptr);
  529. is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  530. }
  531. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  532. MS_EXCEPTION_IF_NULL(runtime_instance);
  533. bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
  534. Dump(kernel_graph);
  535. if (!ret_ok) {
  536. MS_LOG(EXCEPTION) << "run task error!";
  537. }
  538. MS_LOG(INFO) << "Finish!";
  539. }
  540. void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  541. MS_LOG(INFO) << "Start!";
  542. MS_EXCEPTION_IF_NULL(kernel_graph);
  543. E2eDumpUtil::DumpData(kernel_graph.get());
  544. MS_LOG(INFO) << "Finish!";
  545. }
  546. void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) {
  547. #ifdef ENABLE_DUMP_IR
  548. auto context_ptr = MsContext::GetInstance();
  549. MS_EXCEPTION_IF_NULL(context_ptr);
  550. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  551. if (!save_graphs) {
  552. return;
  553. }
  554. auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
  555. if (save_graphs_path.empty()) {
  556. save_graphs_path = ".";
  557. }
  558. for (auto &graph : all_graphs) {
  559. MS_EXCEPTION_IF_NULL(graph);
  560. std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(graph->graph_id()) + ".ir";
  561. DumpIR(file_path, graph, true);
  562. DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
  563. }
  564. #endif
  565. }
  566. void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  567. MS_LOG(INFO) << "Start!";
  568. MS_EXCEPTION_IF_NULL(kernel_graph);
  569. #ifdef ENABLE_DEBUGGER
  570. if (debugger_->DebuggerBackendEnabled()) {
  571. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  572. MS_EXCEPTION_IF_NULL(runtime_instance);
  573. DebugServices *debug_services = debugger_->debug_services();
  574. TensorLoader *tensor_loader = debug_services->tensor_loader();
  575. // TensorData will be freed up here
  576. tensor_loader->EmptyTensor();
  577. uint32_t iter_num = tensor_loader->GetIterNum();
  578. tensor_loader->set_iter_num(++iter_num);
  579. (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get());
  580. tensor_loader->EmptyPrevTensor();
  581. }
  582. #endif
  583. MS_LOG(INFO) << "Finish!";
  584. }
  585. void AscendSession::RecurseSetSummaryNodes(KernelGraph *graph,
  586. std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  587. MS_EXCEPTION_IF_NULL(graph);
  588. MS_EXCEPTION_IF_NULL(summary);
  589. // if final graph have no child graph
  590. auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
  591. if (graph_order_iter == graph_execute_orders_.end()) {
  592. SessionBasic::SetSummaryNodes(graph);
  593. auto summary_nodes = graph->summary_nodes();
  594. summary->insert(summary_nodes.begin(), summary_nodes.end());
  595. return;
  596. }
  597. // for every child graph, find summary nodes
  598. auto graph_order = GetGraphOrder(graph->graph_id());
  599. for (size_t i = 0; i < graph_order.size(); i++) {
  600. auto child_graph = GetGraph(graph_order[i]);
  601. if (child_graph == nullptr) {
  602. continue;
  603. }
  604. SessionBasic::SetSummaryNodes(child_graph.get());
  605. auto child_graph_summary = child_graph->summary_nodes();
  606. summary->insert(child_graph_summary.begin(), child_graph_summary.end());
  607. RecurseSetSummaryNodes(child_graph.get(), summary);
  608. }
  609. graph->set_summary_nodes(*summary);
  610. }
  611. void AscendSession::SetSummaryNodes(KernelGraph *graph) {
  612. MS_LOG(DEBUG) << "Update summary Start";
  613. MS_EXCEPTION_IF_NULL(graph);
  614. auto summary_nodes = graph->summary_nodes();
  615. std::map<std::string, std::pair<AnfNodePtr, int>> summary;
  616. summary.insert(summary_nodes.begin(), summary_nodes.end());
  617. RecurseSetSummaryNodes(graph, &summary);
  618. graph->set_summary_nodes(summary);
  619. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  620. }
  621. void AscendSession::InsertAllAssigns() {
  622. std::vector<std::pair<AnfNodePtr, AnfNodePtr>> assigns;
  623. for (auto assign : assigns_) {
  624. auto front_anf = std::get<0>(assign);
  625. auto to_graph_id = std::get<1>(assign);
  626. auto input_idx = std::get<2>(assign);
  627. auto to_graph = GetGraph(to_graph_id);
  628. MS_EXCEPTION_IF_NULL(to_graph);
  629. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  630. if (input_idx >= graph_inputs.size()) {
  631. MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
  632. }
  633. auto backend_parameter = graph_inputs[input_idx];
  634. assigns.emplace_back(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
  635. }
  636. // erase the repeat assign
  637. std::set<std::pair<AnfNodePtr, AnfNodePtr>> inserted_nodes;
  638. for (auto &assign : assigns) {
  639. auto front_anf = assign.first;
  640. auto backend_parameter = assign.second;
  641. auto from_graph_id = GetGraphIdByNode(front_anf);
  642. auto from_graph = GetGraph(from_graph_id);
  643. MS_EXCEPTION_IF_NULL(from_graph);
  644. auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
  645. if (inserted_nodes.find(assign) == inserted_nodes.end()) {
  646. InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
  647. (void)inserted_nodes.insert(assign);
  648. }
  649. }
  650. }
  651. GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
  652. for (const auto &graph_item : graphs_) {
  653. auto graph = graph_item.second;
  654. MS_EXCEPTION_IF_NULL(graph);
  655. // if front_anf is a parameter,the backend parameter may have two
  656. if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
  657. return graph_item.first;
  658. }
  659. }
  660. MS_EXCEPTION_IF_NULL(front_anf);
  661. MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
  662. return kInvalidGraphId;
  663. }
  664. void AscendSession::MergeGraphExecOrder() {
  665. MS_LOG(INFO) << "Start!";
  666. // merge graph order
  667. auto &graph_order = GetGraphOrder(final_graph_id_);
  668. auto &graph_type = GetGraphOrderType(final_graph_id_);
  669. auto final_graph = GetGraph(final_graph_id_);
  670. MS_EXCEPTION_IF_NULL(final_graph);
  671. if (graph_order.empty()) {
  672. MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
  673. return;
  674. }
  675. if (graph_order.size() > 1) {
  676. auto context_ptr = MsContext::GetInstance();
  677. MS_EXCEPTION_IF_NULL(context_ptr);
  678. if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
  679. MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
  680. }
  681. }
  682. // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
  683. SetStreamDistinctionLabel(final_graph, graph_order[0], false);
  684. std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
  685. KernelGraphPtr last_graph = nullptr;
  686. for (size_t i = 0; i < graph_order.size(); i++) {
  687. auto graph_id = graph_order[i];
  688. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  689. continue;
  690. }
  691. auto child_graph = GetGraph(graph_id);
  692. last_graph = child_graph;
  693. MS_EXCEPTION_IF_NULL(child_graph);
  694. auto exec_order = child_graph->execution_order();
  695. MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
  696. (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order),
  697. [&](CNodePtr node) -> CNodePtr {
  698. AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get());
  699. return node;
  700. });
  701. // add all value nodes of child graphs to final graph
  702. for (auto &value_node : child_graph->graph_value_nodes()) {
  703. final_graph->AddValueNodeToGraph(value_node);
  704. }
  705. // copy ref map to final graph
  706. auto child_ref_map = child_graph->GetRefMap();
  707. for (auto &item : child_ref_map) {
  708. if (final_graph->IsInRefOutputMap(item.first)) {
  709. MS_LOG(EXCEPTION) << "The ref pair is already in final graph!";
  710. }
  711. final_graph->AddRefCorrespondPairs(item.first, item.second);
  712. }
  713. }
  714. // set final_exec_order into final graph
  715. MS_EXCEPTION_IF_NULL(final_graph);
  716. DumpGraphExeOrder(final_exec_order);
  717. final_graph->set_execution_order(final_exec_order);
  718. }
  719. void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
  720. MS_EXCEPTION_IF_NULL(from);
  721. MS_EXCEPTION_IF_NULL(to);
  722. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  723. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  724. return;
  725. }
  726. if (from.get() == to.get()) {
  727. return;
  728. }
  729. MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to "
  730. << to->DebugString();
  731. auto graph = graphs_[graph_id];
  732. MS_EXCEPTION_IF_NULL(graph);
  733. // config inputs of assign node
  734. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
  735. // generate a new cnode
  736. auto assign_node = graph->NewCNode(inputs);
  737. MS_EXCEPTION_IF_NULL(assign_node);
  738. assign_node->set_abstract(to->abstract());
  739. // append the assign at the end of from graph
  740. AscendControlParser::InsertDependToGraph(NOT_NULL(graph), NOT_NULL(assign_node));
  741. }
  742. const std::vector<GraphId> &AscendSession::GetGraphOrder(GraphId final_graph_id) const {
  743. auto graph_order_iter = graph_execute_orders_.find(final_graph_id);
  744. if (graph_order_iter == graph_execute_orders_.end()) {
  745. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph";
  746. }
  747. return graph_order_iter->second;
  748. }
  749. const std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id) const {
  750. auto graph_type_iter = graph_order_types_.find(final_graph_id);
  751. if (graph_type_iter == graph_order_types_.end()) {
  752. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_";
  753. }
  754. return graph_type_iter->second;
  755. }
  756. void AscendSession::SyncInitialTenosrToDevice() {
  757. for (auto &item : initial_tenosrs_) {
  758. auto to_graph_id = item.first.first;
  759. auto input_idx = item.first.second;
  760. auto front_tensor = item.second;
  761. auto to_graph = GetGraph(to_graph_id);
  762. MS_EXCEPTION_IF_NULL(to_graph);
  763. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  764. if (input_idx >= graph_inputs.size()) {
  765. MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
  766. }
  767. auto backend_parameter = graph_inputs[input_idx];
  768. // sync data from host to device
  769. MS_EXCEPTION_IF_NULL(front_tensor);
  770. size_t tensor_size = front_tensor->data().nbytes();
  771. auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
  772. MS_EXCEPTION_IF_NULL(addr);
  773. if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
  774. front_tensor->data_type(), front_tensor->data_c())) {
  775. MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
  776. }
  777. }
  778. }
  779. void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
  780. MS_LOG(INFO) << "Start BackendCommonOptimization";
  781. for (auto &graph : all_graphs) {
  782. opt::BackendCommonOptimization(graph);
  783. }
  784. MS_LOG(INFO) << "End.";
  785. }
  786. void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }
  787. void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
  788. AscendControlParser::ExecutorValidate(graph);
  789. }
  790. void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
  791. if (memo->find(graph.get()) != memo->end()) {
  792. return;
  793. }
  794. memo->insert(graph.get());
  795. graph->UpdateChildGraphOrder();
  796. for (auto &child_graph : graph->child_graph_order()) {
  797. CreateMultiBranchOutput(NOT_NULL(child_graph.lock()), memo);
  798. }
  799. std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
  800. auto node_list = GetCNodes(TopoSort(graph->get_return()));
  801. for (auto &node : node_list) {
  802. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
  803. // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
  804. auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
  805. MS_EXCEPTION_IF_NULL(graph->MutableInputs());
  806. graph->AddChildGraphResult(output_param);
  807. std::vector<AnfNodePtr> depend_inputs = {
  808. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))), output_param, node};
  809. auto depend = graph->NewCNode(depend_inputs);
  810. depend->set_abstract(output_param->abstract());
  811. need_replace_list.emplace(node, depend);
  812. MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
  813. << ", depend node is " << depend->DebugString();
  814. // insert assign in order to transfer child graph output to parameter
  815. auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node);
  816. for (auto &child_graph : child_graphs) {
  817. MS_EXCEPTION_IF_NULL(child_graph);
  818. // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert
  819. // assign from condition to true graph
  820. if (memo->find(child_graph) != memo->end()) {
  821. continue;
  822. }
  823. if (child_graph->get_output_null()) {
  824. continue;
  825. }
  826. AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr,
  827. NOT_NULL(child_graph->output()), NOT_NULL(output_param));
  828. }
  829. }
  830. }
  831. // searching for nodes' input to replace call by depend(parameter, call)
  832. for (auto &node : node_list) {
  833. for (size_t i = 0; i < node->size(); ++i) {
  834. auto input = node->input(i);
  835. auto iter = need_replace_list.find(input);
  836. if (iter != need_replace_list.end()) {
  837. node->set_input(i, iter->second);
  838. }
  839. }
  840. }
  841. memo->erase(graph.get());
  842. }
  843. void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
  844. if (memo->find(graph) != memo->end()) {
  845. return;
  846. }
  847. memo->insert(graph.get());
  848. opt::AscendBackendIRFusionOptimization(graph);
  849. opt::AscendBackendFuseBasicOpt(graph, true);
  850. opt::AscendBackendGraphKernelOpt(graph, true);
  851. graph->SetExecOrderByDefault();
  852. auto context_ptr = MsContext::GetInstance();
  853. MS_EXCEPTION_IF_NULL(context_ptr);
  854. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  855. auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
  856. if (save_graphs) {
  857. if (save_graphs_path.empty()) {
  858. save_graphs_path = ".";
  859. }
  860. std::string file_path =
  861. save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
  862. DumpIR(file_path, graph.get());
  863. }
  864. for (auto &child_graph : graph->child_graph_order()) {
  865. IrFusionPass(NOT_NULL(child_graph.lock()), memo);
  866. }
  867. }
  868. void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
  869. MS_LOG(INFO) << "Start select kernel.";
  870. size_t raise_precision_count = 0;
  871. size_t reduce_precision_count = 0;
  872. std::set<KernelGraphPtr> memo;
  873. (void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count);
  874. memo.clear();
  875. auto ms_context = MsContext::GetInstance();
  876. MS_EXCEPTION_IF_NULL(ms_context);
  877. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  878. if (raise_precision_count > 0) {
  879. MS_LOG(WARNING) << "There are " << raise_precision_count
  880. << " node/nodes used raise precision to selected the kernel!";
  881. }
  882. if (reduce_precision_count > 0) {
  883. MS_LOG(WARNING) << "There are " << reduce_precision_count
  884. << " node/nodes used reduce precision to selected the kernel!";
  885. }
  886. }
  887. MS_LOG(INFO) << "Finish!";
  888. }
  889. void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
  890. NotNull<std::set<KernelGraphPtr> *> const memo,
  891. size_t *const raise_precision_count,
  892. size_t *const reduce_precision_count) const {
  893. if (memo->find(graph) != memo->end()) {
  894. return;
  895. }
  896. memo->insert(graph.get());
  897. MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id();
  898. for (const auto &cnode : graph->execution_order()) {
  899. if (AnfAlgo::IsCondControlKernel(cnode)) {
  900. std::vector<KernelGraphPtr> child_graphs;
  901. if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) {
  902. child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
  903. }
  904. for (auto &child_graph : child_graphs) {
  905. RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count);
  906. }
  907. }
  908. auto status = device::ascend::SelectKernelInfo(cnode);
  909. if (status == device::ascend::kStatusRaisePrecision) {
  910. (*raise_precision_count)++;
  911. } else if (status == device::ascend::kStatusReducePrecision) {
  912. (*reduce_precision_count)++;
  913. }
  914. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  915. }
  916. auto context_ptr = MsContext::GetInstance();
  917. MS_EXCEPTION_IF_NULL(context_ptr);
  918. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  919. auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
  920. if (save_graphs) {
  921. if (save_graphs_path.empty()) {
  922. save_graphs_path = ".";
  923. }
  924. std::string file_path =
  925. save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir";
  926. DumpIR(file_path, graph.get());
  927. }
  928. MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id();
  929. }
  930. void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph,
  931. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  932. if (memo->find(graph) != memo->end()) {
  933. return;
  934. }
  935. memo->insert(graph.get());
  936. MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id();
  937. HardwareOptimize(graph.get());
  938. for (auto &child_graph : graph->child_graph_order()) {
  939. HardwareOptimize(NOT_NULL(child_graph.lock()), memo);
  940. }
  941. MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id();
  942. }
  943. void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
  944. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  945. if (memo->find(graph) != memo->end()) {
  946. return;
  947. }
  948. memo->insert(graph.get());
  949. MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id();
  950. // assign static memory for parameters
  951. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  952. MS_EXCEPTION_IF_NULL(runtime_instance);
  953. runtime_instance->ClearGlobalIdleMem();
  954. runtime_instance->AssignStaticMemoryInput(graph.get().get());
  955. runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
  956. for (auto &child_graph : graph->child_graph_order()) {
  957. AssignStaticMemory(NOT_NULL(child_graph.lock()), memo);
  958. }
  959. MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id();
  960. }
  961. void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
  962. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  963. if (memo->find(graph) != memo->end()) {
  964. return;
  965. }
  966. memo->insert(graph.get());
  967. for (auto &child_graph : graph->child_graph_order()) {
  968. std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
  969. MS_EXCEPTION_IF_NULL(child_graph_ptr);
  970. UpdateRefOutputMap(NOT_NULL(child_graph_ptr), memo);
  971. // copy ref map to final graph
  972. auto child_ref_map = child_graph_ptr->GetRefMap();
  973. for (auto &item : child_ref_map) {
  974. if (graph->IsInRefOutputMap(item.first)) {
  975. MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second
  976. << "> is already in " << graph->ToString();
  977. continue;
  978. }
  979. graph->AddRefCorrespondPairs(item.first, item.second);
  980. }
  981. }
  982. }
  983. GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
  984. RunInfer(func_graph, inputs);
  985. return CompileGraph(func_graph);
  986. }
  987. } // namespace session
  988. } // namespace mindspore