You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.cc 40 kB

5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/ascend_session.h"
  17. #include <algorithm>
  18. #include <map>
  19. #include <tuple>
  20. #include <set>
  21. #include <string>
  22. #include <list>
  23. #include "base/core_ops.h"
  24. #include "ir/tensor.h"
  25. #include "ir/anf.h"
  26. #include "common/trans.h"
  27. #include "runtime/device/kernel_runtime.h"
  28. #include "runtime/device/ascend/kernel_select_ascend.h"
  29. #include "runtime/device/ascend/kernel_build_ascend.h"
  30. #include "runtime/device/ascend/ascend_kernel_runtime.h"
  31. #include "backend/optimizer/ascend/ascend_backend_optimization.h"
  32. #include "backend/optimizer/common/common_backend_optimization.h"
  33. #include "runtime/device/kernel_adjust.h"
  34. #include "runtime/device/ascend/ascend_stream_assign.h"
  35. #include "backend/session/anf_runtime_algorithm.h"
  36. #include "utils/ms_utils.h"
  37. #include "backend/optimizer/common/helper.h"
  38. #include "runtime/device/kernel_runtime_manager.h"
  39. #include "utils/config_manager.h"
  40. #include "debug/data_dump/dump_json_parser.h"
  41. #include "debug/tensor_load.h"
  42. #include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
  43. #include "debug/data_dump/e2e_dump_util.h"
  44. #include "debug/anf_ir_dump.h"
  45. #include "debug/dump_proto.h"
  46. #include "toolchain/adx_datadump_server.h"
  47. namespace mindspore {
  48. namespace session {
  49. const size_t kInvalidIndex = SIZE_MAX;
  50. constexpr size_t kReturnDataIndex = 1;
  51. namespace {
  52. void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") {
  53. MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
  54. MS_LOG(INFO) << "[index][stream_label][graph_id][node string]";
  55. int i = 0;
  56. for (auto &cnode : execution_order) {
  57. MS_EXCEPTION_IF_NULL(cnode);
  58. MS_LOG(INFO) << "[ " << i << "]"
  59. << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]"
  60. << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]"
  61. << "[" << cnode->DebugString() << "]";
  62. i++;
  63. }
  64. std::stringstream buf;
  65. buf << "================== execution order ==================\n";
  66. if (!tag.empty()) {
  67. buf << tag << "\n";
  68. }
  69. buf << "execution_order size: " << execution_order.size() << "\n";
  70. i = 0;
  71. for (auto &cnode : execution_order) {
  72. MS_EXCEPTION_IF_NULL(cnode);
  73. buf << i << ":\n";
  74. buf << "\t" << cnode->DebugString() << "\n";
  75. buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n";
  76. buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n";
  77. i++;
  78. }
  79. buf << "================== execution order ==================\n";
  80. }
  81. void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
  82. MS_EXCEPTION_IF_NULL(graph);
  83. if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) {
  84. graph->set_stream_distinction_label(label);
  85. }
  86. }
  87. std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
  88. std::vector<CNodePtr> cnodes = {};
  89. size_t i = 0;
  90. for (const auto &anf : anf_nodes) {
  91. MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString();
  92. MS_EXCEPTION_IF_NULL(anf);
  93. if (anf->isa<CNode>()) {
  94. cnodes.push_back(anf->cast<CNodePtr>());
  95. }
  96. }
  97. return cnodes;
  98. }
  99. void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
  100. auto return_node = root_graph->get_return();
  101. MS_EXCEPTION_IF_NULL(return_node);
  102. if (return_node->size() <= kReturnDataIndex) {
  103. return;
  104. }
  105. auto make_tuple = root_graph->NewCNode(
  106. {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
  107. root_graph->set_output(make_tuple);
  108. }
  109. } // namespace
  110. GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  111. MS_LOG(INFO) << "Start";
  112. // construct graph, if successfully, graph_sum_ + 1
  113. auto graph = ConstructKernelGraph(lst, outputs);
  114. auto graph_id = graph->graph_id();
  115. MS_LOG(INFO) << "Compile graph " << graph_id << " success";
  116. return graph_id;
  117. }
  118. GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
  119. MS_LOG(INFO) << "Start";
  120. std::vector<KernelGraphPtr> all_graphs;
  121. auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
  122. // Update Graph Dynamic Shape Attr
  123. UpdateGraphDynamicShapeAttr(NOT_NULL(root_graph));
  124. root_graph->UpdateGraphDynamicAttr();
  125. BackendOptimization(all_graphs);
  126. // empty graph dont entry to backend
  127. if (root_graph->execution_order().empty()) {
  128. MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
  129. InsertMakeTupleForOutput(NOT_NULL(root_graph));
  130. root_graph->set_executable(false);
  131. InitRuntimeResource();
  132. return root_graph->graph_id();
  133. }
  134. // create parameter for multiple branch
  135. std::set<KernelGraphPtr> memo;
  136. CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
  137. memo.clear();
  138. // insert goto labels and label_sets
  139. LinkChildGraphs(NOT_NULL(root_graph));
  140. // resource initialize
  141. InitRuntimeResource();
  142. IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo));
  143. memo.clear();
  144. SelectKernel(NOT_NULL(root_graph));
  145. memo.clear();
  146. HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
  147. memo.clear();
  148. UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
  149. memo.clear();
  150. // add make_tuple to the output graph
  151. InsertMakeTupleForOutput(NOT_NULL(root_graph));
  152. // root root_graph valiate,include genearte execute order and so on
  153. RootGraphExecutorValidate(NOT_NULL(root_graph));
  154. // adjust kernel
  155. AdjustKernel(root_graph);
  156. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  157. // Assign parameter keys.
  158. AssignParamKey(root_graph);
  159. #endif
  160. // assign stream
  161. AssignStream(NOT_NULL(root_graph));
  162. // insert profiling point
  163. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
  164. // build kernel
  165. BuildKernel(root_graph);
  166. if (debugger_ && debugger_->partial_memory()) {
  167. debugger_->PreExecute(root_graph);
  168. }
  169. SetSummaryNodes(root_graph.get());
  170. // Alloc memory for child graph's inputs
  171. AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
  172. memo.clear();
  173. // Alloc memory for root graph's inputs and node's outputs, workspace
  174. MemoryAlloc(root_graph.get());
  175. // generate and load task into device
  176. Load(root_graph);
  177. DumpAllGraphs(all_graphs);
  178. // return the root_graph id to backend
  179. auto graph_id = root_graph->graph_id();
  180. return graph_id;
  181. }
  182. void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
  183. MS_EXCEPTION_IF_NULL(kernel_graph);
  184. auto graph_order = GetGraphOrder(kernel_graph->graph_id());
  185. for (auto graph_id : graph_order) {
  186. auto child_graph = GetGraph(graph_id);
  187. if (child_graph == nullptr) {
  188. continue;
  189. }
  190. if (child_graph->summary_node_exist()) {
  191. kernel_graph->set_summary_node_exist(true);
  192. return;
  193. }
  194. }
  195. kernel_graph->set_summary_node_exist(false);
  196. }
  197. void AscendSession::BuildGraphImpl(GraphId graph_id) {
  198. MS_LOG(INFO) << "Start";
  199. auto graph = GetGraph(graph_id);
  200. MS_EXCEPTION_IF_NULL(graph);
  201. // resource initialize
  202. InitRuntimeResource();
  203. // multiple graph handle
  204. if (graph_id == final_graph_id_) {
  205. if (!graph->executable()) {
  206. return;
  207. }
  208. // insert assigns to child graph
  209. InsertAllAssigns();
  210. SetFinalGraphSummaryFlag(graph);
  211. // OptChildGraphs
  212. auto graph_order = GetGraphOrder(final_graph_id_);
  213. auto &graph_type = GetGraphOrderType(final_graph_id_);
  214. for (size_t i = 0; i < graph_order.size(); i++) {
  215. if (!(graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START)) {
  216. auto child_graph = GetGraph(graph_order[i]);
  217. CompileChildGraph(child_graph);
  218. }
  219. }
  220. SetSummaryNodes(graph.get());
  221. // merge child graph
  222. MergeGraphExecOrder();
  223. } else {
  224. auto single_graph = GetGraph(graph_id);
  225. MS_EXCEPTION_IF_NULL(single_graph);
  226. CompileChildGraph(single_graph);
  227. // set the distinction label of single graph
  228. single_graph->set_stream_distinction_label(graph_id);
  229. single_graph->UpdateExecuteKernelStreamLabel();
  230. }
  231. // adjust execution order because merge child graph and other special operations
  232. AdjustKernel(graph);
  233. // Assign streams for control sink and hccl and so on
  234. AssignStream(NOT_NULL(graph));
  235. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
  236. // build kernel if node is cnode
  237. BuildKernel(graph);
  238. auto ms_context = MsContext::GetInstance();
  239. MS_EXCEPTION_IF_NULL(ms_context);
  240. if (debugger_ && debugger_->partial_memory()) {
  241. debugger_->PreExecute(graph);
  242. }
  243. if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
  244. MS_LOG(INFO) << "Precompile only, stop in build kernel step";
  245. } else {
  246. // alloc memory, including static memory and dynamic memory
  247. MemoryAlloc(graph.get());
  248. // generate and load task info to device if it is sink mode
  249. Load(graph);
  250. }
  251. // sync the inital const tensor to device
  252. SyncInitialTenosrToDevice();
  253. DumpAllGraphs({graph});
  254. MS_LOG(INFO) << "End";
  255. }
  256. void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
  257. MS_EXCEPTION_IF_NULL(child_graph);
  258. MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString();
  259. opt::AscendBackendIRFusionOptimization(child_graph);
  260. opt::AscendBackendFuseBasicOpt(child_graph, true);
  261. opt::AscendBackendGraphKernelOpt(child_graph, true);
  262. child_graph->SetExecOrderByDefault();
  263. auto context_ptr = MsContext::GetInstance();
  264. MS_EXCEPTION_IF_NULL(context_ptr);
  265. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  266. if (save_graphs) {
  267. std::string file_name = "select_kernel_before_graph_" + std::to_string(child_graph->graph_id()) + ".ir";
  268. DumpIR(file_name, child_graph);
  269. }
  270. // select kernel build info
  271. SelectKernel(*child_graph);
  272. if (save_graphs) {
  273. std::string file_name = "select_kernel_after_graph_" + std::to_string(child_graph->graph_id()) + ".ir";
  274. DumpIR(file_name, child_graph);
  275. }
  276. // optimize graph
  277. HardwareOptimize(child_graph);
  278. // assign static memory of parameters
  279. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  280. MS_EXCEPTION_IF_NULL(runtime_instance);
  281. runtime_instance->AssignStaticMemoryInput(child_graph.get());
  282. runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
  283. }
  284. void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  285. VectorRef *const outputs) {
  286. MS_LOG(INFO) << "Start";
  287. auto kernel_graph = GetGraph(graph_id);
  288. MS_EXCEPTION_IF_NULL(kernel_graph);
  289. // if none of child graph and no anf output exists
  290. if (!kernel_graph->executable()) {
  291. MS_LOG(INFO) << "No child graph has anf output";
  292. return;
  293. }
  294. // load input data from user input
  295. LoadInputData(kernel_graph, inputs);
  296. if (debugger_) {
  297. debugger_->PreExecute(kernel_graph);
  298. }
  299. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  300. // Initialize parameter server
  301. InitPSParamAndOptim(kernel_graph, inputs);
  302. #endif
  303. {
  304. // run task on device
  305. Execute(kernel_graph, true);
  306. }
  307. // summary
  308. Summary(kernel_graph.get());
  309. // load tensor from device for debugger
  310. if (debugger_ && debugger_->debugger_enabled()) {
  311. LoadTensor(kernel_graph);
  312. }
  313. // debugger post-execution processing
  314. if (debugger_) {
  315. debugger_->PostExecute();
  316. }
  317. MS_LOG(INFO) << "Finish!";
  318. }
  319. void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
  320. MS_LOG(INFO) << "Start";
  321. // data layout optimization
  322. opt::RunOpAscendDataLayout(kernel_graph);
  323. // mixed precision optimization
  324. opt::AscendMixPrecision(kernel_graph);
  325. MS_LOG(INFO) << "Finish";
  326. }
  327. bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
  328. return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
  329. }
  330. void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  331. const std::vector<tensor::TensorPtr> &input_tensors,
  332. const std::vector<int> &tensors_mask) {
  333. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
  334. if (GraphCacheExist(graph_info)) {
  335. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !";
  336. return;
  337. }
  338. // construct graph include one op
  339. auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
  340. MS_EXCEPTION_IF_NULL(graph);
  341. opt::RunOpAscendBackendIRFusionOptimization(graph);
  342. // kernel select
  343. SelectKernel(*graph);
  344. // optimize
  345. RunOpHardwareOptimize(graph);
  346. // init runtime resource
  347. InitRuntimeResource();
  348. // build kernel
  349. RunOpAdjustKernel(graph);
  350. BuildKernel(graph);
  351. run_op_graphs_[graph_info] = graph;
  352. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
  353. }
  354. void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  355. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
  356. auto graph = run_op_graphs_[graph_info];
  357. MS_EXCEPTION_IF_NULL(graph);
  358. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
  359. // malloc mem
  360. RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
  361. // load input data to device
  362. LoadInputData(graph, input_tensors);
  363. // run op
  364. Execute(graph, false);
  365. // get output
  366. UpdateOutputs(graph, outputs, input_tensors);
  367. RunOpMemoryClear(graph.get());
  368. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
  369. }
  370. // compile graph steps
  371. void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
  372. MS_LOG(INFO) << "Start!";
  373. size_t raise_precision_count = 0;
  374. size_t reduce_precision_count = 0;
  375. for (const auto &cnode : kernel_graph.execution_order()) {
  376. auto status = device::ascend::SelectKernelInfo(cnode);
  377. if (status == device::ascend::kStatusRaisePrecision) {
  378. raise_precision_count++;
  379. } else if (status == device::ascend::kStatusReducePrecision) {
  380. reduce_precision_count++;
  381. }
  382. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  383. }
  384. auto ms_context = MsContext::GetInstance();
  385. MS_EXCEPTION_IF_NULL(ms_context);
  386. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  387. if (raise_precision_count > 0) {
  388. MS_LOG(WARNING) << "There has " << raise_precision_count
  389. << " node/nodes used raise precision to selected the kernel!";
  390. }
  391. if (reduce_precision_count > 0) {
  392. MS_LOG(WARNING) << "There has " << reduce_precision_count
  393. << " node/nodes used reduce precision to selected the kernel!";
  394. }
  395. }
  396. MS_LOG(INFO) << "Finish!";
  397. }
  398. void DumpInit() {
  399. auto &json_parser = DumpJsonParser::GetInstance();
  400. json_parser.Parse();
  401. if (json_parser.async_dump_enabled()) {
  402. if (AdxDataDumpServerInit() != 0) {
  403. MS_LOG(EXCEPTION) << "Adx data dump server init failed";
  404. }
  405. }
  406. }
  407. void AscendSession::InitRuntimeResource() {
  408. MS_LOG(INFO) << "Start!";
  409. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  410. MS_EXCEPTION_IF_NULL(runtime_instance);
  411. if (!runtime_instance->Init()) {
  412. MS_LOG(EXCEPTION) << "Kernel runtime init error.";
  413. }
  414. DumpInit();
  415. MS_LOG(INFO) << "Finish!";
  416. }
  417. void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  418. MS_LOG(INFO) << "HardwareOptimize start!";
  419. opt::AscendBackendOptimization(kernel_graph);
  420. opt::AscendGraphKernelCommonProcess(kernel_graph);
  421. opt::AscendBackendFuseBasicOpt(kernel_graph, false);
  422. opt::AscendBackendAddAtomicClean(kernel_graph);
  423. MS_EXCEPTION_IF_NULL(kernel_graph);
  424. kernel_graph->SetExecOrderByDefault();
  425. MS_LOG(INFO) << "HardwareOptimize Finish!";
  426. }
  427. void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  428. MS_LOG(INFO) << "Start!";
  429. opt::HideNopNode(kernel_graph.get());
  430. // Insert CLearZero op
  431. // prepare for next step from json get atomic info
  432. BuildKernel(kernel_graph);
  433. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  434. device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
  435. auto context_ptr = MsContext::GetInstance();
  436. MS_EXCEPTION_IF_NULL(context_ptr);
  437. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  438. if (save_graphs) {
  439. DumpIR("after_adjust_kernel.ir", kernel_graph);
  440. }
  441. MS_LOG(INFO) << "Finish!";
  442. }
  443. void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  444. MS_LOG(INFO) << "Start!";
  445. opt::HideNopNode(kernel_graph.get());
  446. // Insert CLearZero op
  447. // prepare for next step from json get atomic info
  448. BuildKernel(kernel_graph);
  449. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  450. MS_LOG(INFO) << "Finish!";
  451. }
  452. void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
  453. MS_LOG(INFO) << "Start!";
  454. device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph);
  455. MS_LOG(INFO) << "Finish!";
  456. }
  457. void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  458. MS_LOG(INFO) << "Start!";
  459. struct timeval start_time, end_time;
  460. (void)gettimeofday(&start_time, nullptr);
  461. auto ret = device::ascend::KernelBuild(kernel_graph.get());
  462. if (!ret) {
  463. MS_LOG(EXCEPTION) << "Kernel build error.";
  464. }
  465. (void)gettimeofday(&end_time, nullptr);
  466. const uint64_t kUSecondInSecond = 1000000;
  467. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  468. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  469. MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost;
  470. MS_LOG(INFO) << "Finish!";
  471. }
  472. void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
  473. MS_LOG(INFO) << "Start!";
  474. MS_EXCEPTION_IF_NULL(kernel_graph);
  475. opt::RemoveNopNode(kernel_graph);
  476. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  477. MS_EXCEPTION_IF_NULL(runtime_instance);
  478. runtime_instance->AssignMemory(kernel_graph);
  479. MS_LOG(INFO) << "Finish!";
  480. }
  481. void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value,
  482. const std::vector<tensor::TensorPtr> &input_tensors,
  483. KernelGraph *kernel_graph) const {
  484. MS_LOG(INFO) << "Start memory alloc!";
  485. MS_EXCEPTION_IF_NULL(kernel_graph);
  486. opt::RemoveNopNode(kernel_graph);
  487. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  488. MS_EXCEPTION_IF_NULL(runtime_instance);
  489. runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph);
  490. MS_LOG(INFO) << "Finish!";
  491. }
  492. void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
  493. MS_EXCEPTION_IF_NULL(kernel_graph);
  494. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  495. MS_EXCEPTION_IF_NULL(runtime_instance);
  496. runtime_instance->RunOpClearMemory(kernel_graph);
  497. }
  498. void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  499. MS_LOG(INFO) << "Start!";
  500. auto context_ptr = MsContext::GetInstance();
  501. MS_EXCEPTION_IF_NULL(context_ptr);
  502. bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  503. (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
  504. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  505. MS_EXCEPTION_IF_NULL(runtime_instance);
  506. bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
  507. if (!ret_ok) {
  508. MS_LOG(EXCEPTION) << "Load task error!";
  509. }
  510. MS_LOG(INFO) << "Finish!";
  511. }
  512. void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
  513. MS_LOG(INFO) << "Start!";
  514. bool is_task_sink = false;
  515. if (is_task) {
  516. auto context_ptr = MsContext::GetInstance();
  517. MS_EXCEPTION_IF_NULL(context_ptr);
  518. is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  519. }
  520. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  521. MS_EXCEPTION_IF_NULL(runtime_instance);
  522. bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
  523. Dump(kernel_graph);
  524. if (!ret_ok) {
  525. MS_LOG(EXCEPTION) << "run task error!";
  526. }
  527. MS_LOG(INFO) << "Finish!";
  528. }
  529. void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  530. MS_LOG(INFO) << "Start!";
  531. MS_EXCEPTION_IF_NULL(kernel_graph);
  532. E2eDumpUtil::DumpData(kernel_graph.get(), device_id_);
  533. MS_LOG(INFO) << "Finish!";
  534. }
  535. void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) {
  536. #ifdef ENABLE_DUMP_IR
  537. auto context_ptr = MsContext::GetInstance();
  538. MS_EXCEPTION_IF_NULL(context_ptr);
  539. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  540. if (!save_graphs) {
  541. return;
  542. }
  543. for (auto &graph : all_graphs) {
  544. MS_EXCEPTION_IF_NULL(graph);
  545. std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
  546. DumpIR(file_name, graph, true);
  547. DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
  548. }
  549. #endif
  550. }
  551. void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  552. MS_LOG(INFO) << "Start!";
  553. MS_EXCEPTION_IF_NULL(kernel_graph);
  554. #ifdef ENABLE_DEBUGGER
  555. if (debugger_->DebuggerBackendEnabled()) {
  556. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  557. MS_EXCEPTION_IF_NULL(runtime_instance);
  558. DebugServices *debug_services = debugger_->debug_services();
  559. MS_EXCEPTION_IF_NULL(debug_services);
  560. TensorLoader *tensor_loader = debug_services->tensor_loader();
  561. MS_EXCEPTION_IF_NULL(tensor_loader);
  562. // TensorData will be freed up here
  563. tensor_loader->EmptyTensor();
  564. uint32_t iter_num = tensor_loader->GetIterNum();
  565. tensor_loader->set_iter_num(++iter_num);
  566. (void)runtime_instance->LoadData(kernel_graph.get());
  567. tensor_loader->EmptyPrevTensor();
  568. }
  569. #endif
  570. MS_LOG(INFO) << "Finish!";
  571. }
  572. void AscendSession::RecurseSetSummaryNodes(KernelGraph *graph,
  573. std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  574. MS_EXCEPTION_IF_NULL(graph);
  575. MS_EXCEPTION_IF_NULL(summary);
  576. // if final graph have no child graph
  577. auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
  578. if (graph_order_iter == graph_execute_orders_.end()) {
  579. SessionBasic::SetSummaryNodes(graph);
  580. auto summary_nodes = graph->summary_nodes();
  581. summary->insert(summary_nodes.begin(), summary_nodes.end());
  582. return;
  583. }
  584. // for every child graph, find summary nodes
  585. auto graph_order = GetGraphOrder(graph->graph_id());
  586. for (size_t i = 0; i < graph_order.size(); i++) {
  587. auto child_graph = GetGraph(graph_order[i]);
  588. if (child_graph == nullptr) {
  589. continue;
  590. }
  591. SessionBasic::SetSummaryNodes(child_graph.get());
  592. auto child_graph_summary = child_graph->summary_nodes();
  593. summary->insert(child_graph_summary.begin(), child_graph_summary.end());
  594. RecurseSetSummaryNodes(child_graph.get(), summary);
  595. }
  596. graph->set_summary_nodes(*summary);
  597. }
  598. void AscendSession::SetSummaryNodes(KernelGraph *graph) {
  599. MS_LOG(DEBUG) << "Update summary Start";
  600. MS_EXCEPTION_IF_NULL(graph);
  601. auto summary_nodes = graph->summary_nodes();
  602. std::map<std::string, std::pair<AnfNodePtr, int>> summary;
  603. summary.insert(summary_nodes.begin(), summary_nodes.end());
  604. RecurseSetSummaryNodes(graph, &summary);
  605. graph->set_summary_nodes(summary);
  606. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  607. }
  608. void AscendSession::InsertAllAssigns() {
  609. std::vector<std::pair<AnfNodePtr, AnfNodePtr>> assigns;
  610. for (auto assign : assigns_) {
  611. auto front_anf = std::get<0>(assign);
  612. auto to_graph_id = std::get<1>(assign);
  613. auto input_idx = std::get<2>(assign);
  614. auto to_graph = GetGraph(to_graph_id);
  615. MS_EXCEPTION_IF_NULL(to_graph);
  616. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  617. if (input_idx >= graph_inputs.size()) {
  618. MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
  619. }
  620. auto backend_parameter = graph_inputs[input_idx];
  621. assigns.emplace_back(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
  622. }
  623. // erase the repeat assign
  624. std::set<std::pair<AnfNodePtr, AnfNodePtr>> inserted_nodes;
  625. for (auto &assign : assigns) {
  626. auto front_anf = assign.first;
  627. auto backend_parameter = assign.second;
  628. auto from_graph_id = GetGraphIdByNode(front_anf);
  629. auto from_graph = GetGraph(from_graph_id);
  630. MS_EXCEPTION_IF_NULL(from_graph);
  631. auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
  632. if (inserted_nodes.find(assign) == inserted_nodes.end()) {
  633. InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
  634. (void)inserted_nodes.insert(assign);
  635. }
  636. }
  637. }
  638. GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
  639. for (const auto &graph_item : graphs_) {
  640. auto graph = graph_item.second;
  641. MS_EXCEPTION_IF_NULL(graph);
  642. // if front_anf is a parameter,the backend parameter may have two
  643. if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
  644. return graph_item.first;
  645. }
  646. }
  647. MS_EXCEPTION_IF_NULL(front_anf);
  648. MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
  649. return kInvalidGraphId;
  650. }
  651. void AscendSession::MergeGraphExecOrder() {
  652. MS_LOG(INFO) << "Start!";
  653. // merge graph order
  654. auto &graph_order = GetGraphOrder(final_graph_id_);
  655. auto &graph_type = GetGraphOrderType(final_graph_id_);
  656. auto final_graph = GetGraph(final_graph_id_);
  657. MS_EXCEPTION_IF_NULL(final_graph);
  658. if (graph_order.empty()) {
  659. MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
  660. return;
  661. }
  662. if (graph_order.size() > 1) {
  663. auto context_ptr = MsContext::GetInstance();
  664. MS_EXCEPTION_IF_NULL(context_ptr);
  665. if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
  666. MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
  667. }
  668. }
  669. // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
  670. SetStreamDistinctionLabel(final_graph, graph_order[0], false);
  671. std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
  672. KernelGraphPtr last_graph = nullptr;
  673. for (size_t i = 0; i < graph_order.size(); i++) {
  674. auto graph_id = graph_order[i];
  675. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  676. continue;
  677. }
  678. auto child_graph = GetGraph(graph_id);
  679. last_graph = child_graph;
  680. MS_EXCEPTION_IF_NULL(child_graph);
  681. auto exec_order = child_graph->execution_order();
  682. MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
  683. (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order),
  684. [&](CNodePtr node) -> CNodePtr {
  685. AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get());
  686. return node;
  687. });
  688. // add all value nodes of child graphs to final graph
  689. for (auto &value_node : child_graph->graph_value_nodes()) {
  690. final_graph->AddValueNodeToGraph(value_node);
  691. }
  692. // copy ref map to final graph
  693. auto child_ref_map = child_graph->GetRefMap();
  694. for (auto &item : child_ref_map) {
  695. if (final_graph->IsInRefOutputMap(item.first)) {
  696. MS_LOG(EXCEPTION) << "The ref pair is already in final graph!";
  697. }
  698. final_graph->AddRefCorrespondPairs(item.first, item.second);
  699. }
  700. }
  701. // set final_exec_order into final graph
  702. MS_EXCEPTION_IF_NULL(final_graph);
  703. DumpGraphExeOrder(final_exec_order);
  704. final_graph->set_execution_order(final_exec_order);
  705. }
  706. void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
  707. MS_EXCEPTION_IF_NULL(from);
  708. MS_EXCEPTION_IF_NULL(to);
  709. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  710. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  711. return;
  712. }
  713. if (from.get() == to.get()) {
  714. return;
  715. }
  716. MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to "
  717. << to->DebugString();
  718. auto graph = graphs_[graph_id];
  719. MS_EXCEPTION_IF_NULL(graph);
  720. // config inputs of assign node
  721. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
  722. // generate a new cnode
  723. auto assign_node = graph->NewCNode(inputs);
  724. MS_EXCEPTION_IF_NULL(assign_node);
  725. assign_node->set_abstract(to->abstract());
  726. // append the assign at the end of from graph
  727. AscendControlParser::InsertDependToGraph(NOT_NULL(graph), NOT_NULL(assign_node));
  728. }
  729. const std::vector<GraphId> &AscendSession::GetGraphOrder(GraphId final_graph_id) const {
  730. auto graph_order_iter = graph_execute_orders_.find(final_graph_id);
  731. if (graph_order_iter == graph_execute_orders_.end()) {
  732. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph";
  733. }
  734. return graph_order_iter->second;
  735. }
  736. const std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id) const {
  737. auto graph_type_iter = graph_order_types_.find(final_graph_id);
  738. if (graph_type_iter == graph_order_types_.end()) {
  739. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_";
  740. }
  741. return graph_type_iter->second;
  742. }
  743. void AscendSession::SyncInitialTenosrToDevice() {
  744. for (auto &item : initial_tenosrs_) {
  745. auto to_graph_id = item.first.first;
  746. auto input_idx = item.first.second;
  747. auto front_tensor = item.second;
  748. auto to_graph = GetGraph(to_graph_id);
  749. MS_EXCEPTION_IF_NULL(to_graph);
  750. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  751. if (input_idx >= graph_inputs.size()) {
  752. MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
  753. }
  754. auto backend_parameter = graph_inputs[input_idx];
  755. // sync data from host to device
  756. MS_EXCEPTION_IF_NULL(front_tensor);
  757. size_t tensor_size = front_tensor->data().nbytes();
  758. auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
  759. MS_EXCEPTION_IF_NULL(addr);
  760. if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
  761. front_tensor->data_type(), front_tensor->data_c())) {
  762. MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
  763. }
  764. }
  765. }
  766. void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
  767. MS_LOG(INFO) << "Start BackendCommonOptimization";
  768. for (auto &graph : all_graphs) {
  769. opt::BackendCommonOptimization(graph);
  770. }
  771. MS_LOG(INFO) << "End.";
  772. }
  773. void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }
  774. void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
  775. AscendControlParser::ExecutorValidate(graph);
  776. }
  777. void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
  778. if (memo->find(graph.get()) != memo->end()) {
  779. return;
  780. }
  781. memo->insert(graph.get());
  782. graph->UpdateChildGraphOrder();
  783. for (auto &child_graph : graph->child_graph_order()) {
  784. CreateMultiBranchOutput(NOT_NULL(child_graph.lock()), memo);
  785. }
  786. std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
  787. auto node_list = GetCNodes(TopoSort(graph->get_return()));
  788. for (auto &node : node_list) {
  789. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
  790. // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
  791. auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
  792. MS_EXCEPTION_IF_NULL(graph->MutableInputs());
  793. graph->AddChildGraphResult(output_param);
  794. std::vector<AnfNodePtr> depend_inputs = {
  795. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))), output_param, node};
  796. auto depend = graph->NewCNode(depend_inputs);
  797. depend->set_abstract(output_param->abstract());
  798. need_replace_list.emplace(node, depend);
  799. MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString()
  800. << ", depend node is " << depend->DebugString();
  801. // insert assign in order to transfer child graph output to parameter
  802. auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node);
  803. for (auto &child_graph : child_graphs) {
  804. MS_EXCEPTION_IF_NULL(child_graph);
  805. // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert
  806. // assign from condition to true graph
  807. if (memo->find(child_graph) != memo->end()) {
  808. continue;
  809. }
  810. if (child_graph->get_output_null()) {
  811. continue;
  812. }
  813. AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr,
  814. NOT_NULL(child_graph->output()), NOT_NULL(output_param));
  815. }
  816. }
  817. }
  818. // searching for nodes' input to replace call by depend(parameter, call)
  819. for (auto &node : node_list) {
  820. for (size_t i = 0; i < node->size(); ++i) {
  821. auto input = node->input(i);
  822. auto iter = need_replace_list.find(input);
  823. if (iter != need_replace_list.end()) {
  824. node->set_input(i, iter->second);
  825. }
  826. }
  827. }
  828. memo->erase(graph.get());
  829. }
  830. void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
  831. if (memo->find(graph) != memo->end()) {
  832. return;
  833. }
  834. memo->insert(graph.get());
  835. opt::AscendBackendIRFusionOptimization(graph);
  836. opt::AscendBackendFuseBasicOpt(graph, true);
  837. opt::AscendBackendGraphKernelOpt(graph, true);
  838. graph->SetExecOrderByDefault();
  839. auto context_ptr = MsContext::GetInstance();
  840. MS_EXCEPTION_IF_NULL(context_ptr);
  841. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  842. if (save_graphs) {
  843. std::string file_name = "select_kernel_before_graph_" + std::to_string(graph->graph_id()) + ".ir";
  844. DumpIR(file_name, graph.get());
  845. }
  846. for (auto &child_graph : graph->child_graph_order()) {
  847. IrFusionPass(NOT_NULL(child_graph.lock()), memo);
  848. }
  849. }
  850. void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
  851. MS_LOG(INFO) << "Start select kernel.";
  852. size_t raise_precision_count = 0;
  853. size_t reduce_precision_count = 0;
  854. std::set<KernelGraphPtr> memo;
  855. (void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count);
  856. memo.clear();
  857. auto ms_context = MsContext::GetInstance();
  858. MS_EXCEPTION_IF_NULL(ms_context);
  859. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  860. if (raise_precision_count > 0) {
  861. MS_LOG(WARNING) << "There are " << raise_precision_count
  862. << " node/nodes used raise precision to selected the kernel!";
  863. }
  864. if (reduce_precision_count > 0) {
  865. MS_LOG(WARNING) << "There are " << reduce_precision_count
  866. << " node/nodes used reduce precision to selected the kernel!";
  867. }
  868. }
  869. MS_LOG(INFO) << "Finish!";
  870. }
  871. void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
  872. NotNull<std::set<KernelGraphPtr> *> const memo,
  873. size_t *const raise_precision_count,
  874. size_t *const reduce_precision_count) const {
  875. if (memo->find(graph) != memo->end()) {
  876. return;
  877. }
  878. memo->insert(graph.get());
  879. MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id();
  880. for (const auto &cnode : graph->execution_order()) {
  881. if (AnfAlgo::IsCondControlKernel(cnode)) {
  882. std::vector<KernelGraphPtr> child_graphs;
  883. if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) {
  884. child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
  885. }
  886. for (auto &child_graph : child_graphs) {
  887. RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count);
  888. }
  889. }
  890. auto status = device::ascend::SelectKernelInfo(cnode);
  891. if (status == device::ascend::kStatusRaisePrecision) {
  892. (*raise_precision_count)++;
  893. } else if (status == device::ascend::kStatusReducePrecision) {
  894. (*reduce_precision_count)++;
  895. }
  896. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  897. }
  898. auto context_ptr = MsContext::GetInstance();
  899. MS_EXCEPTION_IF_NULL(context_ptr);
  900. bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  901. if (save_graphs) {
  902. std::string file_name = "select_kernel_after_graph_" + std::to_string(graph->graph_id()) + ".ir";
  903. DumpIR(file_name, graph.get());
  904. }
  905. MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id();
  906. }
  907. void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph,
  908. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  909. if (memo->find(graph) != memo->end()) {
  910. return;
  911. }
  912. memo->insert(graph.get());
  913. MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id();
  914. HardwareOptimize(graph.get());
  915. for (auto &child_graph : graph->child_graph_order()) {
  916. HardwareOptimize(NOT_NULL(child_graph.lock()), memo);
  917. }
  918. MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id();
  919. }
  920. void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph,
  921. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  922. if (memo->find(graph) != memo->end()) {
  923. return;
  924. }
  925. memo->insert(graph.get());
  926. MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id();
  927. // assign static memory for parameters
  928. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  929. MS_EXCEPTION_IF_NULL(runtime_instance);
  930. runtime_instance->ClearGlobalIdleMem();
  931. runtime_instance->AssignStaticMemoryInput(graph.get().get());
  932. runtime_instance->AssignStaticMemoryValueNode(graph.get().get());
  933. for (auto &child_graph : graph->child_graph_order()) {
  934. AssignStaticMemory(NOT_NULL(child_graph.lock()), memo);
  935. }
  936. MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id();
  937. }
  938. void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
  939. NotNull<std::set<KernelGraphPtr> *> const memo) const {
  940. if (memo->find(graph) != memo->end()) {
  941. return;
  942. }
  943. memo->insert(graph.get());
  944. for (auto &child_graph : graph->child_graph_order()) {
  945. std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
  946. MS_EXCEPTION_IF_NULL(child_graph_ptr);
  947. UpdateRefOutputMap(NOT_NULL(child_graph_ptr), memo);
  948. // copy ref map to final graph
  949. auto child_ref_map = child_graph_ptr->GetRefMap();
  950. for (auto &item : child_ref_map) {
  951. if (graph->IsInRefOutputMap(item.first)) {
  952. MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second
  953. << "> is already in " << graph->ToString();
  954. continue;
  955. }
  956. graph->AddRefCorrespondPairs(item.first, item.second);
  957. }
  958. }
  959. }
  960. GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
  961. RunInfer(func_graph, inputs);
  962. return CompileGraphImpl(func_graph);
  963. }
  964. } // namespace session
  965. } // namespace mindspore