You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.cc 68 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/ascend_session.h"
  17. #include <algorithm>
  18. #include <map>
  19. #include <tuple>
  20. #include <set>
  21. #include <string>
  22. #include <list>
  23. #include "base/core_ops.h"
  24. #include "base/base_ref_utils.h"
  25. #include "ir/tensor.h"
  26. #include "ir/anf.h"
  27. #include "common/trans.h"
  28. #include "runtime/device/kernel_runtime.h"
  29. #include "runtime/device/ascend/kernel_select_ascend.h"
  30. #include "runtime/device/ascend/kernel_build_ascend.h"
  31. #include "runtime/device/ascend/ascend_kernel_runtime.h"
  32. #include "runtime/device/ascend/profiling/profiling_manager.h"
  33. #include "backend/optimizer/ascend/ascend_backend_optimization.h"
  34. #include "backend/optimizer/common/common_backend_optimization.h"
  35. #include "backend/optimizer/ascend/mindir/space_batch_nd_attr_update.h"
  36. #include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h"
  37. #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h"
  38. #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h"
  39. #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h"
  40. #include "backend/optimizer/ascend/mindir/optimizer_unify_output.h"
  41. #include "backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h"
  42. #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h"
  43. #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h"
  44. #include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h"
  45. #include "backend/optimizer/ascend/mindir/bn_grad_unify_mindir.h"
  46. #include "runtime/device/kernel_adjust.h"
  47. #include "runtime/device/ascend/ascend_stream_assign.h"
  48. #include "backend/session/anf_runtime_algorithm.h"
  49. #include "utils/ms_utils.h"
  50. #include "utils/context/graph_kernel_flags.h"
  51. #include "backend/optimizer/common/helper.h"
  52. #include "runtime/device/kernel_runtime_manager.h"
  53. #include "utils/config_manager.h"
  54. #include "debug/data_dump/dump_json_parser.h"
  55. #include "debug/tensor_load.h"
  56. #include "debug/anf_ir_utils.h"
  57. #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
  58. #include "backend/session/ascend_auto_monad.h"
  59. #include "debug/data_dump/e2e_dump.h"
  60. #include "debug/anf_ir_dump.h"
  61. #include "debug/dump_proto.h"
  62. #include "abstract/utils.h"
  63. #ifdef ENABLE_DEBUGGER
  64. #include "debug/debugger/proto_exporter.h"
  65. #else
  66. #include "debug/debugger/proto_exporter_stub.h"
  67. #endif
  68. #include "toolchain/adx_datadump_server.h"
  69. #ifdef ENABLE_DUMP_IR
  70. #include "debug/rdr/running_data_recorder.h"
  71. #include "debug/rdr/recorder_manager.h"
  72. #include "debug/rdr/graph_recorder.h"
  73. #endif
  74. #if ENABLE_CPU && ENABLE_D
  75. #include "ps/util.h"
  76. #include "ps/ps_cache/ps_cache_manager.h"
  77. #endif
  78. #include "runtime/device/ascend/ascend_bucket.h"
  79. #include "profiler/device/common/memory_profiling.h"
  80. using mindspore::device::ascend::ProfilingManager;
  81. using mindspore::profiler::MemoryProfiling;
  82. static constexpr uint32_t kLabelSwitchLabelId = 2;
  83. namespace mindspore {
  84. namespace session {
  85. const size_t kInvalidIndex = SIZE_MAX;
  86. constexpr char SR_TAG[] = "sr_tag";
  87. constexpr char BACKWARD[] = "backward";
  88. namespace {
  89. void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") {
  90. MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
  91. MS_LOG(INFO) << "[index][stream_label][graph_id][node string]";
  92. int i = 0;
  93. for (auto &cnode : execution_order) {
  94. MS_EXCEPTION_IF_NULL(cnode);
  95. MS_LOG(INFO) << "[ " << i << "]"
  96. << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]"
  97. << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]"
  98. << "[" << cnode->DebugString() << "]";
  99. i++;
  100. }
  101. std::stringstream buf;
  102. buf << "================== execution order ==================\n";
  103. if (!tag.empty()) {
  104. buf << tag << "\n";
  105. }
  106. buf << "execution_order size: " << execution_order.size() << "\n";
  107. i = 0;
  108. for (auto &cnode : execution_order) {
  109. MS_EXCEPTION_IF_NULL(cnode);
  110. buf << i << ":\n";
  111. buf << "\t" << cnode->DebugString() << "\n";
  112. buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n";
  113. buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n";
  114. i++;
  115. }
  116. buf << "================== execution order ==================\n";
  117. }
  118. // Handle control flow by auto-monad.
  119. void HandleControlFlow(NotNull<KernelGraphPtr> graph) {
  120. AscendAutoMonad auto_monad(graph);
  121. auto_monad.Run();
  122. }
  123. void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
  124. MS_EXCEPTION_IF_NULL(graph);
  125. if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) {
  126. graph->set_stream_distinction_label(label);
  127. }
  128. }
  129. TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index,
  130. const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
  131. bool *output_is_weight) {
  132. MS_EXCEPTION_IF_NULL(output_is_weight);
  133. const auto &iter = node_output_info.find(kernel_with_index);
  134. if (iter == node_output_info.end()) {
  135. MS_LOG(EXCEPTION) << "Can not find output stub tensor of cnode " << kernel_with_index.first->DebugString();
  136. }
  137. *output_is_weight = iter->second.is_weight;
  138. return iter->second.output_stub_tensor;
  139. }
  140. void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr &kernel,
  141. const std::map<KernelWithIndex, size_t> &cnode_refcount,
  142. std::map<KernelWithIndex, OutputTensorInfo> *op_output_info) {
  143. MS_EXCEPTION_IF_NULL(single_op_graph);
  144. MS_EXCEPTION_IF_NULL(kernel);
  145. MS_EXCEPTION_IF_NULL(op_output_info);
  146. OutputTensorInfo output_tensor_info;
  147. size_t out_idx = 0;
  148. for (const auto &output : single_op_graph->outputs()) {
  149. KernelWithIndex kernel_with_index = std::make_pair(kernel, out_idx++);
  150. if (cnode_refcount.find(kernel_with_index) == cnode_refcount.end()) {
  151. continue;
  152. }
  153. const auto &output_kernel_with_index = AnfAlgo::VisitKernel(output, 0);
  154. const auto &output_node = output_kernel_with_index.first;
  155. const auto &output_index = output_kernel_with_index.second;
  156. auto out_abstract = output_node->abstract();
  157. MS_EXCEPTION_IF_NULL(out_abstract);
  158. if (out_abstract->isa<abstract::AbstractTuple>()) {
  159. out_abstract = out_abstract->cast<abstract::AbstractTuplePtr>()->elements()[output_index];
  160. MS_EXCEPTION_IF_NULL(out_abstract);
  161. }
  162. abstract::AbstractTensorPtr tensor_abstract = out_abstract->cast<abstract::AbstractTensorPtr>();
  163. MS_EXCEPTION_IF_NULL(tensor_abstract);
  164. const auto &infer_type = AnfAlgo::GetOutputInferDataType(output_node, output_index);
  165. tensor::TensorPtr stub_output_tensor =
  166. std::make_shared<tensor::Tensor>(infer_type, tensor_abstract->shape()->shape(), nullptr);
  167. const auto &output_type = AnfAlgo::GetOutputDeviceDataType(output_node, output_index);
  168. const auto &output_shape = AnfAlgo::GetOutputDeviceShape(output_node, output_index);
  169. const auto &output_format = AnfAlgo::GetOutputFormat(output_node, output_index);
  170. tensor::DeviceInfo device_info;
  171. device_info.format_ = output_format;
  172. device_info.data_type_ = TypeIdToType(output_type);
  173. stub_output_tensor->set_device_info(device_info);
  174. device::DeviceAddressPtr device_address =
  175. std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, 0, output_format, output_type);
  176. stub_output_tensor->set_device_address(device_address);
  177. output_tensor_info.output_stub_tensor = stub_output_tensor;
  178. auto kernel_info = dynamic_cast<const device::KernelInfo *>(output_node->kernel_info());
  179. MS_EXCEPTION_IF_NULL(kernel_info);
  180. output_tensor_info.is_weight = !(kernel_info->is_feature_map());
  181. (*op_output_info)[kernel_with_index] = output_tensor_info;
  182. }
  183. }
  184. bool IsBackward(const CNodePtr &cnode) {
  185. auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  186. return prim->HasAttr(BACKWARD);
  187. }
  188. // compare the value of send/recv sr_tag
  189. bool comp(const CNodePtr &node1, const CNodePtr &node2) {
  190. auto prim1 = GetValueNode<PrimitivePtr>(node1->input(0));
  191. MS_EXCEPTION_IF_NULL(prim1);
  192. auto prim2 = GetValueNode<PrimitivePtr>(node1->input(0));
  193. MS_EXCEPTION_IF_NULL(prim2);
  194. auto sr_tag_value1 = prim1->GetAttr(SR_TAG);
  195. MS_EXCEPTION_IF_NULL(sr_tag_value1);
  196. auto sr_tag_value2 = prim2->GetAttr(SR_TAG);
  197. MS_EXCEPTION_IF_NULL(sr_tag_value2);
  198. auto sr_tag1 = GetValue<int64_t>(sr_tag_value1);
  199. auto sr_tag2 = GetValue<int64_t>(sr_tag_value2);
  200. return sr_tag1 < sr_tag2;
  201. }
  202. // Reorder the execution order of send
  203. void ReorderSend(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
  204. auto last_node = op_v.back();
  205. for (auto &node : op_v) {
  206. if (node == last_node) {
  207. continue;
  208. }
  209. auto node_iter = std::find(execution_order->begin(), execution_order->end(), node);
  210. (void)execution_order->erase(node_iter);
  211. }
  212. std::sort(op_v.begin(), op_v.end(), comp);
  213. auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node);
  214. auto node_iter = execution_order->erase(last_node_iter);
  215. // all send will insert the end of the last node
  216. execution_order->insert(node_iter, op_v.begin(), op_v.end());
  217. }
  218. // Reorder the execution order of receive
  219. void ReorderRecv(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
  220. auto begin_node = op_v.front();
  221. for (auto &node : op_v) {
  222. if (node == begin_node) {
  223. continue;
  224. }
  225. auto node_iter = std::find(execution_order->begin(), execution_order->end(), node);
  226. (void)execution_order->erase(node_iter);
  227. }
  228. std::sort(op_v.begin(), op_v.end(), comp);
  229. auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node);
  230. auto node_iter = execution_order->erase(begin_node_iter);
  231. // all receive will insert before the begin node
  232. execution_order->insert(node_iter, op_v.begin(), op_v.end());
  233. }
  234. void ReorderSendRecv(std::vector<CNodePtr> *execution_order) {
  235. std::vector<CNodePtr> forward_send, forward_recv, backward_send, backward_recv;
  236. for (auto &cnode : *execution_order) {
  237. if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) {
  238. backward_send.push_back(cnode);
  239. continue;
  240. } else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) {
  241. forward_send.push_back(cnode);
  242. continue;
  243. }
  244. if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) {
  245. backward_recv.push_back(cnode);
  246. } else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
  247. forward_recv.push_back(cnode);
  248. }
  249. }
  250. if (!forward_send.empty()) {
  251. ReorderSend(execution_order, forward_send);
  252. }
  253. if (!backward_send.empty()) {
  254. ReorderSend(execution_order, backward_send);
  255. }
  256. if (!forward_recv.empty()) {
  257. ReorderRecv(execution_order, forward_recv);
  258. }
  259. if (!backward_recv.empty()) {
  260. ReorderRecv(execution_order, backward_recv);
  261. }
  262. }
  263. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  264. MS_EXCEPTION_IF_NULL(graph);
  265. MS_LOG(INFO) << "Load kInputCtrlTensors";
  266. auto inputs_params = graph->input_ctrl_tensors();
  267. if (inputs_params == nullptr) {
  268. return 0;
  269. }
  270. if (inputs_params->size() < 3) {
  271. MS_LOG(EXCEPTION) << "Illegal inputs_params size";
  272. }
  273. // update current loop tensor to 0 per iterator
  274. auto cur_loop_tensor = (*inputs_params)[0];
  275. MS_EXCEPTION_IF_NULL(cur_loop_tensor);
  276. auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
  277. MS_EXCEPTION_IF_NULL(cur_val);
  278. *cur_val = 0;
  279. cur_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
  280. // set loop_count to zero
  281. MS_EXCEPTION_IF_NULL(inputs);
  282. inputs->push_back(cur_loop_tensor);
  283. // update next loop tensor to 0 per iterator
  284. auto next_loop_tensor = (*inputs_params)[1];
  285. MS_EXCEPTION_IF_NULL(next_loop_tensor);
  286. auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
  287. MS_EXCEPTION_IF_NULL(next_val);
  288. *next_val = 0;
  289. next_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
  290. // set loop_count to zero
  291. MS_EXCEPTION_IF_NULL(inputs);
  292. inputs->push_back(next_loop_tensor);
  293. auto epoch_tensor = (*inputs_params)[2];
  294. MS_EXCEPTION_IF_NULL(epoch_tensor);
  295. auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
  296. MS_EXCEPTION_IF_NULL(epoch_val);
  297. *epoch_val = graph->current_epoch();
  298. epoch_tensor->set_sync_status(kNeedSyncHostToDevice);
  299. inputs->push_back(epoch_tensor);
  300. MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
  301. graph->set_current_epoch(graph->current_epoch() + 1);
  302. return inputs_params->size();
  303. }
  304. bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor) {
  305. auto ms_context = MsContext::GetInstance();
  306. MS_EXCEPTION_IF_NULL(ms_context);
  307. auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
  308. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
  309. return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
  310. }
  311. if (tensor->NeedSyncHostToDevice()) {
  312. return true;
  313. }
  314. auto tensor_address = tensor->device_address();
  315. if (tensor_address != device_address) {
  316. tensor->data_sync(false);
  317. return true;
  318. }
  319. return false;
  320. }
  321. } // namespace
  322. void AscendSession::Init(uint32_t device_id) { InitExecutor(kAscendDevice, device_id); }
  323. void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
  324. auto context_ptr = MsContext::GetInstance();
  325. MS_EXCEPTION_IF_NULL(context_ptr);
  326. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  327. if (save_graphs) {
  328. std::string file_name = "hwopt_d_before_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir";
  329. DumpIR(file_name, graph);
  330. DumpIRProto(graph, "before_unify_mindir_hwopt_" + std::to_string(graph->graph_id()));
  331. }
  332. auto optimizer = std::make_shared<opt::GraphOptimizer>();
  333. auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
  334. unify_mindir_pm->AddPass(std::make_shared<opt::SpaceToBatchNDAttrUpdate>());
  335. unify_mindir_pm->AddPass(std::make_shared<opt::BatchToSpaceNDAttrUpdate>());
  336. unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>());
  337. unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
  338. unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());
  339. unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
  340. unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
  341. unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());
  342. unify_mindir_pm->AddPass(std::make_shared<opt::SliceGradUnifyMindIR>());
  343. unify_mindir_pm->AddPass(std::make_shared<opt::AvgPoolGradUnifyMindIR>());
  344. unify_mindir_pm->AddPass(std::make_shared<opt::FtrlUnifyOutput>());
  345. unify_mindir_pm->AddPass(std::make_shared<opt::MomentumUnifyOutput>());
  346. unify_mindir_pm->AddPass(std::make_shared<opt::RMSPropUnifyOutput>());
  347. unify_mindir_pm->AddPass(std::make_shared<opt::CenteredRMSPropUnifyOutput>());
  348. unify_mindir_pm->AddPass(std::make_shared<opt::FakeLearnedScaleQuantPerLayerGradUnifyMindIR>());
  349. unify_mindir_pm->AddPass(std::make_shared<opt::FakeLearnedScaleQuantPerChannelGradUnifyMindIR>());
  350. auto ms_context = MsContext::GetInstance();
  351. MS_EXCEPTION_IF_NULL(ms_context);
  352. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  353. unify_mindir_pm->AddPass(std::make_shared<opt::DropoutAndDropoutGradUnifyMindIR>());
  354. unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR0>());
  355. unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
  356. unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
  357. unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
  358. } else {
  359. unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
  360. unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
  361. }
  362. unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
  363. unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
  364. unify_mindir_pm->AddPass(std::make_shared<opt::BatchNormGradUnifyMindIR>());
  365. optimizer->AddPassManager(unify_mindir_pm);
  366. (void)optimizer->Optimize(graph);
  367. graph->SetExecOrderByDefault();
  368. if (save_graphs) {
  369. std::string file_name = "hwopt_d_after_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir";
  370. DumpIR(file_name, graph);
  371. }
  372. }
  373. void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  374. const std::vector<tensor::TensorPtr> &inputs_const) const {
  375. std::vector<tensor::TensorPtr> inputs(inputs_const);
  376. size_t input_ctrl_size = 3;
  377. MS_EXCEPTION_IF_NULL(kernel_graph);
  378. if (kernel_graph->input_ctrl_tensors()) {
  379. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  380. }
  381. auto &input_nodes = kernel_graph->input_nodes();
  382. if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) {
  383. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  384. << ", input_ctrl_size:" << input_ctrl_size;
  385. }
  386. auto ms_context = MsContext::GetInstance();
  387. MS_EXCEPTION_IF_NULL(ms_context);
  388. for (size_t i = 0; i < inputs.size(); ++i) {
  389. auto tensor = inputs[i];
  390. MS_EXCEPTION_IF_NULL(tensor);
  391. auto input_node = input_nodes[i];
  392. MS_EXCEPTION_IF_NULL(input_node);
  393. auto size = LongToSize(tensor->data().nbytes());
  394. if (input_node->isa<Parameter>() && input_node->cast<ParameterPtr>()->is_used_by_dynamic_kernel()) {
  395. auto tensor_shape = tensor->shape();
  396. std::vector<size_t> shape_tmp;
  397. (void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize);
  398. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
  399. input_node.get());
  400. size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
  401. }
  402. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
  403. #if (ENABLE_CPU && !_WIN32)
  404. const std::string &param_name = input_node->fullname_with_scope();
  405. if (ps::ps_cache_instance.IsHashTable(param_name)) {
  406. continue;
  407. }
  408. #endif
  409. auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
  410. MS_EXCEPTION_IF_NULL(device_address);
  411. if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size,
  412. tensor->data_type(), tensor->data_c())) {
  413. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  414. }
  415. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
  416. AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
  417. tensor->set_device_address(device_address);
  418. }
  419. }
  420. tensor->set_sync_status(kNoNeedSync);
  421. }
  422. }
  423. GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  424. MS_LOG(INFO) << "Start";
  425. // construct graph, if successfully, graph_sum_ + 1
  426. auto graph = ConstructKernelGraph(lst, outputs);
  427. auto graph_id = graph->graph_id();
  428. InitAllBucket(graph);
  429. MS_LOG(INFO) << "Compile graph " << graph_id << " success";
  430. return graph_id;
  431. }
  432. GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
  433. MS_LOG(INFO) << "Start";
  434. std::vector<KernelGraphPtr> all_graphs;
  435. auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
  436. // Update Graph Dynamic Shape Attr
  437. UpdateAllGraphDynamicShapeAttr(all_graphs);
  438. UnifyMindIR(root_graph);
  439. opt::BackendCommonOptimization(root_graph);
  440. // empty graph dont entry to backend
  441. if (root_graph->execution_order().empty()) {
  442. MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
  443. AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(root_graph));
  444. root_graph->set_executable(false);
  445. InitRuntimeResource();
  446. return root_graph->graph_id();
  447. }
  448. // Handle control flow by auto-monad.
  449. HandleControlFlow(NOT_NULL(root_graph));
  450. // resource initialize
  451. InitRuntimeResource();
  452. std::set<KernelGraphPtr> memo;
  453. IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo));
  454. memo.clear();
  455. SelectKernel(NOT_NULL(root_graph));
  456. memo.clear();
  457. HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
  458. memo.clear();
  459. // load graphs to debugger.
  460. if (debugger_ && debugger_->DebuggerBackendEnabled()) {
  461. LoadGraphsToDbg(NOT_NULL(root_graph), NOT_NULL(&memo));
  462. }
  463. memo.clear();
  464. UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
  465. memo.clear();
  466. // add make_tuple to the output graph
  467. AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(root_graph));
  468. // root root_graph valiate,include genearte execute order and so on
  469. RootGraphExecutorValidate(NOT_NULL(root_graph));
  470. // dump graph before remove nop nodes
  471. auto context_ptr = MsContext::GetInstance();
  472. MS_EXCEPTION_IF_NULL(context_ptr);
  473. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  474. if (save_graphs) {
  475. DumpIRProto(root_graph, "before_removeNop_" + std::to_string(graph_sum_));
  476. }
  477. // adjust kernel
  478. AdjustKernel(root_graph);
  479. // reorder send/recv
  480. auto execution_order = root_graph->execution_order();
  481. ReorderSendRecv(&execution_order);
  482. root_graph->set_execution_order(execution_order);
  483. #if ENABLE_CPU && ENABLE_D
  484. InitPsWorker(root_graph);
  485. #endif
  486. // assign stream
  487. AssignStream(NOT_NULL(root_graph));
  488. // insert profiling point
  489. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
  490. // build kernel
  491. BuildKernel(root_graph);
  492. if (debugger_ && debugger_->partial_memory()) {
  493. debugger_->PreExecute(root_graph, graph_sum_);
  494. }
  495. SetSummaryNodes(root_graph.get());
  496. // Alloc memory for child graph's inputs
  497. AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
  498. memo.clear();
  499. // Alloc memory for root graph's inputs and node's outputs, workspace
  500. MemoryAlloc(root_graph.get());
  501. // generate and load task into device
  502. Load(root_graph);
  503. root_graph->SetInputNodes();
  504. root_graph->SetOptimizerFlag();
  505. DumpAllGraphs(all_graphs);
  506. // Save memory profiling data to proto file
  507. if (ProfilingManager::GetInstance().IsProfiling()) {
  508. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  509. MS_EXCEPTION_IF_NULL(runtime_instance);
  510. uint64_t mem_size = runtime_instance->GetAvailableMemMaxSize();
  511. auto instance = MemoryProfiling::GetInstance();
  512. instance.SetDeviceMemSize(mem_size);
  513. instance.SaveMemoryProfiling();
  514. }
  515. // return the root_graph id to backend
  516. auto graph_id = root_graph->graph_id();
  517. return graph_id;
  518. }
  519. void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
  520. MS_EXCEPTION_IF_NULL(kernel_graph);
  521. auto graph_order = GetGraphOrder(kernel_graph->graph_id());
  522. for (auto graph_id : graph_order) {
  523. auto child_graph = GetGraph(graph_id);
  524. if (child_graph == nullptr) {
  525. continue;
  526. }
  527. if (child_graph->summary_node_exist()) {
  528. kernel_graph->set_summary_node_exist(true);
  529. return;
  530. }
  531. }
  532. kernel_graph->set_summary_node_exist(false);
  533. }
  534. void AscendSession::BuildGraphImpl(GraphId graph_id) {
  535. MS_LOG(INFO) << "Start";
  536. auto graph = GetGraph(graph_id);
  537. MS_EXCEPTION_IF_NULL(graph);
  538. // resource initialize
  539. InitRuntimeResource();
  540. // multiple graph handle
  541. if (graph_id == final_graph_id_) {
  542. if (!graph->executable()) {
  543. return;
  544. }
  545. SetFinalGraphSummaryFlag(graph);
  546. // OptChildGraphs
  547. auto graph_order = GetGraphOrder(final_graph_id_);
  548. auto &graph_type = GetGraphOrderType(final_graph_id_);
  549. for (size_t i = 0; i < graph_order.size(); i++) {
  550. if (!(graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START)) {
  551. auto child_graph = GetGraph(graph_order[i]);
  552. CompileChildGraph(child_graph);
  553. }
  554. }
  555. SetSummaryNodes(graph.get());
  556. // merge child graph
  557. MergeGraphExecOrder();
  558. } else {
  559. auto single_graph = GetGraph(graph_id);
  560. MS_EXCEPTION_IF_NULL(single_graph);
  561. CompileChildGraph(single_graph);
  562. // set the distinction label of single graph
  563. single_graph->set_stream_distinction_label(graph_id);
  564. single_graph->UpdateExecuteKernelStreamLabel();
  565. }
  566. // adjust execution order because merge child graph and other special operations
  567. AdjustKernel(graph);
  568. #if ENABLE_CPU && ENABLE_D
  569. InitPsWorker(graph);
  570. #endif
  571. // Assign streams for control sink and hccl and so on
  572. AssignStream(NOT_NULL(graph));
  573. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
  574. // build kernel if node is cnode
  575. BuildKernel(graph);
  576. auto ms_context = MsContext::GetInstance();
  577. MS_EXCEPTION_IF_NULL(ms_context);
  578. if (debugger_ && debugger_->partial_memory()) {
  579. debugger_->PreExecute(graph, graph_sum_);
  580. }
  581. if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
  582. MS_LOG(INFO) << "Precompile only, stop in build kernel step";
  583. } else {
  584. // alloc memory, including static memory and dynamic memory
  585. MemoryAlloc(graph.get());
  586. // generate and load task info to device if it is sink mode
  587. Load(graph);
  588. }
  589. // sync the initial const tensor to device
  590. SyncInitialTenosrToDevice();
  591. DumpAllGraphs({graph});
  592. MS_LOG(INFO) << "End";
  593. }
  594. void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
  595. MS_EXCEPTION_IF_NULL(child_graph);
  596. MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString();
  597. opt::AscendBackendIRFusionOptimization(child_graph);
  598. child_graph->SetExecOrderByDefault();
  599. auto context_ptr = MsContext::GetInstance();
  600. MS_EXCEPTION_IF_NULL(context_ptr);
  601. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  602. if (save_graphs) {
  603. std::string file_name = "select_kernel_before_graph_" + std::to_string(child_graph->graph_id()) + ".ir";
  604. DumpIR(file_name, child_graph);
  605. }
  606. // select kernel build info
  607. SelectKernel(*child_graph);
  608. if (save_graphs) {
  609. std::string file_name = "select_kernel_after_graph_" + std::to_string(child_graph->graph_id()) + ".ir";
  610. DumpIR(file_name, child_graph);
  611. }
  612. // optimize graph
  613. HardwareOptimize(child_graph);
  614. // assign static memory of parameters
  615. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  616. MS_EXCEPTION_IF_NULL(runtime_instance);
  617. runtime_instance->AssignStaticMemoryInput(child_graph.get());
  618. runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
  619. }
  620. bool AscendSession::IsSupportSummary() { return !device::KernelAdjust::NeedInsertSwitch(); }
  621. void AscendSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
  622. const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {
  623. if (debugger_) {
  624. debugger_->PreExecute(kernel_graph, graph_sum_);
  625. }
  626. #if ENABLE_CPU && ENABLE_D
  627. // Initialize parameter server
  628. InitPSParamAndOptim(kernel_graph, inputs);
  629. std::string channel_name;
  630. if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(kernel_graph, &channel_name)) {
  631. ps::ps_cache_instance.IncreaseGraphStep(channel_name);
  632. }
  633. #endif
  634. }
  635. void AscendSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
  636. const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {
  637. // summary
  638. Summary(kernel_graph.get());
  639. // load tensor from device for debugger
  640. if (debugger_ && debugger_->debugger_enabled()) {
  641. LoadTensor(kernel_graph);
  642. }
  643. // debugger post-execution processing
  644. if (debugger_) {
  645. debugger_->PostExecute();
  646. }
  647. }
  648. void AscendSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { Execute(kernel_graph, true); }
  649. void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
  650. MS_LOG(INFO) << "Start";
  651. // data layout optimization
  652. opt::AscendDataLayout(kernel_graph);
  653. // mixed precision optimization
  654. opt::AscendMixPrecision(kernel_graph);
  655. MS_LOG(INFO) << "Finish";
  656. }
  657. bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
  658. return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
  659. }
  660. void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  661. const std::vector<tensor::TensorPtr> &input_tensors,
  662. const std::vector<int64_t> &tensors_mask) {
  663. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
  664. if (GraphCacheExist(graph_info)) {
  665. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !";
  666. return;
  667. }
  668. const auto &graph = PreBuildOp(op_run_info, graph_info, input_tensors, tensors_mask);
  669. MS_EXCEPTION_IF_NULL(graph);
  670. // init runtime resource
  671. InitRuntimeResource();
  672. // build kernel
  673. RunOpAdjustKernel(graph);
  674. BuildKernel(graph);
  675. run_op_graphs_[graph_info] = graph;
  676. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
  677. }
  678. void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
  679. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  680. const std::vector<int64_t> &tensors_mask) {
  681. MS_EXCEPTION_IF_NULL(input_tensors);
  682. MS_EXCEPTION_IF_NULL(op_run_info);
  683. BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
  684. EraseValueNodeTensor(tensors_mask, input_tensors);
  685. // wait for allreduce
  686. for (auto &tensor : *input_tensors) {
  687. if (tensor->NeedWaitDevice()) {
  688. tensor->WaitDevice();
  689. }
  690. }
  691. // Run op
  692. auto graph = run_op_graphs_[graph_info];
  693. MS_EXCEPTION_IF_NULL(graph);
  694. MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!";
  695. // malloc mem
  696. RunOpRemoveNopNode(graph);
  697. RunOpMemoryAlloc(*input_tensors, graph.get());
  698. RunOpGenKernelEvent(graph.get());
  699. // Build dynamic kernel
  700. if (op_run_info->is_dynamic_shape) {
  701. BuildDynamicKernel(graph);
  702. }
  703. // load input data to device
  704. LoadInputData(graph, *input_tensors);
  705. // run op
  706. Execute(graph, false);
  707. // get output
  708. UpdateOutputs(graph, outputs, *input_tensors);
  709. // update output abstract of dynamic op to op_run_info
  710. if (op_run_info->is_dynamic_shape) {
  711. UpdateOutputAbstract(graph, op_run_info);
  712. }
  713. RunOpMemoryClear(graph.get());
  714. MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!";
  715. }
  716. KernelGraphPtr AscendSession::PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  717. const std::vector<tensor::TensorPtr> &input_tensors,
  718. const std::vector<int64_t> &tensors_mask) {
  719. // Construct graph include one op
  720. auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask, true);
  721. MS_EXCEPTION_IF_NULL(graph);
  722. opt::RunOpAscendBackendIRFusionOptimization(graph);
  723. SelectKernel(*graph);
  724. RunOpHardwareOptimize(graph);
  725. return graph;
  726. }
  727. void AscendSession::GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
  728. const std::vector<tensor::TensorPtr> &graph_inputs,
  729. const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
  730. InputTensorInfo *input_tensor_info) {
  731. MS_EXCEPTION_IF_NULL(cnode);
  732. MS_EXCEPTION_IF_NULL(input_tensor_info);
  733. for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
  734. const auto &input = cnode->input(i);
  735. auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
  736. auto real_input = kernel_with_index.first;
  737. MS_EXCEPTION_IF_NULL(real_input);
  738. tensor::TensorPtr tensor = nullptr;
  739. if (real_input->isa<ValueNode>()) {
  740. tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
  741. input_tensor_info->input_tensors_mask.emplace_back(kParameterDataTensorMask);
  742. } else if (real_input->isa<Parameter>()) {
  743. tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
  744. auto parameter = real_input->cast<ParameterPtr>();
  745. MS_EXCEPTION_IF_NULL(parameter);
  746. input_tensor_info->input_tensors_mask.emplace_back(parameter->has_default() ? kParameterWeightTensorMask
  747. : kParameterDataTensorMask);
  748. } else if (real_input->isa<CNode>()) {
  749. bool output_is_weight = false;
  750. tensor = GetCNodeOutputStubTensor(kernel_with_index, node_output_info, &output_is_weight);
  751. input_tensor_info->input_tensors_mask.emplace_back(output_is_weight ? kParameterWeightTensorMask
  752. : kParameterDataTensorMask);
  753. } else {
  754. MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
  755. }
  756. MS_EXCEPTION_IF_NULL(tensor);
  757. MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
  758. << real_input->fullname_with_scope() << "-" << kernel_with_index.second;
  759. input_tensor_info->input_tensors.emplace_back(tensor);
  760. }
  761. }
  762. void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
  763. const std::vector<tensor::TensorPtr> &graph_inputs,
  764. const std::map<KernelWithIndex, size_t> &cnode_refcount) {
  765. if (built_graph_id_.find(graph_id) != built_graph_id_.end()) {
  766. return;
  767. }
  768. auto graph = GetGraph(graph_id);
  769. MS_EXCEPTION_IF_NULL(graph);
  770. std::map<KernelWithIndex, OutputTensorInfo> op_output_info;
  771. std::vector<CNodePtr> kernels;
  772. std::unordered_map<KernelGraphPtr, std::vector<GraphInfo>> single_op_graphs;
  773. // Collect kernels need to be built in single op graphs
  774. for (const auto &kernel : graph->execution_order()) {
  775. // Generate fake input tensors, tensor masks and input kernel with index
  776. InputTensorInfo input_tensor_info;
  777. GetOpInputStubTensors(kernel, parameter_index, graph_inputs, op_output_info, &input_tensor_info);
  778. // Get OpRunInfo and GraphInfo
  779. OpRunInfo op_run_info;
  780. GetSingleOpRunInfo(kernel, &op_run_info);
  781. if (op_run_info.is_dynamic_shape) {
  782. MS_LOG(INFO) << "BuildOpsInGraph stop, op " << op_run_info.op_name << " is dynamic shape.";
  783. break;
  784. }
  785. const GraphInfo &graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
  786. const auto &single_op_graph_iter = run_op_graphs_.find(graph_info);
  787. if (single_op_graph_iter != run_op_graphs_.end()) {
  788. // if graph of same single op exists, the output tensor of current op should be generated
  789. const auto &single_op_graph = single_op_graph_iter->second;
  790. GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info);
  791. continue;
  792. }
  793. const auto &single_op_graph =
  794. PreBuildOp(op_run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
  795. MS_EXCEPTION_IF_NULL(single_op_graph);
  796. GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info);
  797. opt::HideNopNode(single_op_graph.get());
  798. // The graph info could have been changed in PreBuildOp
  799. const GraphInfo &new_graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
  800. single_op_graphs.insert({single_op_graph, {graph_info, new_graph_info}});
  801. const auto &execution_order = single_op_graph->execution_order();
  802. std::copy(execution_order.begin(), execution_order.end(), std::back_inserter(kernels));
  803. }
  804. InitRuntimeResource();
  805. // Compile all kernels parallel
  806. BuildKernel(kernels);
  807. // Some new kernel may be added after KernelBuildPreprocess, so collect and build kernels again
  808. kernels.clear();
  809. for (const auto &single_op_graph : single_op_graphs) {
  810. device::ascend::KernelBuildPreprocess(single_op_graph.first.get());
  811. const auto &execution_order = single_op_graph.first->execution_order();
  812. std::copy(execution_order.begin(), execution_order.end(), std::back_inserter(kernels));
  813. }
  814. BuildKernel(kernels);
  815. // Record single op graphs in run_op_graphs_ so that these graphs can be reused in BuildOpImpl
  816. for (const auto &single_op_graph : single_op_graphs) {
  817. RunOpMemoryClear(single_op_graph.first.get());
  818. for (const auto &graph_info : single_op_graph.second) {
  819. run_op_graphs_[graph_info] = single_op_graph.first;
  820. MS_LOG(DEBUG) << "Pre build op finished, graph info: " << graph_info;
  821. }
  822. }
  823. built_graph_id_.insert(graph_id);
  824. }
  825. // compile graph steps
  826. void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
  827. MS_LOG(INFO) << "Start!";
  828. size_t raise_precision_count = 0;
  829. size_t reduce_precision_count = 0;
  830. for (const auto &cnode : kernel_graph.execution_order()) {
  831. auto status = device::ascend::SelectKernelInfo(cnode);
  832. AnfAlgo::EraseNodeAttr(kAttrPynativeNextOpName, cnode);
  833. AnfAlgo::EraseNodeAttr(kAttrPynativeNextIndex, cnode);
  834. if (status == device::ascend::kStatusRaisePrecision) {
  835. raise_precision_count++;
  836. } else if (status == device::ascend::kStatusReducePrecision) {
  837. reduce_precision_count++;
  838. }
  839. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  840. }
  841. auto ms_context = MsContext::GetInstance();
  842. MS_EXCEPTION_IF_NULL(ms_context);
  843. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  844. if (raise_precision_count > 0) {
  845. MS_LOG(WARNING) << "There has " << raise_precision_count
  846. << " node/nodes used raise precision to selected the kernel!";
  847. }
  848. if (reduce_precision_count > 0) {
  849. MS_LOG(WARNING) << "There has " << reduce_precision_count
  850. << " node/nodes used reduce precision to selected the kernel!";
  851. }
  852. }
  853. MS_LOG(INFO) << "Finish!";
  854. }
  855. void DumpInit() {
  856. auto &json_parser = DumpJsonParser::GetInstance();
  857. json_parser.Parse();
  858. json_parser.CopyJsonToDir();
  859. if (json_parser.async_dump_enabled()) {
  860. if (AdxDataDumpServerInit() != 0) {
  861. MS_LOG(EXCEPTION) << "Adx data dump server init failed";
  862. }
  863. }
  864. }
  865. void AscendSession::InitRuntimeResource() {
  866. MS_LOG(INFO) << "Start!";
  867. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  868. MS_EXCEPTION_IF_NULL(runtime_instance);
  869. if (!runtime_instance->Init()) {
  870. MS_LOG(EXCEPTION) << "Kernel runtime init error.";
  871. }
  872. DumpInit();
  873. MS_LOG(INFO) << "Finish!";
  874. }
  875. void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  876. MS_LOG(INFO) << "HardwareOptimize start!";
  877. opt::AscendBackendOptimization(kernel_graph);
  878. FinalOptimize(kernel_graph);
  879. GraphKernelOptimize(kernel_graph);
  880. MS_EXCEPTION_IF_NULL(kernel_graph);
  881. kernel_graph->SetExecOrderByDefault();
  882. MS_LOG(INFO) << "HardwareOptimize Finish!";
  883. }
  884. void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  885. if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
  886. return;
  887. }
  888. opt::GraphKernelOptimize(kernel_graph);
  889. kernel_graph->SetExecOrderByDefault();
  890. }
  891. void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  892. MS_LOG(INFO) << "Start!";
  893. opt::HideNopNode(kernel_graph.get());
  894. // Insert CLearZero op
  895. // prepare for next step from json get atomic info
  896. BuildKernel(kernel_graph);
  897. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  898. device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
  899. auto context_ptr = MsContext::GetInstance();
  900. MS_EXCEPTION_IF_NULL(context_ptr);
  901. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  902. if (save_graphs) {
  903. DumpIR("after_adjust_kernel.ir", kernel_graph);
  904. }
  905. MS_LOG(INFO) << "Finish!";
  906. }
  907. void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  908. MS_LOG(INFO) << "Start!";
  909. RunOpHideNopNode(kernel_graph);
  910. // Insert CLearZero op
  911. // prepare for next step from json get atomic info
  912. BuildKernel(kernel_graph);
  913. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  914. MS_LOG(INFO) << "Finish!";
  915. }
  916. void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
  917. MS_LOG(INFO) << "Start!";
  918. device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph);
  919. MS_LOG(INFO) << "Finish!";
  920. }
  921. void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  922. BuildKernel(kernel_graph->execution_order());
  923. }
  924. void AscendSession::BuildKernel(const std::vector<CNodePtr> &kernels) const {
  925. MS_LOG(INFO) << "Start!";
  926. struct timeval start_time, end_time;
  927. (void)gettimeofday(&start_time, nullptr);
  928. auto ret = device::ascend::KernelBuild(kernels);
  929. if (!ret) {
  930. MS_LOG(EXCEPTION) << "Kernel build error.";
  931. }
  932. (void)gettimeofday(&end_time, nullptr);
  933. const uint64_t kUSecondInSecond = 1000000;
  934. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  935. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  936. MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost;
  937. MS_LOG(INFO) << "Finish!";
  938. }
  939. void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  940. MS_LOG(INFO) << "Start!";
  941. MS_EXCEPTION_IF_NULL(kernel_graph);
  942. const auto &kernels = kernel_graph->execution_order();
  943. auto iter = std::find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) {
  944. return AnfAlgo::GetBooleanAttr(kernel, kAttrOutputIsDynamicShape);
  945. });
  946. if (iter == kernels.end()) {
  947. return;
  948. }
  949. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  950. MS_EXCEPTION_IF_NULL(runtime_instance);
  951. if (!runtime_instance->GenDynamicKernel(kernel_graph.get())) {
  952. MS_LOG(DEBUG) << "Graph:" << kernel_graph->graph_id() << " failed to generate dynamic kernel!";
  953. }
  954. MS_LOG(INFO) << "Finish!";
  955. }
  956. static CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
  957. uint32_t node_sizes = kernel_nodes.size();
  958. if (index >= node_sizes - 1) {
  959. MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString();
  960. }
  961. auto kernel = kernel_nodes[index + 1];
  962. if (AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) {
  963. MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: "
  964. << kernel_nodes[index]->DebugString();
  965. }
  966. return kernel;
  967. }
  968. static std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &kernel_cnodes, const uint32_t &back_label,
  969. uint32_t *index, std::vector<CNodePtr> *back) {
  970. MS_EXCEPTION_IF_NULL(index);
  971. MS_EXCEPTION_IF_NULL(back);
  972. std::vector<CNodePtr> front;
  973. std::vector<CNodePtr> back_temp;
  974. bool back_flag = false;
  975. for (uint32_t i = *index; i < kernel_cnodes.size(); i++) {
  976. if (!back_flag) {
  977. front.emplace_back(kernel_cnodes[i]);
  978. } else {
  979. back->emplace_back(kernel_cnodes[i]);
  980. }
  981. if (AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) {
  982. *index = i;
  983. back->insert(back->end(), back_temp.begin(), back_temp.end());
  984. return front;
  985. }
  986. if (AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
  987. back_flag = true;
  988. if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) {
  989. continue;
  990. } else {
  991. auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp);
  992. front.insert(front.end(), temp.begin(), temp.end());
  993. continue;
  994. }
  995. }
  996. }
  997. return front;
  998. }
  999. static void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) {
  1000. MS_EXCEPTION_IF_NULL(kernel_graph);
  1001. if (!kernel_graph->recursive_call()) {
  1002. return;
  1003. }
  1004. auto kernel_cnodes = kernel_graph->mem_reuse_exec_order();
  1005. std::vector<CNodePtr> mem_reuse_order;
  1006. mem_reuse_order.reserve(kernel_cnodes.size());
  1007. for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
  1008. if (!AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) {
  1009. mem_reuse_order.emplace_back(kernel_cnodes[i]);
  1010. continue;
  1011. }
  1012. auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
  1013. std::vector<CNodePtr> back;
  1014. auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back);
  1015. mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end());
  1016. mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end());
  1017. }
  1018. kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
  1019. }
  1020. static void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node,
  1021. std::vector<CNodePtr> *mem_reuse_order) {
  1022. MS_EXCEPTION_IF_NULL(kernel_graph);
  1023. MS_EXCEPTION_IF_NULL(mem_reuse_order);
  1024. auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex);
  1025. auto kernel_cnodes = kernel_graph->execution_order();
  1026. for (auto i = index; i < kernel_cnodes.size(); i++) {
  1027. mem_reuse_order->emplace_back(kernel_cnodes[i]);
  1028. if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) {
  1029. return;
  1030. }
  1031. }
  1032. }
  1033. void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
  1034. MS_EXCEPTION_IF_NULL(kernel_graph);
  1035. if (!kernel_graph->subgraph_multi_call()) {
  1036. return;
  1037. }
  1038. std::unordered_map<uint32_t, uint32_t> label_id_index_map;
  1039. auto kernel_cnodes = kernel_graph->execution_order();
  1040. std::vector<CNodePtr> mem_reuse_order;
  1041. for (size_t i = 0; i < kernel_cnodes.size(); i++) {
  1042. mem_reuse_order.emplace_back(kernel_cnodes[i]);
  1043. if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) &&
  1044. !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
  1045. !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
  1046. auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList);
  1047. for (auto label_id : label_list) {
  1048. if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
  1049. continue;
  1050. }
  1051. auto back_node = GetNextLabelSet(kernel_cnodes, i);
  1052. GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
  1053. }
  1054. continue;
  1055. }
  1056. if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) &&
  1057. !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
  1058. !AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
  1059. auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
  1060. if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
  1061. continue;
  1062. }
  1063. auto back_node = GetNextLabelSet(kernel_cnodes, i);
  1064. GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
  1065. continue;
  1066. }
  1067. if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) &&
  1068. !AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
  1069. auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
  1070. if (label_id_index_map.find(label_id) != label_id_index_map.end()) {
  1071. MS_LOG(EXCEPTION) << "Two labelsets with same label id.";
  1072. }
  1073. label_id_index_map[label_id] = i;
  1074. continue;
  1075. }
  1076. }
  1077. kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
  1078. UnfoldRecursiveExecOrder(kernel_graph);
  1079. }
  1080. void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
  1081. MS_LOG(INFO) << "Start!";
  1082. MS_EXCEPTION_IF_NULL(kernel_graph);
  1083. InitMemReuseExecOrder(kernel_graph);
  1084. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1085. MS_EXCEPTION_IF_NULL(runtime_instance);
  1086. runtime_instance->AssignMemory(kernel_graph);
  1087. MS_LOG(INFO) << "Finish!";
  1088. }
  1089. void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors,
  1090. KernelGraph *kernel_graph) const {
  1091. MS_LOG(INFO) << "Start memory alloc!";
  1092. MS_EXCEPTION_IF_NULL(kernel_graph);
  1093. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1094. MS_EXCEPTION_IF_NULL(runtime_instance);
  1095. runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph);
  1096. MS_LOG(INFO) << "Finish!";
  1097. }
  1098. void AscendSession::RunOpGenKernelEvent(const KernelGraph *graph) const {
  1099. MS_EXCEPTION_IF_NULL(graph);
  1100. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1101. MS_EXCEPTION_IF_NULL(runtime_instance);
  1102. runtime_instance->GenKernelEvents(graph);
  1103. }
  1104. void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
  1105. MS_EXCEPTION_IF_NULL(kernel_graph);
  1106. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1107. MS_EXCEPTION_IF_NULL(runtime_instance);
  1108. runtime_instance->RunOpClearMemory(kernel_graph);
  1109. }
  1110. void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  1111. MS_LOG(INFO) << "Start!";
  1112. auto context_ptr = MsContext::GetInstance();
  1113. MS_EXCEPTION_IF_NULL(context_ptr);
  1114. bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  1115. (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
  1116. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1117. MS_EXCEPTION_IF_NULL(runtime_instance);
  1118. bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
  1119. if (!ret_ok) {
  1120. MS_LOG(EXCEPTION) << "Load task error!";
  1121. }
  1122. MS_LOG(INFO) << "Finish!";
  1123. }
  1124. void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
  1125. MS_LOG(INFO) << "Start!";
  1126. bool is_task_sink = false;
  1127. if (is_task) {
  1128. auto context_ptr = MsContext::GetInstance();
  1129. MS_EXCEPTION_IF_NULL(context_ptr);
  1130. is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  1131. }
  1132. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1133. MS_EXCEPTION_IF_NULL(runtime_instance);
  1134. bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
  1135. Dump(kernel_graph);
  1136. if (!ret_ok) {
  1137. #ifdef ENABLE_DUMP_IR
  1138. mindspore::RDR::TriggerAll();
  1139. #endif
  1140. MS_LOG(EXCEPTION) << "run task error!";
  1141. }
  1142. MS_LOG(INFO) << "Finish!";
  1143. }
  1144. void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  1145. MS_LOG(INFO) << "Start!";
  1146. MS_EXCEPTION_IF_NULL(kernel_graph);
  1147. E2eDump::DumpData(kernel_graph.get(), device_id_);
  1148. MS_LOG(INFO) << "Finish!";
  1149. }
  1150. void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) {
  1151. #ifdef ENABLE_DUMP_IR
  1152. auto context_ptr = MsContext::GetInstance();
  1153. MS_EXCEPTION_IF_NULL(context_ptr);
  1154. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1155. auto &json_parser = DumpJsonParser::GetInstance();
  1156. json_parser.Parse();
  1157. if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() &&
  1158. !mindspore::RecorderManager::Instance().RdrEnable()) {
  1159. return;
  1160. }
  1161. auto kernel_runtime = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1162. MS_EXCEPTION_IF_NULL(kernel_runtime);
  1163. uint32_t device_id = kernel_runtime->device_id();
  1164. for (auto &graph : all_graphs) {
  1165. MS_EXCEPTION_IF_NULL(graph);
  1166. std::string name = "graph_build." + std::to_string(graph->graph_id());
  1167. DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
  1168. mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb");
  1169. if (save_graphs) {
  1170. std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
  1171. DumpIR(file_name, graph, true, kWholeStack);
  1172. DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
  1173. DumpIR("trace_code_graph", graph, true, kWholeStack);
  1174. }
  1175. std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
  1176. if (json_parser.e2e_dump_enabled()) {
  1177. std::string root_dir = json_parser.path() + "/" + json_parser.net_name() + "/device_" + std::to_string(device_id);
  1178. std::string target_dir = root_dir + "/graphs";
  1179. std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
  1180. DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
  1181. DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
  1182. DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
  1183. graph->execution_order());
  1184. } else if (json_parser.async_dump_enabled()) {
  1185. std::string root_dir = json_parser.path() + "/device_" + std::to_string(device_id);
  1186. std::string target_dir = root_dir + "/graphs";
  1187. std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
  1188. DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
  1189. DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
  1190. DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
  1191. graph->execution_order());
  1192. }
  1193. }
  1194. #endif
  1195. }
  1196. void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  1197. MS_LOG(INFO) << "Start!";
  1198. MS_EXCEPTION_IF_NULL(kernel_graph);
  1199. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1200. MS_EXCEPTION_IF_NULL(runtime_instance);
  1201. (void)runtime_instance->LoadData(kernel_graph.get());
  1202. MS_LOG(INFO) << "Finish!";
  1203. }
  1204. void AscendSession::RecurseSetSummaryNodes(KernelGraph *graph,
  1205. std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  1206. MS_EXCEPTION_IF_NULL(graph);
  1207. MS_EXCEPTION_IF_NULL(summary);
  1208. // if final graph have no child graph
  1209. auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
  1210. if (graph_order_iter == graph_execute_orders_.end()) {
  1211. SessionBasic::SetSummaryNodes(graph);
  1212. auto summary_nodes = graph->summary_nodes();
  1213. summary->insert(summary_nodes.begin(), summary_nodes.end());
  1214. return;
  1215. }
  1216. // for every child graph, find summary nodes
  1217. auto graph_order = GetGraphOrder(graph->graph_id());
  1218. for (size_t i = 0; i < graph_order.size(); i++) {
  1219. auto child_graph = GetGraph(graph_order[i]);
  1220. if (child_graph == nullptr) {
  1221. continue;
  1222. }
  1223. SessionBasic::SetSummaryNodes(child_graph.get());
  1224. auto child_graph_summary = child_graph->summary_nodes();
  1225. summary->insert(child_graph_summary.begin(), child_graph_summary.end());
  1226. RecurseSetSummaryNodes(child_graph.get(), summary);
  1227. }
  1228. graph->set_summary_nodes(*summary);
  1229. }
  1230. void AscendSession::SetSummaryNodes(KernelGraph *graph) {
  1231. MS_LOG(DEBUG) << "Update summary Start";
  1232. MS_EXCEPTION_IF_NULL(graph);
  1233. auto summary_nodes = graph->summary_nodes();
  1234. std::map<std::string, std::pair<AnfNodePtr, int>> summary;
  1235. summary.insert(summary_nodes.begin(), summary_nodes.end());
  1236. RecurseSetSummaryNodes(graph, &summary);
  1237. graph->set_summary_nodes(summary);
  1238. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  1239. }
  1240. void AscendSession::MergeGraphExecOrder() {
  1241. MS_LOG(INFO) << "Start!";
  1242. // merge graph order
  1243. auto &graph_order = GetGraphOrder(final_graph_id_);
  1244. auto &graph_type = GetGraphOrderType(final_graph_id_);
  1245. auto final_graph = GetGraph(final_graph_id_);
  1246. MS_EXCEPTION_IF_NULL(final_graph);
  1247. if (graph_order.empty()) {
  1248. MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
  1249. return;
  1250. }
  1251. if (graph_order.size() > 1) {
  1252. auto context_ptr = MsContext::GetInstance();
  1253. MS_EXCEPTION_IF_NULL(context_ptr);
  1254. if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
  1255. MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
  1256. }
  1257. }
  1258. // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
  1259. SetStreamDistinctionLabel(final_graph, graph_order[0], false);
  1260. std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
  1261. KernelGraphPtr last_graph = nullptr;
  1262. for (size_t i = 0; i < graph_order.size(); i++) {
  1263. auto graph_id = graph_order[i];
  1264. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  1265. continue;
  1266. }
  1267. auto child_graph = GetGraph(graph_id);
  1268. last_graph = child_graph;
  1269. MS_EXCEPTION_IF_NULL(child_graph);
  1270. auto exec_order = child_graph->execution_order();
  1271. MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
  1272. (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order),
  1273. [&](CNodePtr node) -> CNodePtr {
  1274. AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get());
  1275. return node;
  1276. });
  1277. // add all value nodes of child graphs to final graph
  1278. for (auto &value_node : child_graph->graph_value_nodes()) {
  1279. final_graph->AddValueNodeToGraph(value_node);
  1280. }
  1281. // copy ref map to final graph
  1282. auto child_ref_map = child_graph->GetRefMap();
  1283. for (auto &item : child_ref_map) {
  1284. if (final_graph->IsInRefOutputMap(item.first)) {
  1285. MS_LOG(EXCEPTION) << "The ref pair is already in final graph!";
  1286. }
  1287. final_graph->AddRefCorrespondPairs(item.first, item.second);
  1288. }
  1289. }
  1290. // set final_exec_order into final graph
  1291. MS_EXCEPTION_IF_NULL(final_graph);
  1292. DumpGraphExeOrder(final_exec_order);
  1293. final_graph->set_execution_order(final_exec_order);
  1294. }
  1295. const std::vector<GraphId> &AscendSession::GetGraphOrder(GraphId final_graph_id) const {
  1296. auto graph_order_iter = graph_execute_orders_.find(final_graph_id);
  1297. if (graph_order_iter == graph_execute_orders_.end()) {
  1298. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph";
  1299. }
  1300. return graph_order_iter->second;
  1301. }
  1302. const std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id) const {
  1303. auto graph_type_iter = graph_order_types_.find(final_graph_id);
  1304. if (graph_type_iter == graph_order_types_.end()) {
  1305. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_";
  1306. }
  1307. return graph_type_iter->second;
  1308. }
  1309. void AscendSession::SyncInitialTenosrToDevice() {
  1310. for (auto &item : initial_tenosrs_) {
  1311. auto to_graph_id = item.first.first;
  1312. auto input_idx = item.first.second;
  1313. auto front_tensor = item.second;
  1314. auto to_graph = GetGraph(to_graph_id);
  1315. MS_EXCEPTION_IF_NULL(to_graph);
  1316. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  1317. if (input_idx >= graph_inputs.size()) {
  1318. MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
  1319. }
  1320. auto backend_parameter = graph_inputs[input_idx];
  1321. // sync data from host to device
  1322. MS_EXCEPTION_IF_NULL(front_tensor);
  1323. size_t tensor_size = front_tensor->data().nbytes();
  1324. auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
  1325. MS_EXCEPTION_IF_NULL(addr);
  1326. if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
  1327. front_tensor->data_type(), front_tensor->data_c())) {
  1328. MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
  1329. }
  1330. }
  1331. }
  1332. void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
  1333. MS_LOG(INFO) << "Start BackendCommonOptimization";
  1334. for (auto &graph : all_graphs) {
  1335. opt::BackendCommonOptimization(graph);
  1336. }
  1337. MS_LOG(INFO) << "End.";
  1338. }
  1339. void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
  1340. AscendAutoMonad auto_monad(graph);
  1341. auto_monad.GenerateExecuteOrder();
  1342. }
  1343. void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
  1344. if (memo->find(graph) != memo->end()) {
  1345. return;
  1346. }
  1347. memo->insert(graph.get());
  1348. opt::AscendBackendIRFusionOptimization(graph);
  1349. graph->SetExecOrderByDefault();
  1350. auto context_ptr = MsContext::GetInstance();
  1351. MS_EXCEPTION_IF_NULL(context_ptr);
  1352. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1353. if (save_graphs) {
  1354. std::string file_name = "select_kernel_before_graph_" + std::to_string(graph->graph_id()) + ".ir";
  1355. DumpIR(file_name, graph.get());
  1356. }
  1357. for (auto &child_graph : graph->child_graph_order()) {
  1358. IrFusionPass(NOT_NULL(child_graph.lock()), memo);
  1359. }
  1360. }
  1361. void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
  1362. MS_LOG(INFO) << "Start select kernel.";
  1363. size_t raise_precision_count = 0;
  1364. size_t reduce_precision_count = 0;
  1365. std::set<KernelGraphPtr> memo;
  1366. (void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count);
  1367. memo.clear();
  1368. auto ms_context = MsContext::GetInstance();
  1369. MS_EXCEPTION_IF_NULL(ms_context);
  1370. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  1371. if (raise_precision_count > 0) {
  1372. MS_LOG(WARNING) << "There are " << raise_precision_count
  1373. << " node/nodes used raise precision to selected the kernel!";
  1374. }
  1375. if (reduce_precision_count > 0) {
  1376. MS_LOG(WARNING) << "There are " << reduce_precision_count
  1377. << " node/nodes used reduce precision to selected the kernel!";
  1378. }
  1379. }
  1380. MS_LOG(INFO) << "Finish!";
  1381. }
  1382. void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
  1383. NotNull<std::set<KernelGraphPtr> *> const memo,
  1384. size_t *const raise_precision_count,
  1385. size_t *const reduce_precision_count) const {
  1386. if (memo->find(graph) != memo->end()) {
  1387. return;
  1388. }
  1389. memo->insert(graph.get());
  1390. MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id();
  1391. for (const auto &cnode : graph->execution_order()) {
  1392. if (AnfAlgo::IsCondControlKernel(cnode)) {
  1393. std::vector<KernelGraphPtr> child_graphs;
  1394. if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) {
  1395. child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
  1396. }
  1397. for (auto &child_graph : child_graphs) {
  1398. RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count);
  1399. }
  1400. }
  1401. auto status = device::ascend::SelectKernelInfo(cnode);
  1402. if (status == device::ascend::kStatusRaisePrecision) {
  1403. (*raise_precision_count)++;
  1404. } else if (status == device::ascend::kStatusReducePrecision) {
  1405. (*reduce_precision_count)++;
  1406. }
  1407. }
  1408. auto context_ptr = MsContext::GetInstance();
  1409. MS_EXCEPTION_IF_NULL(context_ptr);
  1410. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1411. if (save_graphs) {
  1412. std::string file_name = "select_kernel_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
  1413. DumpIR(file_name, graph.get());
  1414. }
  1415. MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id();
  1416. }
  1417. void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph,
  1418. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  1419. if (memo->find(graph) != memo->end()) {
  1420. return;
  1421. }
  1422. memo->insert(graph.get());
  1423. MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id();
  1424. HardwareOptimize(graph.get());
  1425. for (auto &child_graph : graph->child_graph_order()) {
  1426. HardwareOptimize(NOT_NULL(child_graph.lock()), memo);
  1427. }
  1428. MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id();
  1429. }
  1430. void AscendSession::LoadGraphsToDbg(NotNull<KernelGraphPtr> graph,
  1431. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  1432. if (memo->find(graph) != memo->end()) {
  1433. return;
  1434. }
  1435. memo->insert(graph.get());
  1436. MS_LOG(INFO) << "Start to do LoadGraphsToDbg in graph: " << graph->graph_id();
  1437. debugger_->LoadGraphs(graph);
  1438. MS_LOG(INFO) << "graph_sum_: " << graph_sum_;
  1439. for (auto &child_graph : graph->child_graph_order()) {
  1440. LoadGraphsToDbg(NOT_NULL(child_graph.lock()), memo);
  1441. }
  1442. MS_LOG(INFO) << "Finish doing LoadGraphsToDbg in graph: " << graph->graph_id();
  1443. }
  1444. void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
  1445. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  1446. if (memo->find(graph) != memo->end()) {
  1447. return;
  1448. }
  1449. memo->insert(graph.get());
  1450. MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id();
  1451. // assign static memory for parameters
  1452. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1453. MS_EXCEPTION_IF_NULL(runtime_instance);
  1454. runtime_instance->ClearGlobalIdleMem();
  1455. runtime_instance->AssignStaticMemoryInput(graph.get().get());
  1456. runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
  1457. for (auto &child_graph : graph->child_graph_order()) {
  1458. AssignStaticMemory(NOT_NULL(child_graph.lock()), memo);
  1459. }
  1460. MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id();
  1461. }
  1462. void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
  1463. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  1464. if (memo->find(graph) != memo->end()) {
  1465. return;
  1466. }
  1467. memo->insert(graph.get());
  1468. for (auto &child_graph : graph->child_graph_order()) {
  1469. std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
  1470. MS_EXCEPTION_IF_NULL(child_graph_ptr);
  1471. UpdateRefOutputMap(NOT_NULL(child_graph_ptr), memo);
  1472. // copy ref map to final graph
  1473. auto child_ref_map = child_graph_ptr->GetRefMap();
  1474. for (auto &item : child_ref_map) {
  1475. if (graph->IsInRefOutputMap(item.first)) {
  1476. MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second
  1477. << "> is already in " << graph->ToString();
  1478. continue;
  1479. }
  1480. graph->AddRefCorrespondPairs(item.first, item.second);
  1481. }
  1482. }
  1483. }
  1484. void AscendSession::SyncStream() {
  1485. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  1486. MS_EXCEPTION_IF_NULL(runtime_instance);
  1487. auto ret = runtime_instance->SyncStream();
  1488. if (!ret) {
  1489. MS_LOG(EXCEPTION) << "Sync stream error!";
  1490. }
  1491. }
  1492. std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
  1493. return std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
  1494. }
  1495. } // namespace session
  1496. } // namespace mindspore