You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_adjust.cc 56 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/kernel_adjust.h"
  17. #include <map>
  18. #include <algorithm>
  19. #include <string>
  20. #include <vector>
  21. #include <utility>
  22. #include "backend/common/session/anf_runtime_algorithm.h"
  23. #include "include/common/utils/anfalgo.h"
  24. #include "utils/ms_context.h"
  25. #include "runtime/device/ms_device_shape_transfer.h"
  26. #include "include/common/utils/config_manager.h"
  27. #include "utils/ms_utils.h"
  28. #include "kernel/kernel_build_info.h"
  29. #include "include/common/utils/utils.h"
  30. #include "plugin/device/ascend/hal/device/profiling/profiling_manager.h"
  31. #include "runtime/base.h"
  32. #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
  33. #include "utils/shape_utils.h"
  34. #ifndef ENABLE_SECURITY
  35. #include "debug/data_dump/dump_json_parser.h"
  36. #endif
  37. namespace {
  38. constexpr auto kGradients = "Gradients";
  39. constexpr auto kSpecifyParameter = "accu_status";
  40. size_t kNPUShape = 8;
  41. constexpr size_t kLastHandleDiff = 2;
  42. } // namespace
  43. namespace mindspore {
  44. namespace device {
  45. #ifndef ENABLE_SECURITY
  46. using device::ascend::ProfilingUtils;
  47. #endif
  48. void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  49. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  50. const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order();
  51. std::vector<CNodePtr> getnext_list;
  52. std::vector<CNodePtr> other_list;
  53. for (const auto &cnode : origin_cnode_list) {
  54. if (common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
  55. getnext_list.emplace_back(cnode);
  56. } else {
  57. other_list.emplace_back(cnode);
  58. }
  59. }
  60. std::vector<CNodePtr> new_order_list;
  61. new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end());
  62. new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end());
  63. kernel_graph_ptr->set_execution_order(new_order_list);
  64. }
  65. bool KernelAdjust::NeedLoopSink() {
  66. auto context_ptr = MsContext::GetInstance();
  67. MS_EXCEPTION_IF_NULL(context_ptr);
  68. return (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) &&
  69. context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1);
  70. }
  71. CNodePtr CreateEventApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
  72. std::vector<AnfNodePtr> input_list) {
  73. MS_EXCEPTION_IF_NULL(graph_ptr);
  74. CNodePtr event_node_ptr = graph_ptr->NewCNode(input_list);
  75. MS_EXCEPTION_IF_NULL(event_node_ptr);
  76. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
  77. selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
  78. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), event_node_ptr.get());
  79. common::AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), event_node_ptr);
  80. auto abstract_none = std::make_shared<abstract::AbstractNone>();
  81. event_node_ptr->set_abstract(abstract_none);
  82. return event_node_ptr;
  83. }
  84. CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
  85. uint32_t event_id) {
  86. MS_EXCEPTION_IF_NULL(graph_ptr);
  87. auto send_op = std::make_shared<Primitive>(kSendOpName);
  88. MS_EXCEPTION_IF_NULL(send_op);
  89. auto send_apply = std::make_shared<ValueNode>(send_op);
  90. MS_EXCEPTION_IF_NULL(send_apply);
  91. return CreateEventApplyKernel(graph_ptr, event_id, {send_apply});
  92. }
  93. CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
  94. uint32_t event_id) {
  95. MS_EXCEPTION_IF_NULL(graph_ptr);
  96. auto recv_op = std::make_shared<Primitive>(kRecvOpName);
  97. MS_EXCEPTION_IF_NULL(recv_op);
  98. auto recv_apply = std::make_shared<ValueNode>(recv_op);
  99. MS_EXCEPTION_IF_NULL(recv_apply);
  100. return CreateEventApplyKernel(graph_ptr, event_id, {recv_apply});
  101. }
  102. bool KernelAdjust::ExistGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  103. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  104. const std::vector<CNodePtr> &cnode_list = kernel_graph_ptr->execution_order();
  105. return std::any_of(cnode_list.begin(), cnode_list.end(),
  106. [](const CNodePtr &cnode) { return common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName; });
  107. }
  108. bool KernelAdjust::ExistIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  109. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  110. const auto &exe_orders = kernel_graph_ptr->execution_order();
  111. return std::any_of(exe_orders.begin(), exe_orders.end(), [&kernel_graph_ptr](const CNodePtr &node) {
  112. return AnfAlgo::IsIndependentNode(node) && AnfAlgo::GetGraphId(node.get()) == kernel_graph_ptr->graph_id();
  113. });
  114. }
  115. void KernelAdjust::InsertIndepentParallel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  116. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
  117. std::vector<CNodePtr> *exec_order) {
  118. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  119. MS_EXCEPTION_IF_NULL(exec_order);
  120. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  121. CNodePtr independent_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch);
  122. MS_EXCEPTION_IF_NULL(independent_switch_app);
  123. uint32_t independent_switch_stream_id = resource_manager.ApplyNewStream();
  124. AnfAlgo::SetStreamId(independent_switch_stream_id, independent_switch_app.get());
  125. common::AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), independent_switch_app);
  126. common::AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kIndependentStreamSwitch),
  127. independent_switch_app);
  128. (*exec_order).push_back(independent_switch_app);
  129. MS_LOG(INFO) << "Independent op loop insert Stream Switch " << independent_switch_app->fullname_with_scope();
  130. }
  131. void KernelAdjust::InsertFpBpLoopStreamSwitch(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  132. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
  133. std::vector<CNodePtr> *exec_order, uint32_t *fpbp_stream_id,
  134. uint32_t *fpbp_switch_stream_id) {
  135. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  136. MS_EXCEPTION_IF_NULL(exec_order);
  137. MS_EXCEPTION_IF_NULL(fpbp_stream_id);
  138. MS_EXCEPTION_IF_NULL(fpbp_switch_stream_id);
  139. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  140. *fpbp_switch_stream_id = resource_manager.ApplyNewStream();
  141. *fpbp_stream_id = resource_manager.ApplyNewStream();
  142. CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch);
  143. MS_EXCEPTION_IF_NULL(fpbp_switch_app);
  144. AnfAlgo::SetStreamId(*fpbp_switch_stream_id, fpbp_switch_app.get());
  145. common::AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
  146. // update fpbp loop stream switch true_branch_stream attr
  147. common::AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(*fpbp_stream_id), fpbp_switch_app);
  148. common::AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kFpBpStreamSwitch), fpbp_switch_app);
  149. (*exec_order).push_back(fpbp_switch_app);
  150. MS_LOG(INFO) << "FpBp loop insert Stream Switch " << fpbp_switch_app->fullname_with_scope();
  151. }
  152. void KernelAdjust::CopyMemcpyList(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  153. const std::vector<CNodePtr> &orders, size_t order_index,
  154. std::vector<CNodePtr> *memcpy_list, std::vector<CNodePtr> *other_list) {
  155. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  156. MS_EXCEPTION_IF_NULL(memcpy_list);
  157. MS_EXCEPTION_IF_NULL(other_list);
  158. CNodePtr cur_cnode = nullptr;
  159. for (size_t idx = order_index + 1; idx < orders.size(); idx++) {
  160. cur_cnode = orders[idx];
  161. if (common::AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) {
  162. auto pre_node = orders[idx - 1];
  163. auto pre_kernel_name = common::AnfAlgo::GetCNodeName(pre_node);
  164. if (pre_kernel_name == kAtomicAddrCleanOpName) {
  165. (*other_list).pop_back();
  166. (*memcpy_list).push_back(pre_node);
  167. }
  168. (*memcpy_list).emplace_back(cur_cnode);
  169. } else {
  170. (*other_list).emplace_back(cur_cnode);
  171. }
  172. }
  173. }
  174. void KernelAdjust::InsertEosDoneRecv(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  175. std::vector<CNodePtr> *exec_order, uint32_t eos_done_event_id,
  176. uint32_t fpbp_stream_id) {
  177. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  178. MS_EXCEPTION_IF_NULL(exec_order);
  179. CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id);
  180. AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get());
  181. (*exec_order).push_back(eos_done_recv);
  182. MS_LOG(INFO) << "FpBp loop insert EoS done Recv " << eos_done_recv->fullname_with_scope();
  183. }
  184. void KernelAdjust::InsertGetNextLoopStreamActive(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  185. std::vector<CNodePtr> *exec_order,
  186. const std::vector<uint32_t> &getnext_active_streams) {
  187. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  188. MS_EXCEPTION_IF_NULL(exec_order);
  189. CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr);
  190. MS_EXCEPTION_IF_NULL(getnext_active_app);
  191. common::AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams),
  192. getnext_active_app);
  193. (*exec_order).push_back(getnext_active_app);
  194. MS_LOG(INFO) << "FpBp loop insert GetNext loop Stream Active " << getnext_active_app->fullname_with_scope();
  195. }
  196. void KernelAdjust::InsertFpBpStartRecv(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  197. std::vector<CNodePtr> *exec_order, uint32_t fpbp_start_event_id,
  198. uint32_t fpbp_stream_id) {
  199. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  200. MS_EXCEPTION_IF_NULL(exec_order);
  201. CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
  202. AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get());
  203. (*exec_order).push_back(fpbp_start_recv);
  204. MS_LOG(INFO) << "FpBp loop insert FpBp start Recv " << fpbp_start_recv->fullname_with_scope();
  205. }
  206. void KernelAdjust::InsertNextLoopAssignAdd(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  207. std::vector<CNodePtr> *exec_order,
  208. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
  209. uint32_t fpbp_stream_id) {
  210. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  211. MS_EXCEPTION_IF_NULL(exec_order);
  212. CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, false);
  213. MS_EXCEPTION_IF_NULL(assign_add_one);
  214. AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get());
  215. (*exec_order).push_back(assign_add_one);
  216. MS_LOG(INFO) << "FpBp loop insert next loop AssignAdd " << assign_add_one->fullname_with_scope();
  217. }
  218. void KernelAdjust::InsertCurrentLoopAssignAdd(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  219. std::vector<CNodePtr> *exec_order,
  220. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
  221. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  222. MS_EXCEPTION_IF_NULL(exec_order);
  223. CNodePtr cur_assign_add = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, true);
  224. MS_EXCEPTION_IF_NULL(cur_assign_add);
  225. common::AnfAlgo::SetNodeAttr(kAttrFpBpEnd, MakeValue<bool>(true), cur_assign_add);
  226. (*exec_order).push_back(cur_assign_add);
  227. MS_LOG(INFO) << "FpBp loop insert current loop AssignAdd " << cur_assign_add->fullname_with_scope();
  228. }
  229. void KernelAdjust::InsertFpBpAndEosLoopStreamActive(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  230. std::vector<CNodePtr> *exec_order,
  231. const std::vector<uint32_t> &fpbp_active_streams) {
  232. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  233. MS_EXCEPTION_IF_NULL(exec_order);
  234. CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
  235. MS_EXCEPTION_IF_NULL(fpbp_active_app);
  236. common::AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams),
  237. fpbp_active_app);
  238. (*exec_order).push_back(fpbp_active_app);
  239. MS_LOG(INFO) << "FpBp loop insert FpBp loop and Eos loop Stream Active " << fpbp_active_app->fullname_with_scope();
  240. }
  241. void KernelAdjust::InsertGetNextLoopStreamSwitch(
  242. const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, std::vector<CNodePtr> *exec_order,
  243. uint32_t *getnext_switch_stream_id, uint32_t *getnext_stream_id,
  244. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
  245. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  246. MS_EXCEPTION_IF_NULL(exec_order);
  247. MS_EXCEPTION_IF_NULL(getnext_switch_stream_id);
  248. MS_EXCEPTION_IF_NULL(getnext_stream_id);
  249. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  250. *getnext_switch_stream_id = resource_manager.ApplyNewStream();
  251. *getnext_stream_id = resource_manager.ApplyNewStream();
  252. CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch);
  253. MS_EXCEPTION_IF_NULL(getnext_switch_app);
  254. AnfAlgo::SetStreamId(*getnext_switch_stream_id, getnext_switch_app.get());
  255. // update getnext loop stream switch true_branch_stream attr
  256. common::AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), getnext_switch_app);
  257. common::AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(*getnext_stream_id), getnext_switch_app);
  258. common::AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kGetNextStreamSwitch), getnext_switch_app);
  259. (*exec_order).push_back(getnext_switch_app);
  260. MS_LOG(INFO) << "GetNext loop insert Stream Switch " << getnext_switch_app->fullname_with_scope();
  261. }
  262. void KernelAdjust::SetBeforeGetNextStreamID(std::vector<CNodePtr> *exec_order, const std::vector<CNodePtr> &orders,
  263. size_t *order_index, CNodePtr getnext_cnode, uint32_t getnext_stream_id) {
  264. MS_EXCEPTION_IF_NULL(exec_order);
  265. MS_EXCEPTION_IF_NULL(order_index);
  266. for (; *order_index < orders.size(); (*order_index)++) {
  267. auto node = orders[*order_index];
  268. (*exec_order).push_back(node);
  269. AnfAlgo::SetStreamId(getnext_stream_id, (*exec_order)[(*exec_order).size() - 1].get());
  270. if (common::AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
  271. getnext_cnode = node;
  272. break;
  273. }
  274. }
  275. }
  276. void KernelAdjust::InsertGetNextLoopFpBpStartSend(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  277. std::vector<CNodePtr> *exec_order, uint32_t *fpbp_start_event_id,
  278. uint32_t getnext_stream_id) {
  279. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  280. MS_EXCEPTION_IF_NULL(exec_order);
  281. MS_EXCEPTION_IF_NULL(fpbp_start_event_id);
  282. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  283. *fpbp_start_event_id = resource_manager.ApplyNewEvent();
  284. CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, *fpbp_start_event_id);
  285. AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get());
  286. (*exec_order).push_back(fpbp_start_send);
  287. MS_LOG(INFO) << "GetNext loop insert FpBp start Send " << fpbp_start_send->fullname_with_scope();
  288. }
  289. void KernelAdjust::InsertGetNextLoopEosStartSend(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  290. std::vector<CNodePtr> *exec_order, uint32_t *eos_start_event_id,
  291. uint32_t getnext_stream_id) {
  292. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  293. MS_EXCEPTION_IF_NULL(exec_order);
  294. MS_EXCEPTION_IF_NULL(eos_start_event_id);
  295. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  296. *eos_start_event_id = resource_manager.ApplyNewEvent();
  297. CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_start_event_id);
  298. AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get());
  299. (*exec_order).push_back(eos_start_send);
  300. MS_LOG(INFO) << "GetNext loop insert EoS start Send " << eos_start_send->fullname_with_scope();
  301. }
  302. void KernelAdjust::InsertEosStreamSwitch(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  303. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
  304. std::vector<CNodePtr> *exec_order, uint32_t *eos_switch_stream_id,
  305. uint32_t *eos_stream_id) {
  306. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  307. MS_EXCEPTION_IF_NULL(exec_order);
  308. MS_EXCEPTION_IF_NULL(eos_switch_stream_id);
  309. MS_EXCEPTION_IF_NULL(eos_stream_id);
  310. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  311. *eos_switch_stream_id = resource_manager.ApplyNewStream();
  312. *eos_stream_id = resource_manager.ApplyNewStream();
  313. CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch);
  314. MS_EXCEPTION_IF_NULL(eos_switch_app);
  315. AnfAlgo::SetStreamId(*eos_switch_stream_id, eos_switch_app.get());
  316. common::AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), eos_switch_app);
  317. // update eos loop stream switch true_branch_stream attr
  318. common::AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(*eos_stream_id), eos_switch_app);
  319. common::AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kEosStreamSwitch), eos_switch_app);
  320. (*exec_order).push_back(eos_switch_app);
  321. MS_LOG(INFO) << "EoS loop insert Stream Switch " << eos_switch_app->fullname_with_scope();
  322. }
  323. void KernelAdjust::InsertGetNextLoopEosStartRecv(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  324. std::vector<CNodePtr> *exec_order, uint32_t eos_start_event_id,
  325. uint32_t eos_stream_id) {
  326. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  327. MS_EXCEPTION_IF_NULL(exec_order);
  328. CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id);
  329. AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get());
  330. (*exec_order).push_back(eos_start_recv);
  331. MS_LOG(INFO) << "EoS loop insert EoS Recv " << eos_start_recv->fullname_with_scope();
  332. }
  333. void KernelAdjust::InsertEosOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  334. std::vector<CNodePtr> *exec_order, const CNodePtr &getnext_cnode,
  335. uint32_t eos_stream_id) {
  336. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  337. MS_EXCEPTION_IF_NULL(exec_order);
  338. MS_EXCEPTION_IF_NULL(getnext_cnode);
  339. CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode);
  340. MS_EXCEPTION_IF_NULL(end_of_sequence_op);
  341. AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get());
  342. (*exec_order).push_back(end_of_sequence_op);
  343. MS_LOG(INFO) << "EoS loop insert Eos Op " << end_of_sequence_op->fullname_with_scope();
  344. }
  345. void KernelAdjust::InsertEosDoneSend(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  346. std::vector<CNodePtr> *exec_order, uint32_t *eos_done_event_id,
  347. uint32_t eos_stream_id) {
  348. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  349. MS_EXCEPTION_IF_NULL(exec_order);
  350. MS_EXCEPTION_IF_NULL(eos_done_event_id);
  351. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  352. *eos_done_event_id = resource_manager.ApplyNewEvent();
  353. CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_done_event_id);
  354. AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get());
  355. (*exec_order).push_back(eos_done_send);
  356. MS_LOG(INFO) << "EoS loop insert EoS done Send " << eos_done_send->fullname_with_scope();
  357. }
  358. void KernelAdjust::ProcessLoopSink(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  359. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  360. device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
  361. resource_manager.ResetResource();
  362. if (!NeedLoopSink()) {
  363. return;
  364. }
  365. if (kernel_graph_ptr->is_dynamic_shape()) {
  366. MS_LOG(INFO) << "KernelGraph:" << kernel_graph_ptr->graph_id() << " is dynamic shape, skip ProcessLoopSink";
  367. return;
  368. }
  369. bool exist_getnext = ExistGetNext(kernel_graph_ptr);
  370. bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX && exist_getnext;
  371. MS_LOG(INFO) << "GetNext exist:" << exist_getnext << " End of Sequence mode:" << eos_mode
  372. << " iter num:" << ConfigManager::GetInstance().iter_num();
  373. if (exist_getnext) {
  374. ReorderGetNext(kernel_graph_ptr);
  375. }
  376. auto switch_loop_input = kernel_graph_ptr->device_loop_control_params();
  377. const std::vector<CNodePtr> &orders = kernel_graph_ptr->execution_order();
  378. if (orders.empty()) {
  379. MS_LOG(EXCEPTION) << "graph " << kernel_graph_ptr->graph_id() << " execution order is empty";
  380. }
  381. std::vector<CNodePtr> exec_order;
  382. CNodePtr getnext_cnode;
  383. uint32_t getnext_switch_stream_id = UINT32_MAX;
  384. uint32_t fpbp_start_event_id = UINT32_MAX;
  385. uint32_t eos_start_event_id = UINT32_MAX;
  386. uint32_t getnext_stream_id = UINT32_MAX;
  387. size_t order_index = 0;
  388. if (exist_getnext) {
  389. InsertGetNextLoopStreamSwitch(kernel_graph_ptr, &exec_order, &getnext_switch_stream_id, &getnext_stream_id,
  390. switch_loop_input);
  391. SetBeforeGetNextStreamID(&exec_order, orders, &order_index, getnext_cnode, getnext_stream_id);
  392. InsertGetNextLoopFpBpStartSend(kernel_graph_ptr, &exec_order, &fpbp_start_event_id, getnext_stream_id);
  393. if (eos_mode) {
  394. InsertGetNextLoopEosStartSend(kernel_graph_ptr, &exec_order, &eos_start_event_id, getnext_stream_id);
  395. }
  396. }
  397. uint32_t eos_switch_stream_id = UINT32_MAX;
  398. uint32_t eos_stream_id = UINT32_MAX;
  399. uint32_t eos_done_event_id = UINT32_MAX;
  400. std::vector<uint32_t> fpbp_active_streams;
  401. if (eos_mode) {
  402. InsertEosStreamSwitch(kernel_graph_ptr, switch_loop_input, &exec_order, &eos_switch_stream_id, &eos_stream_id);
  403. InsertGetNextLoopEosStartRecv(kernel_graph_ptr, &exec_order, eos_start_event_id, eos_stream_id);
  404. InsertEosOp(kernel_graph_ptr, &exec_order, getnext_cnode, eos_stream_id);
  405. InsertEosDoneSend(kernel_graph_ptr, &exec_order, &eos_done_event_id, eos_stream_id);
  406. fpbp_active_streams.push_back(eos_switch_stream_id);
  407. }
  408. bool exist_independent = ExistIndependent(kernel_graph_ptr);
  409. if (exist_independent) {
  410. InsertIndepentParallel(kernel_graph_ptr, switch_loop_input, &exec_order);
  411. }
  412. uint32_t fpbp_stream_id = UINT32_MAX;
  413. uint32_t fpbp_switch_stream_id = UINT32_MAX;
  414. InsertFpBpLoopStreamSwitch(kernel_graph_ptr, switch_loop_input, &exec_order, &fpbp_stream_id, &fpbp_switch_stream_id);
  415. if (exist_getnext) {
  416. InsertFpBpStartRecv(kernel_graph_ptr, &exec_order, fpbp_start_event_id, fpbp_stream_id);
  417. }
  418. InsertNextLoopAssignAdd(kernel_graph_ptr, &exec_order, switch_loop_input, fpbp_stream_id);
  419. std::vector<CNodePtr> memcpy_list;
  420. std::vector<CNodePtr> other_list;
  421. if (exist_getnext) {
  422. CopyMemcpyList(kernel_graph_ptr, orders, order_index, &memcpy_list, &other_list);
  423. (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
  424. } else {
  425. other_list = orders;
  426. }
  427. if (eos_mode) {
  428. InsertEosDoneRecv(kernel_graph_ptr, &exec_order, eos_done_event_id, fpbp_stream_id);
  429. }
  430. std::vector<uint32_t> getnext_active_streams;
  431. if (exist_getnext) {
  432. // small loop active
  433. getnext_active_streams.push_back(getnext_switch_stream_id);
  434. InsertGetNextLoopStreamActive(kernel_graph_ptr, &exec_order, getnext_active_streams);
  435. }
  436. (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order));
  437. InsertCurrentLoopAssignAdd(kernel_graph_ptr, &exec_order, switch_loop_input);
  438. // big loop active
  439. fpbp_active_streams.push_back(fpbp_switch_stream_id);
  440. InsertFpBpAndEosLoopStreamActive(kernel_graph_ptr, &exec_order, fpbp_active_streams);
  441. kernel_graph_ptr->set_execution_order(exec_order);
  442. }
  443. kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBuilder(
  444. const std::vector<std::string> &formats, const std::vector<TypeId> &type_ids) {
  445. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
  446. selected_kernel_builder.SetInputsFormat(formats);
  447. selected_kernel_builder.SetInputsDeviceType(type_ids);
  448. selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
  449. selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
  450. selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
  451. return selected_kernel_builder;
  452. }
  453. CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  454. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
  455. StreamSwitchKind kind) {
  456. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
  457. {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
  458. auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
  459. auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
  460. std::vector<AnfNodePtr> inputs;
  461. inputs.push_back(NewValueNode(stream_switch));
  462. if (kind == kFpBpStreamSwitch || kind == kEosStreamSwitch) {
  463. inputs.push_back(switch_loop_input.at(kNextLoopCountName));
  464. } else if (kind == kGetNextStreamSwitch || kind == kIndependentStreamSwitch) {
  465. inputs.push_back(switch_loop_input.at(kNextLoopCountName));
  466. } else {
  467. MS_LOG(ERROR) << "unknown stream switch kind: " << kind;
  468. }
  469. inputs.push_back(switch_loop_input.at(kConstLoopNumInEpochName));
  470. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  471. CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs);
  472. MS_EXCEPTION_IF_NULL(stream_switch_app);
  473. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_switch_app.get());
  474. stream_switch_app->set_abstract(typeNone_abstract);
  475. // set attr: cond_ RT_LESS
  476. int condition = static_cast<int>(RT_LESS_OR_EQUAL);
  477. ValuePtr cond = MakeValue(condition);
  478. common::AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app);
  479. // set attr:data_type
  480. int data_type = static_cast<int>(RT_SWITCH_INT64);
  481. ValuePtr dt = MakeValue(data_type);
  482. common::AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app);
  483. // set distinction label and graph id
  484. return stream_switch_app;
  485. }
  486. CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  487. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
  488. {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
  489. abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>();
  490. auto stream_active_others = std::make_shared<Primitive>(kStreamActiveOpName);
  491. std::vector<AnfNodePtr> inputs;
  492. inputs.push_back(NewValueNode(stream_active_others));
  493. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  494. CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs);
  495. MS_EXCEPTION_IF_NULL(stream_active_others_app);
  496. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get());
  497. stream_active_others_app->set_abstract(typeNone_abstract);
  498. return stream_active_others_app;
  499. }
  500. CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  501. const CNodePtr &node, size_t output_idx) {
  502. auto idx = NewValueNode(SizeToLong(output_idx));
  503. MS_EXCEPTION_IF_NULL(idx);
  504. auto imm = std::make_shared<Int64Imm>(SizeToInt(output_idx));
  505. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  506. idx->set_abstract(abstract_scalar);
  507. CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
  508. MS_EXCEPTION_IF_NULL(tuple_getitem);
  509. tuple_getitem->set_scope(node->scope());
  510. std::vector<size_t> origin_shape = common::AnfAlgo::GetOutputInferShape(node, output_idx);
  511. TypeId origin_type = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
  512. common::AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
  513. return tuple_getitem;
  514. }
  515. CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  516. const CNodePtr &getnext_cnode) {
  517. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  518. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
  519. selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
  520. selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8});
  521. selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
  522. selected_kernel_builder.SetProcessor(kernel::Processor::AICPU);
  523. selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL);
  524. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  525. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8});
  526. // EndOfSequence
  527. auto end_of_sequence = std::make_shared<Primitive>(kEndOfSequence);
  528. std::vector<AnfNodePtr> inputs;
  529. inputs.push_back(NewValueNode(end_of_sequence));
  530. // GetNext output 0 is EndOfSequence's input
  531. auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0);
  532. inputs.push_back(tuple_get_item);
  533. CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs);
  534. MS_EXCEPTION_IF_NULL(end_of_sequence_node);
  535. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get());
  536. std::vector<std::string> input_names = {"x"};
  537. ValuePtr input_names_v = MakeValue(input_names);
  538. common::AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node);
  539. std::vector<std::string> output_names = {"y"};
  540. ValuePtr output_names_v = MakeValue(output_names);
  541. common::AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node);
  542. end_of_sequence_node->set_abstract(tuple_get_item->abstract());
  543. return end_of_sequence_node;
  544. }
  545. CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  546. const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
  547. bool cur_loop) {
  548. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  549. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
  550. {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
  551. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  552. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32});
  553. // AssignAdd
  554. auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
  555. std::vector<AnfNodePtr> inputs;
  556. inputs.push_back(NewValueNode(assign_add));
  557. if (cur_loop) {
  558. inputs.push_back(switch_loop_input.at(kCurLoopCountName));
  559. } else {
  560. inputs.push_back(switch_loop_input.at(kNextLoopCountName));
  561. }
  562. inputs.push_back(switch_loop_input.at(kConstOneName));
  563. CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs);
  564. MS_EXCEPTION_IF_NULL(assign_add_one);
  565. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_one.get());
  566. std::vector<std::string> input_names = {"ref", "value"};
  567. std::vector<std::string> output_names = {"output"};
  568. ValuePtr input_names_v = MakeValue(input_names);
  569. ValuePtr output_names_v = MakeValue(output_names);
  570. common::AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one);
  571. common::AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one);
  572. selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
  573. MS_EXCEPTION_IF_NULL(switch_loop_input.at(kCurLoopCountName));
  574. assign_add_one->set_abstract(switch_loop_input.at(kCurLoopCountName)->abstract());
  575. // add AssignAdd op to kernel ref node map
  576. session::AnfWithOutIndex final_pair = std::make_pair(assign_add_one, 0);
  577. session::KernelWithIndex kernel_with_index =
  578. common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(assign_add_one, 0), 0);
  579. kernel_graph_ptr->AddRefCorrespondPairs(final_pair, kernel_with_index);
  580. return assign_add_one;
  581. }
  582. #ifndef ENABLE_SECURITY
  583. void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) {
  584. if (!ascend::ProfilingManager::GetInstance().IsProfilingInitialized()) {
  585. MS_LOG(INFO) << "No need to profiling";
  586. return;
  587. }
  588. ProfilingTraceInfo profiling_trace_info = ProfilingUtils::GenerateProfilingTrace(*kernel_graph_ptr);
  589. if (!profiling_trace_info.IsValid()) {
  590. MS_LOG(INFO) << "[profiling] no profiling node found!";
  591. return;
  592. }
  593. InsertProfilingKernel(profiling_trace_info, kernel_graph_ptr);
  594. }
  595. void KernelAdjust::InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info,
  596. NotNull<session::KernelGraph *> kernel_graph_ptr) {
  597. MS_LOG(INFO) << "[profiling] Insert profiling kernel start";
  598. if (!profiling_trace_info.IsValid()) {
  599. MS_LOG(WARNING) << "Profiling trace point not found";
  600. return;
  601. }
  602. std::vector<CNodePtr> new_cnode_list;
  603. std::vector<CNodePtr> cnode_ptr_list = kernel_graph_ptr->execution_order();
  604. if (cnode_ptr_list.empty()) {
  605. MS_LOG(ERROR) << "No CNode in graph " << kernel_graph_ptr->graph_id();
  606. return;
  607. }
  608. for (const auto &cnode_ptr : cnode_ptr_list) {
  609. ProfilingUtils::InsertProfilingTraceFp(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
  610. NOT_NULL(&new_cnode_list));
  611. new_cnode_list.emplace_back(cnode_ptr);
  612. ProfilingUtils::InsertProfilingCustomOp(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
  613. NOT_NULL(&new_cnode_list));
  614. ProfilingUtils::InsertProfilingTraceBpEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
  615. NOT_NULL(&new_cnode_list));
  616. ProfilingUtils::InsertProfilingTraceIterEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
  617. NOT_NULL(&new_cnode_list));
  618. }
  619. kernel_graph_ptr->set_execution_order(new_cnode_list);
  620. }
  621. #endif
  622. CNodePtr KernelAdjust::CreateNPUGetFloatStatus(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  623. const CNodePtr &npu_alloc_cnode) {
  624. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  625. MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
  626. auto npu_get_primitive = std::make_shared<Primitive>(kNPUGetFloatStatusOpName);
  627. std::vector<AnfNodePtr> npu_get_inputs = {NewValueNode(npu_get_primitive), npu_alloc_cnode};
  628. auto npu_get_cnode = kernel_graph_ptr->NewCNode(npu_get_inputs);
  629. MS_EXCEPTION_IF_NULL(npu_get_cnode);
  630. npu_alloc_cnode->set_scope(kDefaultScope);
  631. npu_get_cnode->set_abstract(npu_alloc_cnode->abstract());
  632. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
  633. selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
  634. selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
  635. selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
  636. selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
  637. selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
  638. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  639. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
  640. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), npu_get_cnode.get());
  641. return npu_get_cnode;
  642. }
  643. CNodePtr KernelAdjust::CreateNPUClearStatus(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  644. const CNodePtr &npu_alloc_cnode) {
  645. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  646. MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
  647. auto npu_clear_primitive = std::make_shared<Primitive>(kNPUClearFloatStatusOpName);
  648. std::vector<AnfNodePtr> npu_clear_inputs = {NewValueNode(npu_clear_primitive), npu_alloc_cnode};
  649. auto npu_clear_cnode = kernel_graph_ptr->NewCNode(npu_clear_inputs);
  650. MS_EXCEPTION_IF_NULL(npu_clear_cnode);
  651. npu_alloc_cnode->set_scope(kDefaultScope);
  652. npu_clear_cnode->set_abstract(npu_alloc_cnode->abstract());
  653. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
  654. selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
  655. selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
  656. selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
  657. selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
  658. selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
  659. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  660. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
  661. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), npu_clear_cnode.get());
  662. return npu_clear_cnode;
  663. }
  664. CNodePtr KernelAdjust::CreateNPUAllocStatus(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  665. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  666. // create npu_alloc_cnode
  667. auto npu_alloc_primitive = std::make_shared<Primitive>(kNPUAllocFloatStatusOpName);
  668. std::vector<AnfNodePtr> npu_alloc_inputs = {NewValueNode(npu_alloc_primitive)};
  669. auto npu_alloc_cnode = kernel_graph_ptr->NewCNode(npu_alloc_inputs);
  670. MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
  671. npu_alloc_cnode->set_scope(kDefaultScope);
  672. std::vector<size_t> npu_output_shape = {kNPUShape};
  673. common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {npu_output_shape}, npu_alloc_cnode.get());
  674. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
  675. selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
  676. selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
  677. selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
  678. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  679. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
  680. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), npu_alloc_cnode.get());
  681. return npu_alloc_cnode;
  682. }
  683. CNodePtr KernelAdjust::CreateAssignAdd(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  684. const CNodePtr &npu_alloc_cnode, const AnfNodePtr &specify_para) {
  685. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  686. MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
  687. MS_EXCEPTION_IF_NULL(specify_para);
  688. auto assign_add_primitive = std::make_shared<Primitive>(kAssignAddOpName);
  689. std::vector<AnfNodePtr> assign_add_inputs = {NewValueNode(assign_add_primitive), specify_para, npu_alloc_cnode};
  690. auto assign_add_cnode = kernel_graph_ptr->NewCNode(assign_add_inputs);
  691. MS_EXCEPTION_IF_NULL(assign_add_cnode);
  692. assign_add_cnode->set_scope(kDefaultScope);
  693. assign_add_cnode->set_abstract(specify_para->abstract());
  694. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
  695. {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat32});
  696. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  697. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
  698. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_cnode.get());
  699. std::vector<std::string> input_names = {"ref", "value"};
  700. std::vector<std::string> output_names = {"output"};
  701. ValuePtr input_names_v = MakeValue(input_names);
  702. ValuePtr output_names_v = MakeValue(output_names);
  703. common::AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_cnode);
  704. common::AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_cnode);
  705. selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
  706. session::AnfWithOutIndex final_pair = std::make_pair(assign_add_cnode, 0);
  707. session::KernelWithIndex kernel_with_index =
  708. common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(assign_add_cnode, 0), 0);
  709. kernel_graph_ptr->AddRefCorrespondPairs(final_pair, kernel_with_index);
  710. return assign_add_cnode;
  711. }
  712. CNodePtr KernelAdjust::CreateAssign(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  713. const AnfNodePtr &specify_para) {
  714. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  715. MS_EXCEPTION_IF_NULL(specify_para);
  716. std::vector<float> reset(kNPUShape, 0.0);
  717. ShapeVector reset_shape({static_cast<int64_t>(kNPUShape)});
  718. auto shp_buf_size = sizeof(float) * reset.size();
  719. auto reset_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, reset_shape, reset.data(), shp_buf_size);
  720. auto reset_value_node = std::make_shared<ValueNode>(reset_tensor);
  721. MS_EXCEPTION_IF_NULL(reset_value_node);
  722. reset_value_node->set_abstract(specify_para->abstract());
  723. kernel_graph_ptr->AddValueNodeToGraph(reset_value_node);
  724. auto kernel_info = std::make_shared<device::KernelInfo>();
  725. MS_EXCEPTION_IF_NULL(kernel_info);
  726. reset_value_node->set_kernel_info(kernel_info);
  727. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
  728. builder1.SetOutputsFormat({kOpFormat_DEFAULT});
  729. builder1.SetOutputsDeviceType({kNumberTypeFloat32});
  730. AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), reset_value_node.get());
  731. auto assign_primitive = std::make_shared<Primitive>(kAssignOpName);
  732. std::vector<AnfNodePtr> assign_inputs = {NewValueNode(assign_primitive), specify_para, reset_value_node};
  733. auto assign_cnode = kernel_graph_ptr->NewCNode(assign_inputs);
  734. MS_EXCEPTION_IF_NULL(assign_cnode);
  735. assign_cnode->set_scope(kDefaultScope);
  736. assign_cnode->set_abstract(specify_para->abstract());
  737. kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
  738. {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat32});
  739. selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
  740. selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
  741. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_cnode.get());
  742. std::vector<std::string> input_names = {"ref", "value"};
  743. std::vector<std::string> output_names = {"output"};
  744. ValuePtr input_names_v = MakeValue(input_names);
  745. ValuePtr output_names_v = MakeValue(output_names);
  746. common::AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_cnode);
  747. common::AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_cnode);
  748. selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
  749. session::AnfWithOutIndex final_pair = std::make_pair(assign_cnode, 0);
  750. session::KernelWithIndex kernel_with_index =
  751. common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(assign_cnode, 0), 0);
  752. kernel_graph_ptr->AddRefCorrespondPairs(final_pair, kernel_with_index);
  753. return assign_cnode;
  754. }
  755. void KernelAdjust::InsertOverflowCheckOperations(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  756. MS_LOG(INFO) << "Start Insert Overflow Check Operations.";
  757. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  758. auto parameters = kernel_graph_ptr->parameters();
  759. AnfNodePtr specify_para;
  760. bool not_find = true;
  761. for (size_t i = 0; i < parameters.size(); i++) {
  762. auto para_fullname = parameters[i]->fullname_with_scope();
  763. if (para_fullname.find(kSpecifyParameter) != std::string::npos) {
  764. not_find = false;
  765. specify_para = parameters[i];
  766. break;
  767. }
  768. }
  769. if (not_find) {
  770. MS_LOG(INFO) << "Not find parameter named " << kSpecifyParameter;
  771. return;
  772. }
  773. bool first_grad_op = true;
  774. CNodePtr npu_alloc_cnode;
  775. std::vector<CNodePtr> new_execution_order;
  776. auto execution_order = kernel_graph_ptr->execution_order();
  777. for (size_t i = 0; i < execution_order.size() - 1; i++) {
  778. new_execution_order.push_back(execution_order[i]);
  779. auto cur_full_name = execution_order[i]->fullname_with_scope();
  780. auto next_full_name = execution_order[i + 1]->fullname_with_scope();
  781. auto cur_stream_id = AnfAlgo::GetStreamId(execution_order[i]);
  782. auto next_stream_id = AnfAlgo::GetStreamId(execution_order[i + 1]);
  783. if (cur_full_name.find(kGradients) == std::string::npos && next_full_name.find(kGradients) != std::string::npos) {
  784. if (first_grad_op) {
  785. npu_alloc_cnode = CreateNPUAllocStatus(kernel_graph_ptr);
  786. auto npu_clear_cnode = CreateNPUClearStatus(kernel_graph_ptr, npu_alloc_cnode);
  787. auto assign_cnode = CreateAssign(kernel_graph_ptr, specify_para);
  788. AnfAlgo::SetStreamId(next_stream_id, npu_alloc_cnode.get());
  789. AnfAlgo::SetStreamId(next_stream_id, npu_clear_cnode.get());
  790. AnfAlgo::SetStreamId(next_stream_id, assign_cnode.get());
  791. new_execution_order.push_back(npu_alloc_cnode);
  792. new_execution_order.push_back(npu_clear_cnode);
  793. new_execution_order.push_back(assign_cnode);
  794. first_grad_op = false;
  795. } else {
  796. auto npu_clear_cnode = CreateNPUClearStatus(kernel_graph_ptr, npu_alloc_cnode);
  797. AnfAlgo::SetStreamId(next_stream_id, npu_clear_cnode.get());
  798. new_execution_order.push_back(npu_clear_cnode);
  799. }
  800. }
  801. if (cur_full_name.find(kGradients) != std::string::npos && next_full_name.find(kGradients) == std::string::npos) {
  802. auto npu_get_cnode = CreateNPUGetFloatStatus(kernel_graph_ptr, npu_alloc_cnode);
  803. auto assign_add_cnode = CreateAssignAdd(kernel_graph_ptr, npu_alloc_cnode, specify_para);
  804. AnfAlgo::SetStreamId(cur_stream_id, npu_get_cnode.get());
  805. AnfAlgo::SetStreamId(cur_stream_id, assign_add_cnode.get());
  806. new_execution_order.push_back(npu_get_cnode);
  807. new_execution_order.push_back(assign_add_cnode);
  808. }
  809. if (i == execution_order.size() - kLastHandleDiff) {
  810. new_execution_order.push_back(execution_order[i + 1]);
  811. if (next_full_name.find(kGradients) != std::string::npos) {
  812. auto npu_get_cnode = CreateNPUGetFloatStatus(kernel_graph_ptr, npu_alloc_cnode);
  813. auto assign_add_cnode = CreateAssignAdd(kernel_graph_ptr, npu_alloc_cnode, specify_para);
  814. AnfAlgo::SetStreamId(cur_stream_id, npu_get_cnode.get());
  815. AnfAlgo::SetStreamId(cur_stream_id, assign_add_cnode.get());
  816. new_execution_order.push_back(npu_get_cnode);
  817. new_execution_order.push_back(assign_add_cnode);
  818. }
  819. }
  820. }
  821. kernel_graph_ptr->set_execution_order(new_execution_order);
  822. }
  823. // device loop control
  824. std::shared_ptr<Tensor> KernelAdjust::CreateTensor(int32_t initial_value) {
  825. ShapeVector shp = {1};
  826. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
  827. MS_EXCEPTION_IF_NULL(tensor);
  828. auto val = static_cast<int32_t *>(tensor->data_c());
  829. MS_EXCEPTION_IF_NULL(val);
  830. *val = initial_value;
  831. return tensor;
  832. }
  833. std::shared_ptr<Parameter> KernelAdjust::CreateParameter(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  834. const string parameter_name) {
  835. ShapeVector shp = {1};
  836. tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
  837. MS_EXCEPTION_IF_NULL(tensor_ptr);
  838. mindspore::abstract::AbstractBasePtr parameter_abstract_ptr = tensor_ptr->ToAbstract();
  839. if (parameter_abstract_ptr == nullptr) {
  840. MS_LOG(EXCEPTION) << "Create abstract for device loop control failed!";
  841. }
  842. ParameterPtr param = std::make_shared<Parameter>(kernel_graph_ptr);
  843. MS_EXCEPTION_IF_NULL(param);
  844. param->set_name(parameter_name);
  845. param->set_abstract(parameter_abstract_ptr);
  846. ParameterPtr graph_parameter = kernel_graph_ptr->NewParameter(param);
  847. return graph_parameter;
  848. }
  849. void KernelAdjust::InsertDeviceLoopCtrl(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  850. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  851. std::map<std::string, tensor::TensorPtr> device_loop_ctrl_tensors;
  852. std::map<std::string, mindspore::ParameterPtr> device_loop_ctrl_params;
  853. // current loop count
  854. device_loop_ctrl_tensors[kCurLoopCountName] = CreateTensor(0);
  855. device_loop_ctrl_params[kCurLoopCountName] = CreateParameter(kernel_graph_ptr, kCurLoopCountName);
  856. // next loop count tensor
  857. device_loop_ctrl_tensors[kNextLoopCountName] = CreateTensor(0);
  858. device_loop_ctrl_params[kNextLoopCountName] = CreateParameter(kernel_graph_ptr, kNextLoopCountName);
  859. // current epoch count tensor
  860. device_loop_ctrl_tensors[kCurEpochCountName] = CreateTensor(0);
  861. device_loop_ctrl_params[kCurEpochCountName] = CreateParameter(kernel_graph_ptr, kCurEpochCountName);
  862. // constant one tensor
  863. device_loop_ctrl_tensors[kConstOneName] = CreateTensor(1);
  864. device_loop_ctrl_params[kConstOneName] = CreateParameter(kernel_graph_ptr, kConstOneName);
  865. // constant loop num in epoch tensor
  866. int32_t initial_value = 0;
  867. if (NeedLoopSink()) {
  868. // iter_num minus one because the device side counts from 0
  869. initial_value = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num() - 1));
  870. } else {
  871. MS_LOG(INFO) << "Tensor const_loop_num_in_epoch only used in loop sink mode.";
  872. initial_value = 0;
  873. }
  874. MS_LOG(INFO) << "Loop num in epoch is " << initial_value;
  875. device_loop_ctrl_tensors[kConstLoopNumInEpochName] = CreateTensor(initial_value);
  876. device_loop_ctrl_params[kConstLoopNumInEpochName] = CreateParameter(kernel_graph_ptr, kConstLoopNumInEpochName);
  877. kernel_graph_ptr->set_device_loop_ctrl_tensors(device_loop_ctrl_tensors);
  878. kernel_graph_ptr->set_device_loop_ctrl_params(device_loop_ctrl_params);
  879. }
  880. void KernelAdjust::AssignLoopCtrlTensorMem(const session::KernelGraph &kernel_graph, KernelRuntime *runtime_instance,
  881. const string name) {
  882. MS_EXCEPTION_IF_NULL(runtime_instance);
  883. auto device_loop_control_params = kernel_graph.device_loop_control_params();
  884. if (!device_loop_control_params.count(name)) {
  885. MS_LOG(WARNING) << "Can't find Device Loop Control Parameter " << name;
  886. return;
  887. }
  888. auto param = device_loop_control_params.at(name);
  889. MS_EXCEPTION_IF_NULL(param);
  890. DeviceAddressPtr device_address = nullptr;
  891. if (AnfAlgo::OutputAddrExist(param, 0)) {
  892. device_address = AnfAlgo::GetMutableOutputAddr(param, 0);
  893. MS_EXCEPTION_IF_NULL(device_address);
  894. } else {
  895. MS_LOG(INFO) << "Device Loop Control Parameter " << name << " have no address, allocating...";
  896. auto size = AnfAlgo::GetOutputTensorMemSize(param, 0);
  897. auto format = AnfAlgo::GetOutputFormat(param, 0);
  898. auto type_id = AnfAlgo::GetOutputDeviceDataType(param, 0);
  899. auto ms_context = MsContext::GetInstance();
  900. MS_EXCEPTION_IF_NULL(ms_context);
  901. auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  902. device_address =
  903. std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, size, format, type_id, kAscendDevice, device_id);
  904. device_address->set_is_ptr_persisted(true);
  905. if (runtime_instance->MallocMem(kStaticMem, size, device_address) == nullptr) {
  906. MS_LOG(EXCEPTION) << "Cannot alloc static memory for device loop control parameter " << name
  907. << " , tensor size is : " << size;
  908. }
  909. MS_EXCEPTION_IF_NULL(device_address);
  910. AnfAlgo::SetOutputAddr(device_address, 0, param.get());
  911. }
  912. auto device_loop_control_tensors = kernel_graph.device_loop_control_tensors();
  913. auto tensor = device_loop_control_tensors.at(name);
  914. MS_EXCEPTION_IF_NULL(tensor);
  915. tensor->set_device_address(device_address);
  916. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(param, 0), LongToSize(tensor->data().nbytes()),
  917. tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_)) {
  918. MS_LOG(EXCEPTION) << "SyncHostToDevice failed for device loop control parameter " << name;
  919. }
  920. }
  921. void KernelAdjust::AssignLoopCtrlMemory(const session::KernelGraph &kernel_graph_ptr) {
  922. auto device_loop_control_tensors = kernel_graph_ptr.device_loop_control_tensors();
  923. if (device_loop_control_tensors.empty()) {
  924. return;
  925. }
  926. MS_LOG(INFO) << "Assign device loop control memory";
  927. auto ms_context = MsContext::GetInstance();
  928. MS_EXCEPTION_IF_NULL(ms_context);
  929. auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  930. auto runtime_instance = KernelRuntimeManager::Instance().GetSingleKernelRuntime(kAscendDevice, device_id);
  931. MS_EXCEPTION_IF_NULL(runtime_instance);
  932. AssignLoopCtrlTensorMem(kernel_graph_ptr, runtime_instance, kCurLoopCountName);
  933. AssignLoopCtrlTensorMem(kernel_graph_ptr, runtime_instance, kNextLoopCountName);
  934. AssignLoopCtrlTensorMem(kernel_graph_ptr, runtime_instance, kCurEpochCountName);
  935. AssignLoopCtrlTensorMem(kernel_graph_ptr, runtime_instance, kConstOneName);
  936. AssignLoopCtrlTensorMem(kernel_graph_ptr, runtime_instance, kConstLoopNumInEpochName);
  937. }
  938. void KernelAdjust::SetDeviceLoopCtrlTensor(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
  939. const std::string name, int32_t value) {
  940. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  941. auto device_loop_control_tensors = kernel_graph_ptr->device_loop_control_tensors();
  942. if (!device_loop_control_tensors.count(name)) {
  943. MS_LOG(WARNING) << "Can't find Device Loop Control Tensor " << name;
  944. return;
  945. }
  946. auto tensor = device_loop_control_tensors.at(name);
  947. MS_EXCEPTION_IF_NULL(tensor);
  948. auto *cur_val = static_cast<int32_t *>(tensor->data_c());
  949. MS_EXCEPTION_IF_NULL(cur_val);
  950. *cur_val = value;
  951. tensor->set_sync_status(kNeedSyncHostToDevice);
  952. auto device_address = tensor->device_address();
  953. MS_EXCEPTION_IF_NULL(device_address);
  954. if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
  955. tensor->data_c(), tensor->device_info().host_format_)) {
  956. MS_LOG(EXCEPTION) << "SyncHostToDevice failed for device loop control parameter " << name;
  957. }
  958. }
  959. void KernelAdjust::LoadDeviceLoopCtrlParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
  960. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  961. auto device_loop_control_tensors = kernel_graph_ptr->device_loop_control_tensors();
  962. if (device_loop_control_tensors.empty()) {
  963. return;
  964. }
  965. MS_LOG(INFO) << "Load device loop control data";
  966. SetDeviceLoopCtrlTensor(kernel_graph_ptr, kCurLoopCountName, 0);
  967. SetDeviceLoopCtrlTensor(kernel_graph_ptr, kNextLoopCountName, 0);
  968. #ifndef ENABLE_SECURITY
  969. SetDeviceLoopCtrlTensor(kernel_graph_ptr, kCurEpochCountName,
  970. SizeToInt(DumpJsonParser::GetInstance().cur_dump_iter()));
  971. #else
  972. SetDeviceLoopCtrlTensor(kernel_graph_ptr, kCurEpochCountName, 0);
  973. #endif
  974. kernel_graph_ptr->set_current_epoch(kernel_graph_ptr->current_epoch() + 1);
  975. }
  976. } // namespace device
  977. } // namespace mindspore