You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.cc 69 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/ascend_session.h"
  17. #include <algorithm>
  18. #include <map>
  19. #include <tuple>
  20. #include <set>
  21. #include <list>
  22. #include "operator/ops.h"
  23. #include "ir/tensor.h"
  24. #include "ir/anf.h"
  25. #include "common/trans.h"
  26. #include "device/kernel_runtime.h"
  27. #include "device/ascend/kernel_select_ascend.h"
  28. #include "device/ascend/kernel_build_ascend.h"
  29. #include "device/ascend/ascend_kernel_runtime.h"
  30. #include "device/ascend/ascend_device_address.h"
  31. #include "pre_activate/ascend/ascend_backend_optimization.h"
  32. #include "device/kernel_adjust.h"
  33. #include "device/ascend/ascend_stream_assign.h"
  34. #include "device/ascend/ascend_label_assign.h"
  35. #include "predict/predict.h"
  36. #include "session/anf_runtime_algorithm.h"
  37. #include "ir/scalar.h"
  38. #include "debug/anf_ir_dump.h"
  39. #include "debug/anf_ir_utils.h"
  40. #include "common/utils.h"
  41. #include "pre_activate/common/helper.h"
  42. #include "device/kernel_runtime_manager.h"
  43. #include "kernel/tbe/tbe_python_funcs.h"
  44. #include "utils/config_manager.h"
  45. #include "utils/base_ref_extends.h"
  46. namespace mindspore {
  47. namespace session {
  48. const size_t kInvalidIndex = SIZE_MAX;
  49. namespace {
  50. void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order) {
  51. MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
  52. MS_LOG(INFO) << "[index][stream_label][graph_id][node string]";
  53. int i = 0;
  54. for (auto &cnode : execution_order) {
  55. MS_EXCEPTION_IF_NULL(cnode);
  56. MS_LOG(INFO) << "[ " << i << "]"
  57. << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]"
  58. << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]"
  59. << "[" << cnode->DebugString() << "]";
  60. i++;
  61. }
  62. }
  63. void DumpGraphInputArgs(const VectorRef &args) {
  64. MS_LOG(INFO) << "Args size[%lu]" << args.size();
  65. for (size_t i = 0; i < args.size(); i++) {
  66. if (utils::isa<AnfNodePtr>(args[i])) {
  67. auto anf = utils::cast<AnfNodePtr>(args[i]);
  68. MS_EXCEPTION_IF_NULL(anf);
  69. MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString();
  70. } else if (utils::isa<ValuePtr>(args[i])) {
  71. auto value = utils::cast<ValuePtr>(args[i]);
  72. MS_EXCEPTION_IF_NULL(value);
  73. MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString();
  74. } else {
  75. MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString();
  76. }
  77. }
  78. }
  79. void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) {
  80. MS_EXCEPTION_IF_NULL(graph);
  81. if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) {
  82. graph->set_stream_distinction_label(label);
  83. }
  84. }
  85. std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) {
  86. MS_EXCEPTION_IF_NULL(graph);
  87. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  88. auto valid_inputs = graph->valid_inputs();
  89. size_t real_args_size = 0;
  90. std::vector<BaseRef> real_args = {};
  91. for (size_t i = 0; i < args.size(); i++) {
  92. if (utils::isa<AnfNodePtr>(args[i])) {
  93. auto tmp_args = AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem});
  94. for (auto &real_arg : tmp_args) {
  95. auto anf_node = utils::cast<AnfNodePtr>(real_arg);
  96. MS_EXCEPTION_IF_NULL(anf_node);
  97. auto abstract = anf_node->abstract();
  98. MS_EXCEPTION_IF_NULL(abstract);
  99. // create multiple parameters if is a tuple output real kernel
  100. if (abstract->isa<abstract::AbstractTuple>() &&
  101. !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
  102. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  103. real_args_size += tuple_abstract->size();
  104. continue;
  105. }
  106. real_args_size += 1;
  107. real_args.push_back(real_arg);
  108. }
  109. } else {
  110. real_args_size += 1;
  111. real_args.push_back(args[i]);
  112. }
  113. }
  114. if (graph_inputs.size() != valid_inputs.size()) {
  115. MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
  116. << ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
  117. }
  118. if (real_args_size != graph_inputs.size()) {
  119. for (size_t j = 0; j < valid_inputs.size(); j++) {
  120. if (valid_inputs[j]) {
  121. MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
  122. }
  123. }
  124. MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
  125. << " not equal";
  126. }
  127. return real_args;
  128. }
  129. void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
  130. MS_EXCEPTION_IF_NULL(kernel_graph);
  131. // clear input parameter memory resource
  132. for (const auto &input_node : kernel_graph->inputs()) {
  133. MS_EXCEPTION_IF_NULL(input_node);
  134. AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
  135. }
  136. // clear input value node memory resource
  137. for (const auto &value_node : kernel_graph->graph_value_nodes()) {
  138. MS_EXCEPTION_IF_NULL(value_node);
  139. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  140. }
  141. for (const auto &cnode : kernel_graph->execution_order()) {
  142. MS_EXCEPTION_IF_NULL(cnode);
  143. // clear output memory resource
  144. for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
  145. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  146. }
  147. // clear workspace memory resource
  148. auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
  149. MS_EXCEPTION_IF_NULL(kernel_mod);
  150. auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
  151. for (size_t index = 0; index < workspace_lists.size(); ++index) {
  152. AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
  153. }
  154. }
  155. }
  156. std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
  157. std::vector<CNodePtr> cnodes = {};
  158. size_t i = 0;
  159. for (const auto &anf : anf_nodes) {
  160. MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
  161. MS_EXCEPTION_IF_NULL(anf);
  162. if (anf->isa<CNode>()) {
  163. cnodes.push_back(anf->cast<CNodePtr>());
  164. }
  165. }
  166. return cnodes;
  167. }
  168. static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePtr> &cnodes,
  169. const std::set<PrimitivePtr> &cut_prims) {
  170. size_t after_cut_index = 0;
  171. std::vector<std::vector<CNodePtr>> ret;
  172. for (size_t i = 0; i < cnodes.size(); ++i) {
  173. bool is_cut_node = false;
  174. for (auto &prim : cut_prims) {
  175. if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) {
  176. is_cut_node = true;
  177. break;
  178. }
  179. }
  180. if (is_cut_node) {
  181. // is call and not switch call,cut to 3 lists
  182. if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) {
  183. // if is not a call,cut to 2 lists
  184. ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i);
  185. after_cut_index = i;
  186. } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) {
  187. ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i);
  188. ret.emplace_back(1, cnodes[i]);
  189. after_cut_index = i + 1;
  190. continue;
  191. }
  192. }
  193. // get last child graph list
  194. if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) {
  195. ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end());
  196. continue;
  197. }
  198. }
  199. return ret;
  200. }
  201. static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, const std::vector<AnfNodePtr> &args,
  202. KernelGraph *child_graph) {
  203. MS_EXCEPTION_IF_NULL(child_graph);
  204. MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
  205. if (args.empty()) {
  206. return;
  207. }
  208. if (parameters.size() != args.size()) {
  209. MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
  210. << " and args size:" << args.size() << " not equal!";
  211. }
  212. child_graph->SetExecOrderByDefault();
  213. for (size_t i = 0; i < parameters.size(); i++) {
  214. if (args[i] == parameters[i]) {
  215. child_graph->SetRealInput(parameters[i], args[i]);
  216. MS_LOG(INFO) << "Parameter and arg are same";
  217. continue;
  218. }
  219. // if arg is a parameter ,then reuse this parameter
  220. if (args[i]->isa<Parameter>()) {
  221. MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
  222. << " reuse parameter:" << args[i]->DebugString()
  223. << " of graph:" << AnfAlgo::GetGraphId(args[i].get());
  224. child_graph->ReplaceNode(parameters[i], args[i]);
  225. continue;
  226. }
  227. child_graph->SetRealInput(parameters[i], args[i]);
  228. }
  229. }
  230. // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
  231. // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
  232. static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
  233. auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
  234. for (auto &call_node : call_nodes) {
  235. MS_EXCEPTION_IF_NULL(call_node);
  236. auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
  237. if (child_graphs.size() == 1) {
  238. MS_EXCEPTION_IF_NULL(child_graphs[0]);
  239. std::vector<AnfNodePtr> real_args =
  240. std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
  241. std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
  242. BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get());
  243. call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
  244. } else if (child_graphs.size() == 2) {
  245. auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
  246. auto switch_node = call_node->input(1);
  247. MS_EXCEPTION_IF_NULL(switch_node);
  248. auto switch_cnode = switch_node->cast<CNodePtr>();
  249. MS_EXCEPTION_IF_NULL(switch_cnode);
  250. auto partial = switch_cnode->input(input_index);
  251. MS_EXCEPTION_IF_NULL(partial);
  252. auto partial_cnode = partial->cast<CNodePtr>();
  253. MS_EXCEPTION_IF_NULL(partial_cnode);
  254. auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
  255. partial_cnode->set_inputs(
  256. std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
  257. return ret;
  258. };
  259. BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
  260. BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
  261. }
  262. }
  263. }
  264. static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
  265. const NotNull<std::set<KernelGraphPtr> *> memo) {
  266. memo->insert(graph.get());
  267. MS_LOG(INFO) << "start graph id:" << graph->graph_id();
  268. for (auto &child_graph : graph->child_graph_order()) {
  269. if (memo->find(child_graph) != memo->end()) {
  270. MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
  271. << ",parent graph:" << graph->parent_graph()->graph_id();
  272. continue;
  273. }
  274. RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo);
  275. }
  276. // this action should from bottom to top
  277. graph->UpdateCallRealInput();
  278. }
  279. } // namespace
  280. GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  281. MS_LOG(INFO) << "start";
  282. // construct graph, if successfully, graph_sum_ + 1
  283. auto graph = ConstructKernelGraph(lst, outputs);
  284. auto graph_id = graph->graph_id();
  285. MS_LOG(INFO) << "Compile graph " << graph_id << " success";
  286. return graph_id;
  287. }
  288. GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
  289. MS_LOG(INFO) << "start";
  290. auto graph = ConstructKernelGraph(func_graph);
  291. // split switch
  292. SplitGraphs(NOT_NULL(graph));
  293. // insert goto labels and label_sets
  294. LinkChildGraphs(NOT_NULL(graph));
  295. // resource initialize
  296. InitRuntimeResource();
  297. // assign label
  298. AssignLabel(NOT_NULL(graph));
  299. // recurse compile child graph
  300. std::set<KernelGraphPtr> memo;
  301. RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo));
  302. // root graph valiate,include genearte execute order and so on
  303. RootGraphExecutorValidate(NOT_NULL(graph));
  304. // adjust kernel
  305. AdjustKernel(graph);
  306. // assign stream
  307. AssignStream(graph);
  308. // build kernel
  309. BuildKernel(graph);
  310. // alloc mem
  311. MemoryAlloc(graph.get());
  312. // task generate
  313. GenerateTaskInfo(graph);
  314. // load task into device
  315. LoadTask(graph);
  316. // return the graph id to backend
  317. auto graph_id = graph->graph_id();
  318. return graph_id;
  319. }
  320. void AscendSession::BuildGraph(GraphId graph_id) {
  321. MS_LOG(INFO) << "start";
  322. auto graph = GetGraph(graph_id);
  323. MS_EXCEPTION_IF_NULL(graph);
  324. // resource initialize
  325. InitRuntimeResource();
  326. // multiple graph handle
  327. if (graph_id == final_graph_id_) {
  328. if (!graph->executable()) {
  329. return;
  330. }
  331. // insert assigns to child graph
  332. InsertAllAssigns();
  333. // insert switch and active to child graph
  334. MergeSwitchCompile();
  335. // OptChildGraphs
  336. auto graph_order = GetGraphOrder(final_graph_id_);
  337. auto &graph_type = GetGraphOrderType(final_graph_id_);
  338. for (size_t i = 0; i < graph_order.size(); i++) {
  339. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  340. continue;
  341. }
  342. MS_LOG(INFO) << "Start build child graph " << graph_order[i];
  343. auto child_graph = GetGraph(graph_order[i]);
  344. CompileChildGraph(child_graph);
  345. }
  346. // merge child graph
  347. MergeGraphExecOrder();
  348. } else {
  349. auto single_graph = GetGraph(graph_id);
  350. CompileChildGraph(single_graph);
  351. // set the distinction label of single graph
  352. single_graph->set_stream_distinction_label(graph_id);
  353. single_graph->UpdateExecuteKernelStreamLabel();
  354. }
  355. // adjust execution order because merge child graph and other special operations
  356. AdjustKernel(graph);
  357. // Assign streams for control sink and hccl and so on
  358. AssignStream(graph);
  359. device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
  360. // build kernel if node is cnode
  361. BuildKernel(graph);
  362. auto ms_context = MsContext::GetInstance();
  363. MS_EXCEPTION_IF_NULL(ms_context);
  364. if (ms_context->precompile_only()) {
  365. MS_LOG(INFO) << "Precompile only, stop in build kernel step";
  366. } else {
  367. // alloc memory, including static memory and dynamic memory
  368. MemoryAlloc(graph.get());
  369. // generate task info for task sink mode
  370. GenerateTaskInfo(graph);
  371. // load task info to device if it is sink mode
  372. LoadTask(graph);
  373. }
  374. // sync the inital const tensor to device
  375. SyncInitialTenosrToDevice();
  376. ExportChildGraphs(graph_id);
  377. MS_LOG(INFO) << "end";
  378. }
  379. void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
  380. MS_EXCEPTION_IF_NULL(child_graph);
  381. MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString();
  382. opt::AscendBackendIRFusionOptimization(child_graph);
  383. // select kernel build info
  384. SelectKernel(*child_graph);
  385. // convert kernel Graph to model
  386. predictmodel::StepConvertGraph(child_graph);
  387. // optimize graph
  388. HardwareOptimize(child_graph);
  389. // assign static memory of parameters
  390. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  391. MS_EXCEPTION_IF_NULL(runtime_instance);
  392. runtime_instance->AssignStaticMemoryInput(child_graph.get());
  393. runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
  394. }
  395. void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  396. VectorRef *const outputs) {
  397. MS_LOG(INFO) << "start";
  398. auto kernel_graph = GetGraph(graph_id);
  399. DumpIR("./run_graph.ir", kernel_graph);
  400. MS_EXCEPTION_IF_NULL(kernel_graph);
  401. // if none of child graph and no anf output exists
  402. if (!kernel_graph->executable()) {
  403. MS_LOG(INFO) << "No child graph has anf output";
  404. UpdateOutputs(kernel_graph, outputs, inputs);
  405. return;
  406. }
  407. // load input data from user input
  408. LoadInputData(kernel_graph, inputs);
  409. // convert inputs to model
  410. predictmodel::StepConvertWeight(inputs);
  411. {
  412. py::gil_scoped_release release;
  413. // run task on device
  414. ExecTask(kernel_graph);
  415. }
  416. // get result from device
  417. UpdateOutputs(kernel_graph, outputs, inputs);
  418. // summary
  419. Summary(kernel_graph.get());
  420. // dump used for debug
  421. Dump(kernel_graph);
  422. MS_LOG(INFO) << "Finish!";
  423. }
  424. void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const {
  425. MS_LOG(INFO) << "Start";
  426. // data layout optimization
  427. opt::RunOpAscendDataLayout(kernel_graph);
  428. // mixed precision optimization
  429. opt::AscendMixPrecision(kernel_graph);
  430. MS_LOG(INFO) << "Finish";
  431. }
  432. void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  433. MS_LOG(INFO) << "Start!";
  434. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  435. MS_EXCEPTION_IF_NULL(runtime_instance);
  436. bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
  437. if (!ret_ok) {
  438. MS_LOG(EXCEPTION) << "run task error!";
  439. }
  440. MS_LOG(INFO) << "Finish!";
  441. }
  442. bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
  443. if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
  444. return true;
  445. }
  446. return false;
  447. }
  448. void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  449. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
  450. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
  451. if (GraphCacheExist(graph_info)) {
  452. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !";
  453. return;
  454. }
  455. // construct graph include one op
  456. auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
  457. MS_EXCEPTION_IF_NULL(graph);
  458. opt::RunOpAscendBackendIRFusionOptimization(graph);
  459. // kernel select
  460. SelectKernel(*graph);
  461. // optimize
  462. RunOpHardwareOptimize(graph);
  463. // init runtime resource
  464. InitRuntimeResource();
  465. // build kernel
  466. RunOpAdjustKernel(graph);
  467. BuildKernel(graph);
  468. run_op_graphs_[graph_info] = graph;
  469. MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
  470. }
  471. py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  472. const std::vector<tensor::TensorPtr> &input_tensors) {
  473. auto graph = run_op_graphs_[graph_info];
  474. MS_EXCEPTION_IF_NULL(graph);
  475. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
  476. // malloc mem
  477. RunOpMemoryAlloc(input_tensors, graph.get());
  478. // load input data to device
  479. LoadInputData(graph, input_tensors);
  480. // run op
  481. RunOpExecTask(graph);
  482. // get output
  483. VectorRef outputs;
  484. UpdateOutputs(graph, &outputs, input_tensors);
  485. // trans output to tuple
  486. auto output_tensors = TransformBaseRefListToTuple(outputs);
  487. if (!utils::isa<PyObjectRef>(output_tensors) ||
  488. !py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
  489. MS_LOG(EXCEPTION) << "The output tensors should be a tuple !";
  490. }
  491. py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
  492. py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
  493. ClearRunOpMemoryResource(graph);
  494. MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
  495. return tuple_tensors;
  496. }
  497. // compile graph steps
  498. void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
  499. MS_LOG(INFO) << "Start!";
  500. size_t raise_precision_count = 0;
  501. size_t reduce_precision_count = 0;
  502. for (const auto &cnode : kernel_graph.execution_order()) {
  503. auto status = device::ascend::SelectKernelInfo(cnode);
  504. if (status == device::ascend::kStatusRaisePrecision) {
  505. raise_precision_count++;
  506. } else if (status == device::ascend::kStatusReducePrecision) {
  507. reduce_precision_count++;
  508. }
  509. MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
  510. }
  511. if (raise_precision_count > 0) {
  512. MS_LOG(WARNING) << "There has " << raise_precision_count
  513. << " node/nodes used raise precision to selected the kernel!";
  514. }
  515. if (reduce_precision_count > 0) {
  516. MS_LOG(WARNING) << "There has " << reduce_precision_count
  517. << " node/nodes used reduce precision to selected the kernel!";
  518. }
  519. MS_LOG(INFO) << "Finish!";
  520. }
  521. void AscendSession::InitRuntimeResource() {
  522. MS_LOG(INFO) << "Start!";
  523. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  524. MS_EXCEPTION_IF_NULL(runtime_instance);
  525. if (!runtime_instance->Init()) {
  526. MS_LOG(EXCEPTION) << "Kernel runtime init error.";
  527. }
  528. MS_LOG(INFO) << "Finish!";
  529. }
  530. void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  531. MS_LOG(INFO) << "HardwareOptimize start!";
  532. opt::AscendBackendOptimization(kernel_graph);
  533. MS_EXCEPTION_IF_NULL(kernel_graph);
  534. kernel_graph->SetExecOrderByDefault();
  535. MS_LOG(INFO) << "HardwareOptimize Finish!";
  536. }
  537. void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  538. MS_LOG(INFO) << "Start!";
  539. device::KernelAdjust::GetInstance().Reorder(kernel_graph);
  540. opt::HideNopNode(kernel_graph.get());
  541. // Insert CLearZero op
  542. // prepare for next step from json get atomic info
  543. BuildKernel(kernel_graph);
  544. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  545. device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
  546. auto context_ptr = MsContext::GetInstance();
  547. MS_EXCEPTION_IF_NULL(context_ptr);
  548. bool save_graphs = context_ptr->save_graphs_flag();
  549. auto save_graphs_path = context_ptr->save_graphs_path();
  550. if (save_graphs_path.empty()) {
  551. save_graphs_path = ".";
  552. }
  553. if (save_graphs) {
  554. std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir";
  555. DumpIR(file_path, kernel_graph);
  556. }
  557. MS_LOG(INFO) << "Finish!";
  558. }
  559. void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  560. MS_LOG(INFO) << "Start!";
  561. opt::HideNopNode(kernel_graph.get());
  562. // Insert CLearZero op
  563. // prepare for next step from json get atomic info
  564. BuildKernel(kernel_graph);
  565. device::ascend::KernelBuildPreprocess(kernel_graph.get());
  566. MS_LOG(INFO) << "Finish!";
  567. }
  568. void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  569. MS_LOG(INFO) << "Start!";
  570. device::ascend::AscendStreamAssign::GetInstance().AssignStreamNew(kernel_graph);
  571. MS_LOG(INFO) << "Finish!";
  572. }
  573. void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const {
  574. MS_LOG(INFO) << "Start!";
  575. device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
  576. MS_LOG(INFO) << "Finish!";
  577. }
  578. void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  579. MS_LOG(INFO) << "Start!";
  580. struct timeval start_time, end_time;
  581. (void)gettimeofday(&start_time, nullptr);
  582. auto ret = device::ascend::KernelBuild(kernel_graph.get());
  583. if (!ret) {
  584. MS_LOG(EXCEPTION) << "Kernel build error.";
  585. }
  586. (void)gettimeofday(&end_time, nullptr);
  587. const uint64_t kUSecondInSecond = 1000000;
  588. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  589. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  590. MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost;
  591. MS_LOG(INFO) << "Finish!";
  592. }
  593. void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
  594. MS_LOG(INFO) << "Start!";
  595. MS_EXCEPTION_IF_NULL(kernel_graph);
  596. opt::RemoveNopNode(kernel_graph);
  597. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  598. MS_EXCEPTION_IF_NULL(runtime_instance);
  599. runtime_instance->AssignMemory(kernel_graph);
  600. MS_LOG(INFO) << "Finish!";
  601. }
  602. void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors,
  603. KernelGraph *kernel_graph) const {
  604. MS_LOG(INFO) << "Start memory alloc!";
  605. MS_EXCEPTION_IF_NULL(kernel_graph);
  606. opt::RemoveNopNode(kernel_graph);
  607. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  608. MS_EXCEPTION_IF_NULL(runtime_instance);
  609. runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph);
  610. MS_LOG(INFO) << "Finish!";
  611. }
  612. void AscendSession::GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  613. MS_LOG(INFO) << "Start!";
  614. (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
  615. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  616. MS_EXCEPTION_IF_NULL(runtime_instance);
  617. bool ret_ok = runtime_instance->GenTask(kernel_graph.get());
  618. if (!ret_ok) {
  619. MS_LOG(EXCEPTION) << "Generate task error!";
  620. }
  621. MS_LOG(INFO) << "Finish!";
  622. }
  623. void AscendSession::LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  624. MS_LOG(INFO) << "Start!";
  625. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  626. MS_EXCEPTION_IF_NULL(runtime_instance);
  627. bool ret_ok = runtime_instance->LoadTask(kernel_graph.get());
  628. if (!ret_ok) {
  629. MS_LOG(EXCEPTION) << "Load task error!";
  630. }
  631. MS_LOG(INFO) << "Finish!";
  632. }
  633. void AscendSession::ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  634. MS_LOG(INFO) << "Start!";
  635. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  636. MS_EXCEPTION_IF_NULL(runtime_instance);
  637. bool ret_ok = runtime_instance->Run(kernel_graph.get());
  638. if (!ret_ok) {
  639. MS_LOG(EXCEPTION) << "run task error!";
  640. }
  641. MS_LOG(INFO) << "Finish!";
  642. }
  643. void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
  644. MS_LOG(INFO) << "Start!";
  645. MS_EXCEPTION_IF_NULL(kernel_graph);
  646. auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
  647. MS_EXCEPTION_IF_NULL(runtime_instance);
  648. (void)runtime_instance->DumpData(kernel_graph.get());
  649. MS_LOG(INFO) << "Finish!";
  650. }
  651. void AscendSession::ExportChildGraphs(const GraphId graph_id) {
  652. #ifdef ENABLE_DUMP_IR
  653. auto context_ptr = MsContext::GetInstance();
  654. MS_EXCEPTION_IF_NULL(context_ptr);
  655. bool save_graphs = context_ptr->save_graphs_flag();
  656. if (!save_graphs) {
  657. return;
  658. }
  659. auto save_graphs_path = context_ptr->save_graphs_path();
  660. if (save_graphs_path.empty()) {
  661. save_graphs_path = ".";
  662. }
  663. if (graph_id == final_graph_id_) {
  664. auto &graph_order = GetGraphOrder(final_graph_id_);
  665. auto &graph_type = GetGraphOrderType(final_graph_id_);
  666. for (size_t i = 0; i < graph_order.size(); i++) {
  667. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  668. continue;
  669. }
  670. auto child_graph = GetGraph(graph_order[i]);
  671. MS_LOG(DEBUG) << "Start export child graph " << graph_order[i];
  672. std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(child_graph->graph_id()) + ".ir";
  673. DumpIR(file_path, child_graph, true);
  674. DumpIRProto(child_graph, "vm_build_" + std::to_string(child_graph->graph_id()));
  675. MS_LOG(DEBUG) << "End export child graph " << graph_order[i];
  676. }
  677. }
  678. #endif
  679. }
  680. GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
  681. MS_LOG(INFO) << "Start! Args size " << args.size();
  682. auto final_graph = NewKernelGraph();
  683. final_graph_id_ = final_graph->graph_id();
  684. MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success";
  685. // init private variables and bind them with final_graph_id
  686. graph_execute_orders_[final_graph_id_] = std::vector<GraphId>();
  687. graph_order_types_[final_graph_id_] = std::vector<GraphType>();
  688. for (const auto &parameter : args) {
  689. MS_EXCEPTION_IF_NULL(parameter);
  690. if (!parameter->isa<Parameter>()) {
  691. MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!";
  692. }
  693. AnfNodePtr parameter_backend = nullptr;
  694. // if function return UINT_MAX,the parameter is not exist in child graph
  695. auto parameter_belong_graph_id = GetGraphIdByNode(parameter);
  696. if (parameter_belong_graph_id == kInvalidGraphId) {
  697. parameter_backend = CreateNewParameterFromParameter(parameter, true, final_graph.get());
  698. final_graph->FrontBackendlMapAdd(parameter, parameter_backend);
  699. MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph";
  700. } else {
  701. // parametr is a parameter of child graph
  702. auto graph = GetGraph(parameter_belong_graph_id);
  703. MS_EXCEPTION_IF_NULL(graph);
  704. MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph ["
  705. << parameter_belong_graph_id << "]";
  706. parameter_backend = graph->GetBackendAnfByFrontAnf(parameter);
  707. // add parameter in backend to final graph inputs
  708. auto final_graph_inputs = final_graph->MutableInputs();
  709. MS_EXCEPTION_IF_NULL(final_graph_inputs);
  710. final_graph_inputs->push_back(parameter_backend);
  711. }
  712. MS_EXCEPTION_IF_NULL(parameter_backend);
  713. MS_LOG(INFO) << "parameter backend " << parameter_backend->DebugString() << " belong_graph_id "
  714. << AnfAlgo::GetGraphId(parameter_backend.get());
  715. }
  716. MS_LOG(INFO) << "End final_graph_id " << final_graph_id_;
  717. return final_graph_id_;
  718. }
  719. void AscendSession::GetSummaryNodes(const KernelGraph *graph,
  720. std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  721. MS_LOG(DEBUG) << "Update summary Start";
  722. MS_EXCEPTION_IF_NULL(graph);
  723. MS_EXCEPTION_IF_NULL(summary);
  724. summary->clear();
  725. // if final graph have no child graph
  726. auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
  727. if (graph_order_iter == graph_execute_orders_.end()) {
  728. SessionBasic::GetSummaryNodes(graph, summary);
  729. return;
  730. }
  731. // for every child graph, find summary nodes
  732. auto graph_order = GetGraphOrder(graph->graph_id());
  733. for (size_t i = 0; i < graph_order.size(); i++) {
  734. auto child_graph = GetGraph(graph_order[i]);
  735. SessionBasic::GetSummaryNodes(child_graph.get(), summary);
  736. }
  737. MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
  738. }
  739. AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
  740. auto fake_graph = GetGraph(fake_graph_id);
  741. auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
  742. auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr {
  743. auto parameter = fake_graph->NewParameter();
  744. MS_EXCEPTION_IF_NULL(parameter);
  745. parameter->set_abstract(abstract);
  746. auto new_parameter = fake_graph->NewParameter(parameter);
  747. // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory.
  748. auto graph_inputs = fake_graph->MutableInputs();
  749. MS_EXCEPTION_IF_NULL(graph_inputs);
  750. graph_inputs->push_back(new_parameter);
  751. return new_parameter;
  752. };
  753. auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr {
  754. MS_EXCEPTION_IF_NULL(cnode);
  755. auto abstract = cnode->abstract();
  756. MS_EXCEPTION_IF_NULL(abstract);
  757. // create multiple parameters if is a tuple output real kernel
  758. if (abstract->isa<abstract::AbstractTuple>()) {
  759. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  760. MS_EXCEPTION_IF_NULL(tuple_abstract);
  761. MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
  762. return create_parameter((*tuple_abstract)[output_idx]);
  763. }
  764. return create_parameter(cnode->abstract());
  765. };
  766. if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) {
  767. std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
  768. auto make_tuple = output_item_with_index.first->cast<CNodePtr>();
  769. MS_EXCEPTION_IF_NULL(make_tuple);
  770. for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
  771. auto input = make_tuple->inputs()[i];
  772. make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input));
  773. }
  774. return fake_graph->NewCNode(make_tuple_inputs);
  775. }
  776. return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
  777. }
  778. void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) {
  779. // get the backend anf node related to the output node of front
  780. auto output_from_graph_id = GetGraphIdByNode(node);
  781. auto output_from_graph = GetGraph(output_from_graph_id);
  782. MS_EXCEPTION_IF_NULL(node);
  783. MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id
  784. << "] to final graph";
  785. MS_EXCEPTION_IF_NULL(output_from_graph);
  786. auto final_graph = GetGraph(final_graph_id_);
  787. MS_EXCEPTION_IF_NULL(final_graph);
  788. // if output is from final graph,it remarks no child graph exist
  789. if (final_graph_id_ == output_from_graph_id) {
  790. MS_LOG(INFO) << "No child graph,output is " << node->DebugString();
  791. final_graph->set_output(ConstructOutput({node}, final_graph));
  792. final_graph->set_executable(false);
  793. return;
  794. }
  795. final_graph->set_output(output_from_graph->output());
  796. }
  797. void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
  798. auto value_node = NewValueNode(value);
  799. auto kernel_info = std::make_shared<device::KernelInfo>();
  800. value_node->set_kernel_info(kernel_info);
  801. value_node->set_abstract(abstract::FromValue(value));
  802. auto final_graph = GetGraph(final_graph_id_);
  803. MS_EXCEPTION_IF_NULL(final_graph);
  804. final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
  805. final_graph->set_executable(false);
  806. MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
  807. }
  808. void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) {
  809. for (auto &output : vec_output) {
  810. if (utils::isa<AnfNodePtr>(output)) {
  811. auto output_anf_node = utils::cast<AnfNodePtr>(output);
  812. SetFinalGraphOutput(output_anf_node);
  813. } else if (utils::isa<ValuePtr>(output)) {
  814. auto value = utils::cast<ValuePtr>(output);
  815. SetFinalGraphOutput(value);
  816. } else {
  817. MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
  818. }
  819. }
  820. }
  821. void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
  822. if (utils::isa<AnfNodePtr>(output)) {
  823. auto output_anf_node = utils::cast<AnfNodePtr>(output);
  824. SetFinalGraphOutput(output_anf_node);
  825. } else if (utils::isa<ValuePtr>(output)) {
  826. auto value = utils::cast<ValuePtr>(output);
  827. SetFinalGraphOutput(value);
  828. } else if (utils::isa<VectorRef>(output)) {
  829. auto vec_output = utils::cast<VectorRef>(output);
  830. SetFinalGraphOutput(vec_output);
  831. } else {
  832. MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
  833. }
  834. }
  835. KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
  836. auto it = graphs_.find(graph_id);
  837. if (it == graphs_.end()) {
  838. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  839. return nullptr;
  840. }
  841. return it->second;
  842. }
  843. void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) {
  844. MS_LOG(INFO) << "Start!";
  845. MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]";
  846. auto condition_graph = GetGraph(condition_graph_id);
  847. MS_EXCEPTION_IF_NULL(condition_graph);
  848. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int>{1});
  849. int32_t *val = nullptr;
  850. val = static_cast<int32_t *>(tensor->data_c(true));
  851. MS_EXCEPTION_IF_NULL(val);
  852. *val = 0;
  853. auto value_node = std::make_shared<ValueNode>(tensor);
  854. value_node->set_abstract(abstract::FromValue(tensor, false));
  855. auto counter_const = condition_graph->NewValueNode(value_node);
  856. condition_graph->AddValueNodeToGraph(counter_const);
  857. // create a new switch op
  858. auto switch_primitive = std::make_shared<Primitive>("StreamSwitch");
  859. auto cond_output_it = condition_output_.find(condition_graph_id);
  860. if (cond_output_it == condition_output_.end()) {
  861. MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
  862. }
  863. auto cond_output_kernel =
  864. AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first;
  865. MS_EXCEPTION_IF_NULL(cond_output_kernel);
  866. std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
  867. CNodePtr switch_node = condition_graph->NewCNode(inputs);
  868. MS_EXCEPTION_IF_NULL(switch_node);
  869. switch_node->set_abstract(std::make_shared<abstract::AbstractNone>());
  870. AnfAlgo::SetGraphId(condition_graph_id, switch_node.get());
  871. // set attr: cond_ RT_GREATER
  872. AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue<int>(static_cast<int>(RT_GREATER)), switch_node);
  873. // set attr:data_type
  874. AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue<int>(static_cast<int>(RT_SWITCH_INT64)), switch_node);
  875. // set attr:true branch graph id ,which is same to stream distinction label
  876. AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(true_graph_id), switch_node);
  877. // append switch at the end of condition graph
  878. auto return_node = condition_graph->get_return();
  879. MS_EXCEPTION_IF_NULL(return_node);
  880. InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node);
  881. MS_LOG(INFO) << "Finish!";
  882. }
  883. void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
  884. auto &graph_execute_order = GetGraphOrder(final_graph_id_);
  885. auto &graph_order_type = GetGraphOrderType(final_graph_id_);
  886. auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id);
  887. if (false_index == kInvalidIndex || false_index == 0) {
  888. return;
  889. }
  890. for (int i = SizeToInt(false_index) - 1; i >= 0; i--) {
  891. size_t graph_index = IntToSize(i);
  892. if (graph_index >= graph_execute_order.size()) {
  893. MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]";
  894. }
  895. if (graph_order_type[graph_index] == COMMON_GRAPH) {
  896. auto true_last_id = graph_execute_order[graph_index];
  897. MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id;
  898. auto true_last = GetGraph(true_last_id);
  899. auto final_graph = GetGraph(final_graph_id_);
  900. MS_EXCEPTION_IF_NULL(final_graph);
  901. auto false_last = GetGraph(false_graph_id);
  902. MS_EXCEPTION_IF_NULL(true_last);
  903. MS_EXCEPTION_IF_NULL(false_last);
  904. MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id;
  905. // create fake output
  906. auto fake_output_graph = NewKernelGraph();
  907. graph_execute_order.push_back(fake_output_graph->graph_id());
  908. graph_order_type.push_back(COMMON_GRAPH);
  909. fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output()));
  910. final_graph->set_output(fake_output_graph->output());
  911. InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output());
  912. InsertMultipleAssignToGraph(false_graph_id, false_last->output(), final_graph->output());
  913. // insert stream active for loop sink
  914. auto context_ptr = MsContext::GetInstance();
  915. MS_EXCEPTION_IF_NULL(context_ptr);
  916. if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
  917. ConfigManager::GetInstance().iter_num() > 1) {
  918. // insert active in true graph, another active will be inserted in kernel adjust
  919. InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel);
  920. }
  921. break;
  922. }
  923. }
  924. }
  925. void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id,
  926. const AnfNodePtr &output) {
  927. if (switches_.find(cond_graph_id) != switches_.end()) {
  928. MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before ";
  929. return;
  930. }
  931. switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id);
  932. condition_output_[cond_graph_id] = output;
  933. MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
  934. // set the type of condition graph
  935. auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
  936. auto &graph_order_type = GetGraphOrderType(final_graph_id_);
  937. if (cond_graph_index >= graph_order_type.size()) {
  938. MS_LOG(EXCEPTION) << "cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size();
  939. }
  940. graph_order_type[cond_graph_index] = CONDITION_GRAPH;
  941. // update distinction label of false graph,update before merge to sure the distinction
  942. if (false_graph_id != kInvalidGraphId) {
  943. // false graph and condition in graph same stream
  944. auto condition_graph = GetGraph(cond_graph_id);
  945. SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
  946. // if false graph is a condition graph and has been switch compiled before,it's false should be updated again
  947. auto cond_it = switches_.find(false_graph_id);
  948. while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) {
  949. cond_graph_id = cond_it->first;
  950. false_graph_id = cond_it->second.second;
  951. condition_graph = GetGraph(cond_graph_id);
  952. SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
  953. cond_it = switches_.find(false_graph_id);
  954. }
  955. }
  956. } // namespace session
  957. void AscendSession::MergeSwitchCompile() {
  958. auto graph_execute_order = GetGraphOrder(final_graph_id_);
  959. auto &graph_order_type = GetGraphOrderType(final_graph_id_);
  960. for (auto switch_compile : switches_) {
  961. auto cond_graph_id = switch_compile.first;
  962. auto true_graph_id = switch_compile.second.first;
  963. auto false_graph_id = switch_compile.second.second;
  964. MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
  965. auto condition_graph = GetGraph(cond_graph_id);
  966. auto final_graph = GetGraph(final_graph_id_);
  967. MS_EXCEPTION_IF_NULL(condition_graph);
  968. MS_EXCEPTION_IF_NULL(final_graph);
  969. // insert switch to condition graph
  970. InsertSwitchToGraph(cond_graph_id, true_graph_id);
  971. auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
  972. auto prev_graph_id = kInvalidGraphId;
  973. // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph
  974. if (cond_graph_index == 0 && !final_graph->execution_order().empty()) {
  975. prev_graph_id = final_graph_id_;
  976. // set the distinction label of final graph
  977. SetStreamDistinctionLabel(final_graph, final_graph_id_, true);
  978. // if condition graph is not the first graph
  979. } else if ((cond_graph_index - 1 < graph_execute_order.size()) &&
  980. (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) {
  981. prev_graph_id = graph_execute_order[cond_graph_index - 1];
  982. }
  983. // insert stream active to common graph
  984. if (prev_graph_id != kInvalidGraphId) {
  985. InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label());
  986. }
  987. // if this is a 'if' condition
  988. auto it = while_condition_graphs_.find(cond_graph_id);
  989. if (it == while_condition_graphs_.end()) {
  990. CopyOutputOfIf(false_graph_id);
  991. } else {
  992. // if it is a while,insert a stream active to true graph
  993. GraphId from_graph = it->second;
  994. InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label());
  995. }
  996. }
  997. MS_LOG(INFO) << "Finish!";
  998. }
  999. void AscendSession::InsertAllAssigns() {
  1000. std::vector<std::pair<AnfNodePtr, AnfNodePtr>> assigns;
  1001. for (auto assign : assigns_) {
  1002. auto front_anf = std::get<0>(assign);
  1003. auto to_graph_id = std::get<1>(assign);
  1004. auto input_idx = std::get<2>(assign);
  1005. auto to_graph = GetGraph(to_graph_id);
  1006. MS_EXCEPTION_IF_NULL(to_graph);
  1007. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  1008. if (input_idx >= graph_inputs.size()) {
  1009. MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
  1010. }
  1011. auto backend_parameter = graph_inputs[input_idx];
  1012. assigns.emplace_back(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
  1013. }
  1014. // erase the repeat assign
  1015. std::set<std::pair<AnfNodePtr, AnfNodePtr>> inserted_nodes;
  1016. for (auto &assign : assigns) {
  1017. auto front_anf = assign.first;
  1018. auto backend_parameter = assign.second;
  1019. auto from_graph_id = GetGraphIdByNode(front_anf);
  1020. auto from_graph = GetGraph(from_graph_id);
  1021. MS_EXCEPTION_IF_NULL(from_graph);
  1022. auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
  1023. if (inserted_nodes.find(assign) == inserted_nodes.end()) {
  1024. InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
  1025. (void)inserted_nodes.insert(assign);
  1026. }
  1027. }
  1028. }
  1029. // insert active to graph
  1030. void AscendSession::SetActive(GraphId from, GraphId to) {
  1031. if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) {
  1032. MS_LOG(WARNING) << " to " << to << " has been exits in map,from " << from << ",exist from "
  1033. << while_condition_graphs_[to];
  1034. return;
  1035. }
  1036. MS_LOG(INFO) << "From " << from << " to " << to;
  1037. auto &graph_order = GetGraphOrder(final_graph_id_);
  1038. auto &graph_type = GetGraphOrderType(final_graph_id_);
  1039. std::vector<GraphId> graph_order_new;
  1040. std::vector<GraphType> graph_type_new;
  1041. for (size_t i = 0; i < graph_order.size(); i++) {
  1042. auto graph_id = graph_order[i];
  1043. graph_order_new.push_back(graph_id);
  1044. graph_type_new.push_back(graph_type[i]);
  1045. if (from == graph_id) {
  1046. graph_order_new.push_back(kInvalidGraphId);
  1047. graph_type_new.push_back(BRANCH_END);
  1048. }
  1049. }
  1050. graph_order = graph_order_new;
  1051. graph_type = graph_type_new;
  1052. // set the graph type of condition graph
  1053. graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH;
  1054. // record the condition graph into while condition set
  1055. while_condition_graphs_[to] = from;
  1056. }
  1057. void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) {
  1058. MS_LOG(INFO) << "Start!";
  1059. MS_EXCEPTION_IF_NULL(front_anf);
  1060. auto from_graph_id = GetGraphIdByNode(front_anf);
  1061. auto from_graph = GetGraph(from_graph_id);
  1062. MS_EXCEPTION_IF_NULL(from_graph);
  1063. auto to_graph = GetGraph(to_graph_id);
  1064. MS_EXCEPTION_IF_NULL(to_graph);
  1065. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  1066. if (input_idx >= graph_inputs.size()) {
  1067. MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
  1068. }
  1069. auto backend_parameter = graph_inputs[input_idx];
  1070. MS_EXCEPTION_IF_NULL(backend_parameter);
  1071. auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
  1072. MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
  1073. << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
  1074. << "]";
  1075. // a node should not assign to itself
  1076. if (backend_arg.get() == backend_parameter.get()) {
  1077. return;
  1078. }
  1079. // if arg is the the parameter of child graph,it is parameter of final graph too
  1080. if (front_anf->isa<Parameter>()) {
  1081. MS_EXCEPTION_IF_NULL(backend_arg);
  1082. MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString()
  1083. << "] will be replaced.";
  1084. to_graph->ReplaceNode(backend_parameter, backend_arg);
  1085. return;
  1086. }
  1087. MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node"
  1088. << backend_parameter->DebugString() << "of graph " << to_graph_id;
  1089. assigns_.emplace_back(std::tuple<AnfNodePtr, GraphId, size_t>(front_anf, to_graph_id, input_idx));
  1090. }
  1091. void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id,
  1092. size_t input_idx) {
  1093. MS_LOG(INFO) << "Start!";
  1094. std::pair<GraphId, size_t> graph_input_pair(to_graph_id, input_idx);
  1095. initial_tenosrs_[graph_input_pair] = front_tensor;
  1096. MS_LOG(INFO) << "Finish!";
  1097. }
  1098. void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
  1099. MS_LOG(INFO) << "to_graph_id " << to_graph_id;
  1100. auto &graph_order = GetGraphOrder(final_graph_id_);
  1101. auto &graph_type = GetGraphOrderType(final_graph_id_);
  1102. for (size_t i = 0; i < graph_order.size(); i++) {
  1103. if (graph_order[i] == to_graph_id) {
  1104. return;
  1105. }
  1106. }
  1107. // if graph is not in graph order,add it to graph order
  1108. SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false);
  1109. graph_order.push_back(to_graph_id);
  1110. graph_type.push_back(COMMON_GRAPH);
  1111. for (size_t i = 0; i < graph_order.size(); i++) {
  1112. MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i];
  1113. }
  1114. }
  1115. size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) {
  1116. auto output_num = AnfAlgo::GetOutputTensorNum(node);
  1117. if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
  1118. return input_index + output_num;
  1119. }
  1120. auto valid_inputs = graph->valid_inputs();
  1121. if (valid_inputs[input_index]) {
  1122. SetChildGraphParameter(node, graph->graph_id(), input_index);
  1123. } else {
  1124. MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString();
  1125. }
  1126. return ++input_index;
  1127. }
  1128. size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) {
  1129. MS_EXCEPTION_IF_NULL(value);
  1130. if (!value->isa<Tensor>()) {
  1131. MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString();
  1132. }
  1133. SetChildGraphParameter(value->cast<TensorPtr>(), graph->graph_id(), input_index);
  1134. return ++input_index;
  1135. }
  1136. size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) {
  1137. auto index = input_index;
  1138. for (auto &arg : vec_args) {
  1139. if (utils::isa<AnfNodePtr>(arg)) {
  1140. // arg is a anf node
  1141. auto node = utils::cast<AnfNodePtr>(arg);
  1142. index = SetChildGraphInput(graph, node, input_index);
  1143. } else if (utils::isa<ValuePtr>(arg)) {
  1144. // arg is a tensor
  1145. auto value = utils::cast<ValuePtr>(arg);
  1146. index = SetChildGraphInput(graph, value, input_index);
  1147. } else {
  1148. MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString();
  1149. }
  1150. }
  1151. return index;
  1152. }
  1153. void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
  1154. MS_LOG(INFO) << "Set input of graph " << g;
  1155. auto to_graph = GetGraph(g);
  1156. MS_EXCEPTION_IF_NULL(to_graph);
  1157. DumpGraphInputArgs(args);
  1158. UpdateGraphOrder(g);
  1159. auto &graph_inputs = to_graph->inputs();
  1160. auto real_args = GetRealArgs(to_graph, args);
  1161. size_t input_index = 0;
  1162. for (size_t i = 0; i < real_args.size(); i++) {
  1163. if (input_index >= graph_inputs.size()) {
  1164. MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
  1165. }
  1166. auto &real_arg = real_args[i];
  1167. if (utils::isa<AnfNodePtr>(real_arg)) {
  1168. // arg is a anf node
  1169. auto node = utils::cast<AnfNodePtr>(real_arg);
  1170. input_index = SetChildGraphInput(to_graph, node, input_index);
  1171. } else if (utils::isa<ValuePtr>(real_arg)) {
  1172. // arg is a tensor
  1173. auto value = utils::cast<ValuePtr>(real_arg);
  1174. input_index = SetChildGraphInput(to_graph, value, input_index);
  1175. } else if (utils::isa<VectorRef>(real_arg)) {
  1176. // arg is a VectorRef
  1177. auto vec_args = utils::cast<VectorRef>(real_arg);
  1178. input_index = SetChildGraphInput(to_graph, vec_args, input_index);
  1179. } else {
  1180. MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString();
  1181. }
  1182. }
  1183. MS_LOG(INFO) << "Finish!";
  1184. }
  1185. GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
  1186. for (const auto &graph_item : graphs_) {
  1187. auto graph = graph_item.second;
  1188. MS_EXCEPTION_IF_NULL(graph);
  1189. // if front_anf is a parameter,the backend parameter may have two
  1190. if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
  1191. return graph_item.first;
  1192. }
  1193. }
  1194. MS_EXCEPTION_IF_NULL(front_anf);
  1195. MS_LOG(DEBUG) << "front_anf " << front_anf->DebugString() << " is not exist in any graph";
  1196. return kInvalidGraphId;
  1197. }
  1198. void AscendSession::MergeGraphExecOrder() {
  1199. MS_LOG(INFO) << "Start!";
  1200. // merge graph order
  1201. auto &graph_order = GetGraphOrder(final_graph_id_);
  1202. auto &graph_type = GetGraphOrderType(final_graph_id_);
  1203. auto final_graph = GetGraph(final_graph_id_);
  1204. MS_EXCEPTION_IF_NULL(final_graph);
  1205. if (graph_order.empty()) {
  1206. MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!";
  1207. return;
  1208. }
  1209. if (graph_order.size() > 1) {
  1210. auto context_ptr = MsContext::GetInstance();
  1211. MS_EXCEPTION_IF_NULL(context_ptr);
  1212. if (!context_ptr->enable_task_sink()) {
  1213. MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
  1214. }
  1215. }
  1216. // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph
  1217. SetStreamDistinctionLabel(final_graph, graph_order[0], false);
  1218. std::vector<CNodePtr> final_exec_order = final_graph->execution_order();
  1219. KernelGraphPtr last_graph = nullptr;
  1220. for (size_t i = 0; i < graph_order.size(); i++) {
  1221. auto graph_id = graph_order[i];
  1222. if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
  1223. continue;
  1224. }
  1225. auto child_graph = GetGraph(graph_id);
  1226. last_graph = child_graph;
  1227. MS_EXCEPTION_IF_NULL(child_graph);
  1228. auto exec_order = child_graph->execution_order();
  1229. MS_LOG(INFO) << "Merge graph,graph_id " << graph_id;
  1230. (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order),
  1231. [&](CNodePtr node) -> CNodePtr {
  1232. AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get());
  1233. return node;
  1234. });
  1235. // add all value nodes of child graphs to final graph
  1236. for (auto &value_node : child_graph->graph_value_nodes()) {
  1237. final_graph->AddValueNodeToGraph(value_node);
  1238. }
  1239. // copy ref map to final graph
  1240. auto child_ref_map = child_graph->GetRefMap();
  1241. for (auto &item : child_ref_map) {
  1242. if (final_graph->IsInRefOutputMap(item.first)) {
  1243. MS_LOG(EXCEPTION) << "The ref pair is already in final graph!";
  1244. }
  1245. final_graph->AddRefCorrespondPairs(item.first, item.second);
  1246. }
  1247. }
  1248. // set final_exec_order into final graph
  1249. MS_EXCEPTION_IF_NULL(final_graph);
  1250. DumpGraphExeOrder(final_exec_order);
  1251. final_graph->set_execution_order(final_exec_order);
  1252. }
  1253. void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
  1254. MS_EXCEPTION_IF_NULL(from);
  1255. MS_EXCEPTION_IF_NULL(to);
  1256. if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
  1257. AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
  1258. return;
  1259. }
  1260. if (from.get() == to.get()) {
  1261. return;
  1262. }
  1263. MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to "
  1264. << to->DebugString();
  1265. auto graph = graphs_[graph_id];
  1266. MS_EXCEPTION_IF_NULL(graph);
  1267. // config inputs of assign node
  1268. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
  1269. // generate a new cnode
  1270. auto assign_node = graph->NewCNode(inputs);
  1271. MS_EXCEPTION_IF_NULL(assign_node);
  1272. assign_node->set_abstract(to->abstract());
  1273. // append the assign at the end of from graph
  1274. InsertDependToGraph(graph_id, assign_node);
  1275. }
  1276. void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) {
  1277. std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
  1278. std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
  1279. MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to ["
  1280. << AnfAlgo::GetGraphId(to.get()) << "]";
  1281. if (from_outputs.size() != to_outputs.size()) {
  1282. MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]";
  1283. MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size["
  1284. << to_outputs.size() << "]";
  1285. }
  1286. for (size_t i = 0; i < from_outputs.size(); i++) {
  1287. InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]);
  1288. }
  1289. }
  1290. void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) {
  1291. MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream;
  1292. auto from_graph = GetGraph(graph_id);
  1293. MS_EXCEPTION_IF_NULL(from_graph);
  1294. std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("StreamActive"))};
  1295. auto active_node = from_graph->NewCNode(inputs);
  1296. MS_EXCEPTION_IF_NULL(active_node);
  1297. active_node->set_abstract(std::make_shared<abstract::AbstractNone>());
  1298. // set the active stream id into the attr of active node
  1299. std::vector<uint32_t> active_index_value = {};
  1300. active_index_value.push_back(actived_stream);
  1301. AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_value), active_node);
  1302. // append the active node at the end of from graph
  1303. auto return_node = from_graph->get_return();
  1304. MS_EXCEPTION_IF_NULL(return_node);
  1305. InsertControlDependToGraph(graph_id, return_node->input(1), active_node);
  1306. }
  1307. void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) {
  1308. AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node));
  1309. }
  1310. void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node,
  1311. const AnfNodePtr &second_node) {
  1312. AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node),
  1313. NOT_NULL(second_node));
  1314. }
  1315. size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) {
  1316. auto &graph_order = GetGraphOrder(final_graph);
  1317. for (size_t i = 0; i < graph_order.size(); i++) {
  1318. if (child_graph == graph_order[i]) {
  1319. return i;
  1320. }
  1321. }
  1322. return kInvalidIndex;
  1323. }
  1324. std::vector<GraphId> &AscendSession::GetGraphOrder(GraphId final_graph_id) {
  1325. auto graph_order_iter = graph_execute_orders_.find(final_graph_id);
  1326. if (graph_order_iter == graph_execute_orders_.end()) {
  1327. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph";
  1328. }
  1329. return graph_order_iter->second;
  1330. }
  1331. // get graph order type vector by graph id
  1332. std::vector<GraphType> &AscendSession::GetGraphOrderType(GraphId final_graph_id) {
  1333. auto graph_type_iter = graph_order_types_.find(final_graph_id);
  1334. if (graph_type_iter == graph_order_types_.end()) {
  1335. MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_";
  1336. }
  1337. return graph_type_iter->second;
  1338. }
  1339. void AscendSession::SyncInitialTenosrToDevice() {
  1340. for (auto &item : initial_tenosrs_) {
  1341. auto to_graph_id = item.first.first;
  1342. auto input_idx = item.first.second;
  1343. auto front_tensor = item.second;
  1344. auto to_graph = GetGraph(to_graph_id);
  1345. MS_EXCEPTION_IF_NULL(to_graph);
  1346. std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
  1347. if (input_idx >= graph_inputs.size()) {
  1348. MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
  1349. }
  1350. auto backend_parameter = graph_inputs[input_idx];
  1351. // sync data from host to device
  1352. MS_EXCEPTION_IF_NULL(front_tensor);
  1353. size_t tensor_size = front_tensor->data().nbytes();
  1354. auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
  1355. MS_EXCEPTION_IF_NULL(addr);
  1356. if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
  1357. front_tensor->data_type(), front_tensor->data_c(false))) {
  1358. MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
  1359. }
  1360. }
  1361. }
  1362. static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list) {
  1363. // count the output of every anf node
  1364. std::set<AnfNodePtr> has_output_nodes;
  1365. for (auto &anf_node : list) {
  1366. for (auto &input : anf_node->inputs()) {
  1367. (void)has_output_nodes.insert(input);
  1368. }
  1369. }
  1370. auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
  1371. std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
  1372. int output_idx = 0;
  1373. for (auto &anf_node : list) {
  1374. if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
  1375. new_kernel_graph->set_return(anf_node);
  1376. }
  1377. if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
  1378. MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
  1379. make_tuple_inputs.push_back(anf_node);
  1380. }
  1381. }
  1382. if (new_kernel_graph->get_return() == nullptr) {
  1383. new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
  1384. }
  1385. }
  1386. std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
  1387. const std::vector<CNodePtr> &list) {
  1388. MS_EXCEPTION_IF_NULL(new_kernel_graph);
  1389. MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
  1390. MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
  1391. std::vector<AnfNodePtr> call_node_inputs;
  1392. std::vector<AnfNodePtr> new_graph_inputs;
  1393. // create new parameter from cnode
  1394. for (auto &anf_node : list) {
  1395. auto cnode = anf_node->cast<CNodePtr>();
  1396. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  1397. auto input = cnode->inputs()[input_idx];
  1398. MS_EXCEPTION_IF_NULL(input);
  1399. AnfNodePtr new_parameter = nullptr;
  1400. // value node consider move to new graph
  1401. if (input->isa<ValueNode>()) {
  1402. cnode->set_input(input_idx, input);
  1403. continue;
  1404. } else if (input->isa<Parameter>()) {
  1405. // parameter reuse and should attention mulptiple use of one parameter
  1406. cnode->set_input(input_idx, input);
  1407. new_parameter = input;
  1408. } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
  1409. // if is cnode and not in current child graph
  1410. new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
  1411. cnode->set_input(input_idx, new_parameter);
  1412. } else {
  1413. // if is a cnode and in current graph
  1414. continue;
  1415. }
  1416. // if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node
  1417. // args
  1418. if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) {
  1419. new_graph_inputs.push_back(new_parameter);
  1420. call_node_inputs.push_back(input);
  1421. }
  1422. }
  1423. }
  1424. // set graph inputs of new graph
  1425. auto graph_inputs = new_kernel_graph->MutableInputs();
  1426. MS_EXCEPTION_IF_NULL(graph_inputs);
  1427. graph_inputs->clear();
  1428. std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs));
  1429. MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
  1430. ConstructSplitedGraphOutput(new_kernel_graph, list);
  1431. MS_LOG(INFO) << "end";
  1432. return call_node_inputs;
  1433. }
  1434. void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
  1435. std::set<KernelGraphPtr> memo;
  1436. // if root graph output is a call node ,the root graph is condition graph of 'if' sentence
  1437. auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first;
  1438. if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) {
  1439. SplitGraph(root_graph, {prim::kPrimReturn});
  1440. for (auto &child_graph : root_graph->child_graph_order()) {
  1441. RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo));
  1442. }
  1443. } else {
  1444. RecurseSplitGraph(root_graph, NOT_NULL(&memo));
  1445. }
  1446. memo.clear();
  1447. // replace the real input if the real input is a call
  1448. RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo));
  1449. }
  1450. AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
  1451. const std::vector<CNodePtr> &child_graph_list) {
  1452. // if child graph list only has a call ,then return the exist call
  1453. if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
  1454. return child_graph_list[0];
  1455. }
  1456. // create new child graph
  1457. auto child_graph = NewKernelGraph();
  1458. MS_EXCEPTION_IF_NULL(child_graph);
  1459. // create new value node to bind child graph
  1460. auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
  1461. std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
  1462. graph_value_node};
  1463. // set the graph id of all node of child graph
  1464. for (auto &child_graph_node : child_graph_list) {
  1465. AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
  1466. }
  1467. auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
  1468. std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
  1469. auto new_call = graph->NewCNode(new_call_input);
  1470. AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
  1471. return new_call;
  1472. }
  1473. void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
  1474. MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
  1475. auto apply_list = GetCNodes(TopoSort(graph->get_return()));
  1476. // update the root graph child graph order
  1477. AscendControlParser::UpdateChildGraphOrder(graph);
  1478. // get child list from current graph
  1479. std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
  1480. if (child_graph_lists.size() > 1) {
  1481. std::list<AnfNodePtr> depend_input = {};
  1482. for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
  1483. auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]);
  1484. MS_EXCEPTION_IF_NULL(call_node);
  1485. // if call node is the last call of true graph,no need create child graph after that
  1486. auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
  1487. depend_input.push_front(call_node);
  1488. if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) {
  1489. break;
  1490. }
  1491. }
  1492. depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))));
  1493. auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end()));
  1494. auto new_return_primitive =
  1495. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
  1496. graph->set_return(graph->NewCNode({new_return_primitive, depend}));
  1497. AnfNodePtr pre_call_node = nullptr;
  1498. AnfNodePtr cur_call_node = nullptr;
  1499. auto iter = depend_input.begin();
  1500. for (++iter; iter != depend_input.end(); ++iter) {
  1501. pre_call_node = cur_call_node;
  1502. cur_call_node = *iter;
  1503. if (pre_call_node != nullptr && cur_call_node != nullptr) {
  1504. AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node));
  1505. }
  1506. }
  1507. }
  1508. AscendControlParser::UpdateChildGraphOrder(graph);
  1509. UpdateRealInput(graph);
  1510. MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
  1511. // recurse to split child graph
  1512. }
  1513. void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) {
  1514. memo->insert(graph.get());
  1515. SplitGraph(graph, {prim::kPrimCall});
  1516. for (auto &child_graph : graph->child_graph_order()) {
  1517. if (memo->find(child_graph) == memo->end()) {
  1518. RecurseSplitGraph(NOT_NULL(child_graph), memo);
  1519. }
  1520. }
  1521. }
  1522. void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }
  1523. void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
  1524. AscendControlParser::ExecutorValidate(graph);
  1525. }
  1526. void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) {
  1527. memo->insert(graph.get());
  1528. CompileChildGraph(graph);
  1529. for (auto child_graph : graph->child_graph_order()) {
  1530. if (memo->find(child_graph) != memo->end()) {
  1531. continue;
  1532. }
  1533. RecurseCompileGraph(NOT_NULL(child_graph), memo);
  1534. }
  1535. }
  1536. } // namespace session
  1537. } // namespace mindspore