You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 108 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/graph_scheduler/graph_scheduler.h"
  17. #include <queue>
  18. #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
  19. #include "runtime/graph_scheduler/actor/debug_actor.h"
  20. #include "runtime/graph_scheduler/actor/recorder_actor.h"
  21. #include "runtime/hardware/device_context_manager.h"
  22. #include "mindrt/src/actor/actormgr.h"
  23. #include "mindrt/include/async/async.h"
  24. #include "backend/common/session/anf_runtime_algorithm.h"
  25. #include "include/common/utils/anfalgo.h"
  26. #include "backend/common/optimizer/helper.h"
  27. #include "utils/anf_utils.h"
  28. #include "include/common/utils/config_manager.h"
  29. #include "utils/log_adapter.h"
  30. #include "include/common/utils/convert_utils.h"
  31. #include "utils/ms_context.h"
  32. #include "utils/profile.h"
  33. #if !defined(_WIN32) && !defined(_WIN64)
  34. #include "include/common/utils/signal_util.h"
  35. #endif
  36. #ifndef ENABLE_SECURITY
  37. #include "debug/data_dump/dump_json_parser.h"
  38. #endif
  39. #ifdef ENABLE_DUMP_IR
  40. #include "include/common/debug/rdr/recorder_manager.h"
  41. #endif
  42. #ifdef ENABLE_DEBUGGER
  43. #include "debug/debugger/debugger.h"
  44. #endif
  45. #include "profiler/device/profiling.h"
  46. #include "include/common/debug/common.h"
  47. #include "runtime/recovery/recovery_context.h"
  48. namespace mindspore {
  49. namespace runtime {
  50. using recovery::RecoveryContext;
  51. namespace {
  52. bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
  53. MS_EXCEPTION_IF_NULL(from_device_context);
  54. MS_EXCEPTION_IF_NULL(to_device_context);
  55. if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
  56. return false;
  57. } else {
  58. return true;
  59. }
  60. }
  61. inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
  62. MS_EXCEPTION_IF_NULL(actor_set);
  63. return actor_set->kernel_actors_.size() == 1;
  64. }
  65. bool IsTakenOverByControlFlow(const AnfNodePtr &front_node, const KernelGraphPtr &graph,
  66. const ControlNodeParserPtr &parser) {
  67. MS_EXCEPTION_IF_NULL(front_node);
  68. MS_EXCEPTION_IF_NULL(graph);
  69. if (common::AnfAlgo::IsCallNode(front_node)) {
  70. return true;
  71. }
  72. if (parser != nullptr && parser->IsInited() && (!parser->IsSameKernelGraphGroup(front_node, graph))) {
  73. return true;
  74. }
  75. return false;
  76. }
  77. // Convert the actors vector by the actor set.
  78. std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
  79. MS_EXCEPTION_IF_NULL(actor_set);
  80. std::vector<AbstractActorPtr> actors;
  81. if (actor_set->data_prepare_actor_ != nullptr) {
  82. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
  83. }
  84. for (auto &data_source_actor : actor_set->data_source_actors_) {
  85. MS_EXCEPTION_IF_NULL(data_source_actor);
  86. (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
  87. }
  88. for (auto &custom_actor : actor_set->custom_actors_) {
  89. MS_EXCEPTION_IF_NULL(custom_actor);
  90. (void)actors.emplace_back(static_cast<AbstractActorPtr>(custom_actor));
  91. }
  92. for (auto &kernel_actor : actor_set->kernel_actors_) {
  93. MS_EXCEPTION_IF_NULL(kernel_actor);
  94. (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
  95. }
  96. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  97. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  98. (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
  99. }
  100. for (auto &copy_actor : actor_set->copy_actors_) {
  101. MS_EXCEPTION_IF_NULL(copy_actor);
  102. (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
  103. }
  104. if (actor_set->loop_count_actor_ != nullptr) {
  105. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
  106. }
  107. if (actor_set->output_actor_ != nullptr) {
  108. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
  109. }
  110. if (actor_set->control_actors_ != nullptr) {
  111. const auto &control_actor_set = actor_set->control_actors_;
  112. for (auto &switch_actor : control_actor_set->switch_actors_) {
  113. MS_EXCEPTION_IF_NULL(switch_actor);
  114. (void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
  115. }
  116. for (auto &gather_actor : control_actor_set->gather_actors_) {
  117. MS_EXCEPTION_IF_NULL(gather_actor);
  118. (void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
  119. }
  120. for (auto &entrance_actor : control_actor_set->entrance_actors_) {
  121. MS_EXCEPTION_IF_NULL(entrance_actor);
  122. (void)actors.emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
  123. }
  124. for (auto &exit_actor : control_actor_set->exit_actors_) {
  125. MS_EXCEPTION_IF_NULL(exit_actor);
  126. (void)actors.emplace_back(static_cast<AbstractActorPtr>(exit_actor));
  127. }
  128. for (auto &stack_actor : control_actor_set->stack_actors_) {
  129. MS_EXCEPTION_IF_NULL(stack_actor);
  130. (void)actors.emplace_back(static_cast<AbstractActorPtr>(stack_actor));
  131. }
  132. }
  133. return actors;
  134. }
  135. void ClearNodeInfo(const KernelGraphPtr &graph) {
  136. MS_EXCEPTION_IF_NULL(graph);
  137. // Clear input parameter device tensor and device tensor store.
  138. for (const auto &input_node : graph->input_nodes()) {
  139. MS_EXCEPTION_IF_NULL(input_node);
  140. if (!input_node->isa<Parameter>()) {
  141. continue;
  142. }
  143. auto parameter = input_node->cast<ParameterPtr>();
  144. MS_EXCEPTION_IF_NULL(parameter);
  145. parameter->DecreaseUsedGraphCount();
  146. // Only the parameter has no graph used, then clear the device tensor.
  147. if (parameter->used_graph_count() != 0) {
  148. continue;
  149. }
  150. auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
  151. DeviceTensorStore::GetInstance().Remove(front_input_node.get());
  152. size_t output_num = common::AnfAlgo::GetOutputTensorNum(input_node);
  153. for (size_t index = 0; index < output_num; ++index) {
  154. if (AnfAlgo::OutputAddrExist(input_node, index)) {
  155. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  156. }
  157. }
  158. }
  159. // Clear input value node device tensor and device tensor store.
  160. for (const auto &value_node : graph->graph_value_nodes()) {
  161. auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
  162. DeviceTensorStore::GetInstance().Remove(front_value_node.get());
  163. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  164. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  165. }
  166. }
  167. // Clear cnode device tensor.
  168. for (const auto &cnode : graph->execution_order()) {
  169. size_t output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
  170. for (size_t index = 0; index < output_num; ++index) {
  171. if (AnfAlgo::OutputAddrExist(cnode, index)) {
  172. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  173. }
  174. }
  175. }
  176. }
  177. #if !defined(_WIN32) && !defined(_WIN64)
  178. void IntHandler(int, siginfo_t *, void *) {
  179. int this_pid = getpid();
  180. MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
  181. (void)kill(this_pid, SIGTERM);
  182. }
  183. #endif
  184. } // namespace
  185. GraphScheduler &GraphScheduler::GetInstance() noexcept {
  186. static GraphScheduler instance{};
  187. return instance;
  188. }
  189. void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs,
  190. const std::vector<AnfNodePtr> &root_graph_parameters,
  191. const ControlNodeParserPtr &parser) noexcept {
  192. // Terminate the actors of actor info.
  193. if (actors_.count(actor_info) > 0) {
  194. auto actor_manager = ActorMgr::GetActorMgrRef();
  195. if (actor_manager == nullptr) {
  196. MS_LOG(ERROR) << "Actor manager is not exist.";
  197. return;
  198. }
  199. auto actor_set = actors_[actor_info];
  200. auto base_actors = CollectActors(actor_set.get());
  201. for (auto &base_actor : base_actors) {
  202. MS_EXCEPTION_IF_NULL(base_actor);
  203. EraseActor(base_actor->GetAID().Name());
  204. actor_manager->Terminate(base_actor->GetAID());
  205. }
  206. }
  207. // Clear device tensor and device tensor store.
  208. for (auto &graph : graphs) {
  209. ClearNodeInfo(graph);
  210. }
  211. if (parser != nullptr && parser->IsInited()) {
  212. const auto &front_value_nodes = parser->front_value_nodes();
  213. for (const auto &front_value_node : front_value_nodes) {
  214. const auto &node = front_value_node.first.first;
  215. size_t index = front_value_node.first.second;
  216. if (AnfAlgo::OutputAddrExist(node, index)) {
  217. AnfAlgo::SetOutputAddr(nullptr, index, node.get());
  218. }
  219. }
  220. }
  221. // Clear the member of DeviceTensorStore.
  222. for (auto &root_graph_parameter : root_graph_parameters) {
  223. DeviceTensorStore::GetInstance().Remove(root_graph_parameter.get());
  224. }
  225. // Clear global maps of actor info.
  226. (void)actors_.erase(actor_info);
  227. }
  228. void GraphScheduler::Clear() {
  229. // Terminate all actors.
  230. auto actor_manager = ActorMgr::GetActorMgrRef();
  231. MS_EXCEPTION_IF_NULL(actor_manager);
  232. actor_manager->Finalize();
  233. // Clear the member of DeviceTensorStore.
  234. DeviceTensorStore::GetInstance().Clear();
  235. // Clear global maps.
  236. actors_.clear();
  237. ClearAllActors();
  238. }
  239. void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
  240. MS_EXCEPTION_IF_NULL(actor_set);
  241. // Clear the member of DeviceTensorCopyStore.
  242. DeviceTensorCopyStore::GetInstance().Clear();
  243. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  244. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  245. super_kernel_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
  246. }
  247. control_node_scheduler_.ClearActorData(actor_set->control_actors_.get());
  248. // At the end of the step, the op data sent to the stack actor in each actor should be clear.
  249. auto total_actors = CollectActors(actor_set);
  250. for (auto &actor : total_actors) {
  251. MS_EXCEPTION_IF_NULL(actor);
  252. actor->to_stack_data_.clear();
  253. }
  254. }
  255. using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
  256. const KernelWithIndex &, const KernelGraphPtr &);
  257. static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;
  258. void GraphScheduler::Initialize() {
  259. if (init_) {
  260. return;
  261. }
  262. init_ = true;
  263. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
  264. &GraphScheduler::LinkDataArrowForBaseActor);
  265. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
  266. &GraphScheduler::LinkDataArrowForHostDSActor);
  267. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
  268. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSuperKernelActor,
  269. &GraphScheduler::LinkDataArrowForBaseActor);
  270. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
  271. &GraphScheduler::LinkDataArrowForDeviceTensorStore);
  272. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
  273. &GraphScheduler::LinkDataArrowForInternalParameter);
  274. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSendActor, &GraphScheduler::LinkDataArrowForBaseActor);
  275. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kRecvActor, &GraphScheduler::LinkDataArrowForBaseActor);
  276. // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
  277. size_t actor_thread_num = 0;
  278. size_t actor_and_kernel_thread_num = 0;
  279. ComputeThreadNums(&actor_thread_num, &actor_and_kernel_thread_num);
  280. auto actor_manager = ActorMgr::GetActorMgrRef();
  281. MS_EXCEPTION_IF_NULL(actor_manager);
  282. auto ret = actor_manager->Initialize(true, actor_thread_num, actor_and_kernel_thread_num);
  283. if (ret != MINDRT_OK) {
  284. MS_LOG(EXCEPTION) << "Actor manager init failed.";
  285. }
  286. common::SetOMPThreadNum();
  287. MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
  288. << ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num);
  289. #ifdef ENABLE_RPC_ACTOR
  290. // Create and initialize RpcNodeScheduler.
  291. rpc_node_scheduler_ = std::make_unique<RpcNodeScheduler>();
  292. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  293. rpc_node_scheduler_->Initialize();
  294. #endif
  295. BuildAndScheduleGlobalActor();
  296. }
  297. void GraphScheduler::BuildAndScheduleGlobalActor() {
  298. auto actor_manager = ActorMgr::GetActorMgrRef();
  299. MS_EXCEPTION_IF_NULL(actor_manager);
  300. // Create and schedule memory manager actor.
  301. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  302. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  303. memory_manager_aid_ = memory_manager_actor->GetAID();
  304. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  305. // Bind single thread to response to memory alloc and free quickly.
  306. (void)actor_manager->Spawn(base_actor, false);
  307. // Create and schedule recorder actor.
  308. auto recorder_actor = std::make_shared<RecorderActor>();
  309. MS_EXCEPTION_IF_NULL(recorder_actor);
  310. recorder_aid_ = &(recorder_actor->GetAID());
  311. auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
  312. (void)actor_manager->Spawn(base_recorder_actor, true);
  313. // Create and schedule debug actor.
  314. // debugger_actor_need is true for CPU when e2e dump is enabled and for Ascend and GPU is true when debugger or dump
  315. // is enabled.
  316. #ifndef ENABLE_SECURITY
  317. bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
  318. #endif
  319. #ifdef ENABLE_DEBUGGER
  320. if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
  321. debugger_actor_need = true;
  322. }
  323. #endif
  324. #ifndef ENABLE_SECURITY
  325. if (debugger_actor_need) {
  326. auto debug_actor = std::make_shared<DebugActor>();
  327. MS_EXCEPTION_IF_NULL(debug_actor);
  328. debug_aid_ = &(debug_actor->GetAID());
  329. auto base_debug_actor = static_cast<ActorReference>(debug_actor);
  330. (void)actor_manager->Spawn(base_debug_actor, true);
  331. }
  332. #endif
  333. }
  334. ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
  335. struct ScopeCleaner {
  336. GraphScheduler *const scheduler_;
  337. explicit ScopeCleaner(GraphScheduler *scheduler) : scheduler_(scheduler) {}
  338. ~ScopeCleaner() {
  339. // Local maps and vectors clear.
  340. if (scheduler_ == nullptr) {
  341. return;
  342. }
  343. scheduler_->graph_output_to_actor_.clear();
  344. scheduler_->copy_actors_.clear();
  345. }
  346. };
  347. // cppcheck-suppress unreadVariable
  348. ScopeCleaner cleaner(this);
  349. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
  350. if (graph_compiler_info.graphs_.size() == 0) {
  351. MS_LOG(EXCEPTION) << "The number of graphs is zero.";
  352. }
  353. if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
  354. MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
  355. }
  356. PersistDeviceTensor(graph_compiler_info);
  357. const auto &actor_set = Build(graph_compiler_info);
  358. MS_EXCEPTION_IF_NULL(actor_set);
  359. CacheGraphOutputToActor(graph_compiler_info);
  360. Link(actor_set.get(), graph_compiler_info);
  361. Optimize(actor_set.get());
  362. DumpActor(actor_set.get(), graph_compiler_info);
  363. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  364. CheckActorValid(actor_set.get());
  365. }
  366. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
  367. return actor_set.get();
  368. }
  369. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  370. MS_EXCEPTION_IF_NULL(actor_set);
  371. auto actors = CollectActors(actor_set);
  372. // Schedule actors.
  373. auto actor_manager = ActorMgr::GetActorMgrRef();
  374. MS_EXCEPTION_IF_NULL(actor_manager);
  375. for (auto actor : actors) {
  376. (void)actor_manager->Spawn(actor);
  377. }
  378. #ifdef ENABLE_RPC_ACTOR
  379. // Build physical connections in 'RpcNodeScheduler::Schedule()' method. This costs some time.
  380. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  381. rpc_node_scheduler_->Schedule();
  382. #endif
  383. }
  384. void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceContext *> &device_contexts,
  385. const std::vector<std::vector<TensorPtr>> &input_tensors,
  386. const std::vector<TensorPtr> &input_tensors_with_value_node, GraphExecutionStrategy strategy) {
  387. MS_EXCEPTION_IF_NULL(actor_set);
  388. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  389. #if !defined(_WIN32) && !defined(_WIN64)
  390. SignalGuard sg(IntHandler);
  391. #endif
  392. // Construct OpContext.
  393. OpContext<DeviceTensor> op_context;
  394. std::vector<Promise<int>> result(1);
  395. op_context.sequential_num_ = RandInt::Instance().Get();
  396. op_context.results_ = &result;
  397. #ifdef ENABLE_RPC_ACTOR
  398. // Set OpContext to rpc node scheduler.
  399. auto op_context_setter = std::make_shared<RpcActorOpContextSetter>(rpc_node_scheduler_.get(), &op_context);
  400. MS_EXCEPTION_IF_NULL(op_context_setter);
  401. #endif
  402. if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  403. actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context, GraphExecutionStrategy::kStep);
  404. MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
  405. actor_set->kernel_actors_[0]->RunOpControlWithInputTensor(nullptr, &op_context, &input_tensors_with_value_node);
  406. return;
  407. }
  408. // Trigger data prepare actor running.
  409. MS_EXCEPTION_IF_NULL(ActorMgr::GetActorMgrRef());
  410. auto thread_pool = ActorMgr::GetActorMgrRef()->GetActorThreadPool();
  411. MS_EXCEPTION_IF_NULL(thread_pool);
  412. ActorDispatcher::is_multi_thread_execution(actor_set->is_multi_thread_execution_);
  413. double start_time = GetTime();
  414. ActorDispatcher::Send(actor_set->data_prepare_actor_->GetAID(), &DataPrepareActor::PrepareData, input_tensors,
  415. &op_context, GraphExecutionStrategy::kPipeline);
  416. // Get the run result.
  417. auto result_future = result[0].GetFuture();
  418. result_future.Wait();
  419. MsException::Instance().CheckException();
  420. if (!result_future.IsOK()) {
  421. #ifdef ENABLE_DUMP_IR
  422. mindspore::RDR::TriggerAll();
  423. #endif
  424. // When temporary variable 'op_context' has beed set failed status, the main thread need wait other threads until
  425. // they finish respective task, otherwise segmentation fault will happen when these task access 'op_context',
  426. // because it has been destroyed.
  427. std::mutex mutex;
  428. std::unique_lock<std::mutex> locker(mutex);
  429. std::condition_variable thread_blocker;
  430. const int64_t kTimeToWait = 2;
  431. (void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait));
  432. // May set exception in the wait time, need throw the exception to avoid affecting the next execution.
  433. MsException::Instance().CheckException();
  434. MS_LOG(EXCEPTION) << op_context.error_info_;
  435. }
  436. double end_time = GetTime();
  437. const size_t kSecondsToMilliseconds = 1000;
  438. SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
  439. if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
  440. MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
  441. if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
  442. MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
  443. }
  444. MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
  445. RecoveryContext::GetInstance()->set_need_reinit_collective(false);
  446. }
  447. }
  448. void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
  449. double execution_time) const {
  450. MS_EXCEPTION_IF_NULL(actor_set);
  451. MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_);
  452. ++actor_set->execution_count_;
  453. MS_LOG(DEBUG) << "Execution count: " << actor_set->execution_count_ << ", execution time cost: " << execution_time
  454. << " ms in multi thread or not: " << actor_set->is_multi_thread_execution_ << ".";
  455. #if defined(_WIN32) || defined(_WIN64)
  456. return;
  457. #endif
  458. // The step mode uses the default multi thread.
  459. if (strategy == GraphExecutionStrategy::kStep) {
  460. return;
  461. }
  462. // The constraint condition of not supporting the single thread execution.
  463. if ((actor_set->control_actors_ != nullptr) || (actor_set->copy_actors_.size() > 0) ||
  464. (actor_set->super_kernel_actors_.size() > 0) || (actor_set->loop_count_actor_->loop_count() > 1) ||
  465. (actor_set->kernel_actors_.size() > ActorDispatcher::kSingleThreadExecutionActorMaxNum)) {
  466. return;
  467. }
  468. if ((actor_set->is_multi_thread_execution_) &&
  469. (actor_set->execution_count_ >= ActorDispatcher::kMultiThreadExecutionCountBegin) &&
  470. (actor_set->execution_count_ <= ActorDispatcher::kMultiThreadExecutionCountEnd)) {
  471. actor_set->multi_thread_execution_time_ += execution_time;
  472. if (actor_set->execution_count_ == ActorDispatcher::kMultiThreadExecutionCountEnd) {
  473. actor_set->multi_thread_execution_time_ /=
  474. ((ActorDispatcher::kMultiThreadExecutionCountEnd - ActorDispatcher::kMultiThreadExecutionCountBegin) + 1);
  475. actor_set->is_multi_thread_execution_ = false;
  476. }
  477. return;
  478. }
  479. if ((!actor_set->is_multi_thread_execution_) &&
  480. (actor_set->execution_count_ >= ActorDispatcher::kSingleThreadExecutionCountBegin) &&
  481. (actor_set->execution_count_ <= ActorDispatcher::kSingleThreadExecutionCountEnd)) {
  482. actor_set->single_thread_execution_time_ += execution_time;
  483. if (actor_set->execution_count_ == ActorDispatcher::kSingleThreadExecutionCountEnd) {
  484. actor_set->single_thread_execution_time_ /=
  485. (ActorDispatcher::kSingleThreadExecutionCountEnd - ActorDispatcher::kSingleThreadExecutionCountBegin + 1);
  486. actor_set->is_multi_thread_execution_ =
  487. (actor_set->multi_thread_execution_time_ <= actor_set->single_thread_execution_time_) ? true : false;
  488. MS_LOG(INFO) << "Multi thread execution time cost: " << actor_set->multi_thread_execution_time_
  489. << " ms, single thread execution time cost: " << actor_set->single_thread_execution_time_
  490. << " ms, decide to use multi thread execution or not: " << actor_set->is_multi_thread_execution_
  491. << ".";
  492. }
  493. return;
  494. }
  495. }
  496. ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
  497. auto iter = actors_.find(actor_info);
  498. if (iter != actors_.end()) {
  499. return iter->second.get();
  500. } else {
  501. MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
  502. return nullptr;
  503. }
  504. }
  505. ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
  506. auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
  507. MS_EXCEPTION_IF_NULL(actor_set);
  508. (void)actors_.emplace(actor_set->name_, actor_set);
  509. auto host_queue = std::make_shared<HostTensorQueue>();
  510. actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
  511. actor_set->custom_actors_ = BuildCustomActor(graph_compiler_info);
  512. actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
  513. actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
  514. actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
  515. actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
  516. actor_set->data_prepare_actor_ =
  517. BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
  518. actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_);
  519. #ifdef ENABLE_RPC_ACTOR
  520. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  521. actor_set->rpc_actors_ = rpc_node_scheduler_->Build(graph_compiler_info);
  522. #endif
  523. return actor_set;
  524. }
  525. void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
  526. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  527. return;
  528. }
  529. for (const auto &graph : graph_compiler_info.graphs_) {
  530. MS_EXCEPTION_IF_NULL(graph);
  531. auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
  532. for (const auto &output_with_index : outputs) {
  533. auto output_kernel = output_with_index.first;
  534. MS_EXCEPTION_IF_NULL(output_kernel);
  535. auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
  536. if (origin_output_with_index.first == nullptr) {
  537. MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  538. << " with index: " << output_with_index.second << " has no front node.";
  539. continue;
  540. }
  541. auto kernel_type = FetchKernelTransformType(output_kernel, graph, graph_compiler_info.origin_parameters_order_);
  542. auto output_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_kernel, graph);
  543. if (output_actor == nullptr) {
  544. MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  545. << " with index:" << output_with_index.second
  546. << " is not actor, and the kernel type is:" << kernel_type;
  547. }
  548. auto output_actor_name = (output_actor != nullptr) ? output_actor->GetAID().Name() : "";
  549. (void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(output_actor, output_with_index));
  550. MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  551. << " with index:" << output_with_index.second << " to actor:" << output_actor_name
  552. << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
  553. << " with index:" << origin_output_with_index.second;
  554. }
  555. }
  556. }
  557. void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
  558. MS_EXCEPTION_IF_NULL(actor_set);
  559. std::vector<AbstractActor *> auto_monad_actors;
  560. GroupNameToCommuNodes group_name_to_communication_nodes;
  561. std::string default_group_name = "";
  562. const auto &parser = graph_compiler_info.control_node_parser_;
  563. MS_EXCEPTION_IF_NULL(parser);
  564. for (const auto &graph : graph_compiler_info.graphs_) {
  565. MS_EXCEPTION_IF_NULL(graph);
  566. if (graph->execution_order().empty()) {
  567. MS_LOG(INFO) << "The graph " << graph->graph_id() << " is an empty graph and skips linking.";
  568. continue;
  569. }
  570. if (graph->is_executing_sink()) {
  571. LinkDataArrowInSinkMode(graph, graph_compiler_info, &auto_monad_actors);
  572. } else {
  573. // In the control flow, the communication nodes need to be guaranteed to be executed in order. The order
  574. // within the kernel graph group needs to add control arrows between the communication nodes, and the order
  575. // between groups is guaranteed by the control flow framework. Therefore, communication nodes need to be
  576. // grouped by group name. And this is not required in non-control flow, the default unified group name is used.
  577. std::vector<CNodePtr> communication_nodes;
  578. const auto &group_name = (parser->IsInited() ? parser->FetchGroupNameByKernelGraph(graph) : default_group_name);
  579. LinkDataArrowInNonSinkMode(graph, graph_compiler_info, &auto_monad_actors, &communication_nodes);
  580. group_name_to_communication_nodes[group_name].insert(group_name_to_communication_nodes[group_name].end(),
  581. communication_nodes.begin(), communication_nodes.end());
  582. }
  583. }
  584. LinkGlobalControlArrow(actor_set, group_name_to_communication_nodes, auto_monad_actors, graph_compiler_info);
  585. LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
  586. // The copy actors are built in the link, so need push into the actor set after link.
  587. actor_set->copy_actors_ = copy_actors_;
  588. // Link the arrow in the control flow scene.
  589. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline &&
  590. graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited()) {
  591. control_node_scheduler_.Link(actor_set, graph_compiler_info);
  592. }
  593. #ifdef ENABLE_RPC_ACTOR
  594. // Link inter-process arrows for rpc actors.
  595. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  596. rpc_node_scheduler_->Link(actor_set);
  597. #endif
  598. }
  599. void GraphScheduler::Optimize(ActorSet *const actor_set) {
  600. MS_EXCEPTION_IF_NULL(actor_set);
  601. control_node_scheduler_.Optimize(actor_set->control_actors_.get());
  602. }
  603. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  604. const HostTensorQueuePtr &host_queue) {
  605. std::vector<DataSourceActorPtr> data_source_actors;
  606. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  607. size_t data_node_position = 0;
  608. mindspore::HashMap<AnfNodePtr, size_t> front_node_position_temp_map;
  609. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  610. const auto &graph = graph_compiler_info.graphs_[i];
  611. const auto &device_context = graph_compiler_info.device_contexts_[i];
  612. MS_EXCEPTION_IF_NULL(graph);
  613. // Build host queue data source actor.
  614. const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
  615. const auto &root_parameters = graph_compiler_info.origin_parameters_order_;
  616. for (size_t j = 0; j < input_nodes.size(); j++) {
  617. const auto &input_node = input_nodes[j];
  618. MS_EXCEPTION_IF_NULL(input_node);
  619. if (IsHostQueueDSActor(input_node, graph, root_parameters, graph_compiler_info.strategy_)) {
  620. // In control flow, parameters from subgraph need not init in data source actor.
  621. if (graph_compiler_info.control_node_parser_->IsInited()) {
  622. auto node_with_index = graph->GetElementInTupleBackendFrontIndexMap(input_node);
  623. if (node_with_index.first != nullptr && node_with_index.first->isa<Parameter>() &&
  624. find(root_parameters.begin(), root_parameters.end(), node_with_index.first) == root_parameters.end())
  625. continue;
  626. }
  627. if (host_queue_ds_actor == nullptr) {
  628. auto actor_name = graph_compiler_info.name_ + kHostDSActorNameSuffix;
  629. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  630. host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
  631. nullptr, host_queue);
  632. InsertActor(host_queue_ds_actor.get());
  633. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  634. }
  635. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  636. // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
  637. // is saved in the host queue data source actor.
  638. if (front_node_position_temp_map.count(front_node) > 0) {
  639. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
  640. front_node_position_temp_map[front_node]);
  641. continue;
  642. }
  643. (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  644. (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
  645. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
  646. // In control flow, need to rely on the front node to find the location of the corresponding real parameter.
  647. (void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position);
  648. (void)front_node_position_temp_map.emplace(front_node, data_node_position);
  649. data_node_position++;
  650. }
  651. }
  652. // The graph sink mode has no device queue data source actor.
  653. if (!graph->is_executing_sink()) {
  654. // Build device queue data source actor.
  655. const auto &execution_order = graph->execution_order();
  656. const auto &iter =
  657. std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
  658. return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
  659. });
  660. if (iter != execution_order.end()) {
  661. auto actor_name =
  662. graph_compiler_info.name_ + kDeviceDSActorNameSuffix + "_" + std::to_string(graph->graph_id());
  663. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  664. auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
  665. actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
  666. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  667. InsertActor(device_queue_ds_actor.get());
  668. (void)data_source_actors.emplace_back(device_queue_ds_actor);
  669. device_queue_ds_actor->data_kernel_ = *iter;
  670. device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
  671. }
  672. }
  673. }
  674. control_node_scheduler_.BuildDataSourceActorForControlNode(graph_compiler_info, host_queue, host_queue_ds_actor,
  675. memory_manager_aid_, &data_source_actors);
  676. return data_source_actors;
  677. }
  678. std::vector<CustomActorPtr> GraphScheduler::BuildCustomActor(const GraphCompilerInfo &graph_compiler_info) {
  679. std::vector<CustomActorPtr> custom_actors;
  680. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  681. const auto &device_context = graph_compiler_info.device_contexts_[i];
  682. const auto &graph = graph_compiler_info.graphs_[i];
  683. MS_EXCEPTION_IF_NULL(graph);
  684. if (graph->is_executing_sink()) {
  685. continue;
  686. }
  687. auto all_nodes = TopoSort(graph->get_return());
  688. for (const auto &node : all_nodes) {
  689. if (!AnfUtils::IsCustomActorNode(node)) {
  690. continue;
  691. }
  692. auto actor_name = AnfUtils::GetCustomActorName(node);
  693. auto custom_actor = std::make_shared<CustomActor>(actor_name, node, device_context, recorder_aid_);
  694. MS_EXCEPTION_IF_NULL(custom_actor);
  695. InsertActor(custom_actor.get());
  696. custom_actors.emplace_back(custom_actor);
  697. }
  698. }
  699. return custom_actors;
  700. }
  701. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  702. std::vector<KernelActorPtr> kernel_actors;
  703. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  704. const auto &graph = graph_compiler_info.graphs_[i];
  705. const auto &device_context = graph_compiler_info.device_contexts_[i];
  706. MS_EXCEPTION_IF_NULL(graph);
  707. if (graph->is_executing_sink()) {
  708. continue;
  709. }
  710. auto execution_order = graph->execution_order();
  711. // Single op graph in step mode, kernel actor executes synchronously.
  712. bool is_single_op_graph = execution_order.size() == 1;
  713. GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
  714. if (strategy == GraphExecutionStrategy::kStep) {
  715. strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
  716. }
  717. for (auto &kernel : execution_order) {
  718. MS_EXCEPTION_IF_NULL(kernel);
  719. if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
  720. auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
  721. auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph);
  722. KernelActorPtr kernel_actor = nullptr;
  723. if (IsRpcActor(kernel)) {
  724. kernel_actor = GenerateRpcActor(kernel, device_context, strategy, ref_input_indexes, ref_output_indexes);
  725. } else {
  726. kernel_actor =
  727. std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
  728. debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
  729. }
  730. MS_EXCEPTION_IF_NULL(kernel_actor);
  731. InsertActor(kernel_actor.get());
  732. (void)kernel_actors.emplace_back(kernel_actor);
  733. }
  734. }
  735. }
  736. return kernel_actors;
  737. }
  738. std::vector<SuperKernelActorPtr> GraphScheduler::BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  739. std::vector<SuperKernelActorPtr> super_kernel_actors;
  740. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  741. const auto &graph = graph_compiler_info.graphs_[i];
  742. const auto &device_context = graph_compiler_info.device_contexts_[i];
  743. MS_EXCEPTION_IF_NULL(graph);
  744. if (!graph->is_executing_sink()) {
  745. continue;
  746. }
  747. if (graph->execution_order().empty()) {
  748. MS_LOG(INFO) << "The graph " << graph->graph_id() << " is an empty graph and skips building.";
  749. continue;
  750. }
  751. auto actor_name = graph->ToString() + kSuperKernelActorNameSuffix;
  752. auto super_kernel_actor =
  753. std::make_shared<SuperKernelActor>(actor_name, graph, device_context, memory_manager_aid_, debug_aid_, nullptr);
  754. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  755. InsertActor(super_kernel_actor.get());
  756. (void)super_kernel_actors.emplace_back(super_kernel_actor);
  757. }
  758. return super_kernel_actors;
  759. }
  760. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
  761. auto actor_set = Fetch(graph_compiler_info.name_);
  762. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  763. return nullptr;
  764. }
  765. auto loop_count = ConfigManager::GetInstance().iter_num();
  766. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  767. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  768. loop_count = 1;
  769. }
  770. auto actor_name = graph_compiler_info.name_ + kLoopCountActorNameSuffix;
  771. auto loop_count_actor =
  772. std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_,
  773. graph_compiler_info.strategy_, graph_compiler_info.device_contexts_);
  774. MS_LOG(INFO) << "Create loop count actor: " << actor_name;
  775. MS_EXCEPTION_IF_NULL(loop_count_actor);
  776. InsertActor(loop_count_actor.get());
  777. return loop_count_actor;
  778. }
  779. OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
  780. auto actor_set = Fetch(graph_compiler_info.name_);
  781. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  782. return nullptr;
  783. }
  784. auto loop_count = ConfigManager::GetInstance().iter_num();
  785. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  786. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  787. loop_count = 1;
  788. }
  789. auto actor_name = graph_compiler_info.name_ + kOutputActorNameSuffix;
  790. auto output_actor = std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_);
  791. MS_LOG(INFO) << "Create output actor: " << actor_name;
  792. MS_EXCEPTION_IF_NULL(output_actor);
  793. InsertActor(output_actor.get());
  794. return output_actor;
  795. }
  796. DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  797. const std::vector<DataSourceActorPtr> &data_source_actors,
  798. const HostTensorQueuePtr &host_queue) {
  799. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  800. auto iter = std::find_if(data_source_actors.begin(), data_source_actors.end(), [&](const auto &data_source_actor) {
  801. return data_source_actor->type_ == KernelTransformType::kHostDataSourceActor;
  802. });
  803. if (iter != data_source_actors.end()) {
  804. host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
  805. }
  806. auto actor_name = graph_compiler_info.name_ + kDataPrepareActorNameSuffix;
  807. auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
  808. &graph_compiler_info, host_queue_ds_actor, host_queue);
  809. MS_LOG(INFO) << "Create data prepare actor: " << actor_name;
  810. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  811. // Cache the nodes which need continuous memory.
  812. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  813. for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
  814. const auto &graph = graph_compiler_info.graphs_[index];
  815. MS_EXCEPTION_IF_NULL(graph);
  816. if (graph->is_executing_sink()) {
  817. continue;
  818. }
  819. auto &execution_order = graph->execution_order();
  820. for (auto &kernel : execution_order) {
  821. if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
  822. continue;
  823. }
  824. auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
  825. auto value = std::make_pair(false, false);
  826. if (common::AnfAlgo::GetInputTensorNum(kernel) > 1) {
  827. value.first = true;
  828. }
  829. if (common::AnfAlgo::GetOutputTensorNum(kernel) > 1) {
  830. value.second = true;
  831. }
  832. if ((value.first == true) || (value.second == true)) {
  833. data_prepare_actor->continuous_memory_nodes_[key] = value;
  834. }
  835. }
  836. }
  837. }
  838. InsertActor(data_prepare_actor.get());
  839. return data_prepare_actor;
  840. }
  841. std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
  842. GraphExecutionStrategy strategy) {
  843. MS_EXCEPTION_IF_NULL(actor_set);
  844. std::vector<AbstractActorPtr> no_input_kernel_actors;
  845. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  846. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  847. if ((super_kernel_actor->input_datas_num_ == 0) && (super_kernel_actor->input_controls_num_ == 0)) {
  848. (void)no_input_kernel_actors.emplace_back(super_kernel_actor);
  849. }
  850. }
  851. for (auto &kernel_actor : actor_set->kernel_actors_) {
  852. MS_EXCEPTION_IF_NULL(kernel_actor);
  853. // Framework will trigger kernel actor running in the step execution strategy.
  854. if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
  855. kernel_actor->input_controls_num_++;
  856. continue;
  857. }
  858. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  859. (void)no_input_kernel_actors.emplace_back(kernel_actor);
  860. }
  861. }
  862. return no_input_kernel_actors;
  863. }
  864. KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
  865. GraphExecutionStrategy strategy,
  866. const std::set<size_t> &ref_input_indexes,
  867. const std::set<size_t> &ref_output_indexes) {
  868. MS_EXCEPTION_IF_NULL(kernel);
  869. MS_EXCEPTION_IF_NULL(device_context);
  870. #ifdef ENABLE_RPC_ACTOR
  871. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  872. if (common::AnfAlgo::GetCNodeName(kernel) == kRpcSendOpName) {
  873. auto send_actor =
  874. std::make_shared<SendActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
  875. debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
  876. MS_EXCEPTION_IF_NULL(send_actor);
  877. rpc_node_scheduler_->InsertSendActor(send_actor);
  878. return send_actor;
  879. } else if (common::AnfAlgo::GetCNodeName(kernel) == kRpcRecvOpName) {
  880. auto recv_actor =
  881. std::make_shared<RecvActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
  882. debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
  883. MS_EXCEPTION_IF_NULL(recv_actor);
  884. rpc_node_scheduler_->InsertRecvActor(recv_actor);
  885. return recv_actor;
  886. } else {
  887. MS_LOG(EXCEPTION) << "Kernel " << kernel->fullname_with_scope() << " is not an rpc kernel.";
  888. }
  889. #endif
  890. return nullptr;
  891. }
  892. void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  893. std::vector<AbstractActor *> *const auto_monad_actors) {
  894. MS_EXCEPTION_IF_NULL(graph);
  895. // The data arrow linking is taken over by the control flow.
  896. if (graph_compiler_info.control_node_parser_ != nullptr &&
  897. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, nullptr)) {
  898. return;
  899. }
  900. auto to_actor_name = graph->ToString() + kSuperKernelActorNameSuffix;
  901. auto to_actor = FetchActor(to_actor_name);
  902. MS_EXCEPTION_IF_NULL(to_actor);
  903. auto &input_nodes = graph->input_nodes();
  904. for (size_t node_index = 0; node_index < input_nodes.size(); ++node_index) {
  905. auto &input_node = input_nodes[node_index];
  906. MS_EXCEPTION_IF_NULL(input_node);
  907. if (HasAbstractMonad(input_node)) {
  908. MS_LOG(INFO) << "The graph:" << graph->graph_id()
  909. << " has abstract monad input node:" << input_node->DebugString() << ", input index:" << node_index;
  910. LinkControlArrowByAutoMonad(to_actor, input_node, graph);
  911. continue; // No data arrow for monad input.
  912. }
  913. UpdateRefCount(input_node, 0, true);
  914. KernelWithIndex from_kernel_with_output_idx = std::make_pair(input_node, 0);
  915. KernelWithIndex to_kernel_with_input_idx = std::make_pair(input_node, node_index);
  916. // The gather of linking data arrows of kernel by the different from kernel type.
  917. LinkDataArrow(to_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  918. }
  919. std::vector<CNodePtr> auto_monad_kernels;
  920. // Foreach the execution order to get the auto monad kernels.
  921. auto &execution_order = graph->execution_order();
  922. (void)std::for_each(execution_order.begin(), execution_order.end(), [&](const CNodePtr &kernel) {
  923. for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
  924. auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
  925. if (HasAbstractMonad(input_node)) {
  926. (void)auto_monad_kernels.emplace_back(kernel);
  927. continue;
  928. }
  929. }
  930. });
  931. // Foreach auto monad kernels to get the auto monad device tensor stores.
  932. (void)std::for_each(auto_monad_kernels.begin(), auto_monad_kernels.end(), [&](const CNodePtr &kernel) {
  933. for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(kernel); ++i) {
  934. KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::GetPrevNodeOutput(kernel, i, false);
  935. auto front_node = FetchFrontNodeByBackendNode(from_kernel_with_output_idx.first, graph);
  936. if (IsPersistentDeviceTensor(front_node)) {
  937. (void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
  938. }
  939. }
  940. });
  941. if (to_actor->auto_monad_device_tensor_stores_.size() > 0) {
  942. (void)auto_monad_actors->emplace_back(to_actor);
  943. }
  944. }
  945. void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
  946. const GraphCompilerInfo &graph_compiler_info,
  947. std::vector<AbstractActor *> *const auto_monad_actors,
  948. std::vector<CNodePtr> *const communication_nodes) {
  949. MS_EXCEPTION_IF_NULL(graph);
  950. MS_EXCEPTION_IF_NULL(auto_monad_actors);
  951. MS_EXCEPTION_IF_NULL(communication_nodes);
  952. const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
  953. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
  954. auto &execution_order = graph->execution_order();
  955. // Foreach the execution order to link the actors.
  956. for (const auto &kernel : execution_order) {
  957. MS_EXCEPTION_IF_NULL(kernel);
  958. if (common::AnfAlgo::IsCommunicationOp(kernel)) {
  959. (void)communication_nodes->emplace_back(kernel);
  960. }
  961. if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
  962. continue;
  963. }
  964. const auto &kernel_actor = FetchActor(kernel->fullname_with_scope());
  965. MS_EXCEPTION_IF_NULL(kernel_actor);
  966. for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
  967. auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
  968. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
  969. if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
  970. LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_);
  971. }
  972. if (HasAbstractMonad(input_node)) {
  973. (void)auto_monad_actors->emplace_back(kernel_actor);
  974. continue; // No data arrow for monad input.
  975. }
  976. KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  977. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  978. // The data arrow linking is taken over by the control flow.
  979. if (graph_compiler_info.control_node_parser_ != nullptr &&
  980. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, from_kernel_with_output_idx.first)) {
  981. continue;
  982. }
  983. // The gather of linking data arrows of kernel by the different from kernel type.
  984. LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  985. }
  986. }
  987. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  988. LinkControlArrowBySendRecvNodes(graph);
  989. }
  990. void GraphScheduler::LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  991. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  992. const KernelWithIndex &to_kernel_with_input_idx) {
  993. MS_EXCEPTION_IF_NULL(to_actor);
  994. MS_EXCEPTION_IF_NULL(graph);
  995. auto from_kernel = from_kernel_with_output_idx.first;
  996. MS_EXCEPTION_IF_NULL(from_kernel);
  997. auto kernel_type = FetchKernelTransformType(from_kernel, graph, graph_compiler_info.origin_parameters_order_,
  998. graph_compiler_info.strategy_);
  999. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, from_kernel, graph);
  1000. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  1001. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  1002. MS_LOG(WARNING) << "Invalid from node:" << from_kernel->fullname_with_scope() << ", type:" << kernel_type;
  1003. }
  1004. return;
  1005. }
  1006. (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
  1007. to_kernel_with_input_idx, graph);
  1008. }
  1009. void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, AbstractActor *const to_actor,
  1010. const KernelWithIndex &from_kernel_with_output_idx,
  1011. const KernelWithIndex &to_kernel_with_input_idx,
  1012. const KernelGraphPtr &graph) {
  1013. MS_EXCEPTION_IF_NULL(to_actor);
  1014. MS_EXCEPTION_IF_NULL(graph);
  1015. auto from_kernel = from_kernel_with_output_idx.first;
  1016. MS_EXCEPTION_IF_NULL(from_kernel);
  1017. auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
  1018. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
  1019. }
  1020. void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, AbstractActor *to_actor,
  1021. const KernelWithIndex &from_kernel_with_output_idx,
  1022. const KernelWithIndex &to_kernel_with_input_idx,
  1023. const KernelGraphPtr &graph) {
  1024. MS_EXCEPTION_IF_NULL(to_actor);
  1025. MS_EXCEPTION_IF_NULL(graph);
  1026. auto internal_parameter = from_kernel_with_output_idx.first;
  1027. MS_EXCEPTION_IF_NULL(internal_parameter);
  1028. // Parameter ---> front node.
  1029. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
  1030. auto front_output_node = front_output_with_index.first;
  1031. MS_EXCEPTION_IF_NULL(front_output_node);
  1032. if (IsSwitchActor(front_output_node)) {
  1033. return;
  1034. }
  1035. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1036. AbstractActor *real_from_actor = nullptr;
  1037. KernelTransformType kernel_type;
  1038. if (IsPersistentDeviceTensor(front_output_node)) {
  1039. kernel_type = KernelTransformType::kDeviceTensorStore;
  1040. } else {
  1041. // front node ---> actor.
  1042. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1043. MS_LOG(EXCEPTION) << "Can't find actor by front node:" << common::AnfAlgo::GetNodeDebugString(front_output_node)
  1044. << ", internal parameter:" << common::AnfAlgo::GetNodeDebugString(internal_parameter);
  1045. }
  1046. auto actor_pair = graph_output_to_actor_[front_output_with_index];
  1047. MS_EXCEPTION_IF_NULL(actor_pair.first);
  1048. MS_EXCEPTION_IF_NULL(actor_pair.second.first);
  1049. MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
  1050. << ", corresponding front node:" << front_output_node->fullname_with_scope()
  1051. << " with index:" << front_output_with_index.second
  1052. << ", from actor:" << actor_pair.first->GetAID().Name()
  1053. << " node:" << actor_pair.second.first->fullname_with_scope()
  1054. << " with index:" << actor_pair.second.second << ", to actor:" << to_actor->GetAID().Name()
  1055. << " with index:" << to_kernel_with_input_idx.second;
  1056. real_from_actor = actor_pair.first;
  1057. real_from_kernel_with_output_idx = actor_pair.second;
  1058. kernel_type = actor_pair.first->type_;
  1059. }
  1060. // Update the real input node.
  1061. MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
  1062. if (to_kernel_with_input_idx.first->isa<CNode>()) {
  1063. auto kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first->cast<CNodePtr>());
  1064. MS_EXCEPTION_IF_NULL(kernel_mod);
  1065. kernel_mod->InsertRealInputNode(real_from_kernel_with_output_idx.first, real_from_kernel_with_output_idx.second,
  1066. to_kernel_with_input_idx.second);
  1067. }
  1068. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  1069. MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
  1070. }
  1071. (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
  1072. to_kernel_with_input_idx, graph);
  1073. }
  1074. void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1075. const KernelWithIndex &from_kernel_with_output_idx,
  1076. const KernelWithIndex &to_kernel_with_input_idx,
  1077. const KernelGraphPtr &) {
  1078. MS_EXCEPTION_IF_NULL(from_actor);
  1079. MS_EXCEPTION_IF_NULL(to_actor);
  1080. auto from_kernel = from_kernel_with_output_idx.first;
  1081. MS_EXCEPTION_IF_NULL(from_kernel);
  1082. auto from_output_index = from_kernel_with_output_idx.second;
  1083. auto to_input_index = to_kernel_with_input_idx.second;
  1084. // Get the position of from kernel in the data source actor.
  1085. auto position = from_actor->FetchNodePosition(from_kernel);
  1086. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  1087. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  1088. }
  1089. if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
  1090. LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  1091. } else {
  1092. AddDataArrow(from_actor, to_actor, from_kernel, from_output_index, to_input_index);
  1093. }
  1094. }
  1095. void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1096. const KernelWithIndex &from_kernel_with_output_idx,
  1097. const KernelWithIndex &to_kernel_with_input_idx,
  1098. const KernelGraphPtr &graph) {
  1099. auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1100. MS_EXCEPTION_IF_NULL(host_ds_actor);
  1101. MS_EXCEPTION_IF_NULL(from_kernel_with_output_idx.first);
  1102. KernelWithIndex real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1103. // Get the position and real kernel by from kernel in the data source actor.
  1104. auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
  1105. real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(position);
  1106. LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx, graph);
  1107. }
  1108. void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1109. const KernelWithIndex &from_kernel_with_output_idx,
  1110. const KernelWithIndex &to_kernel_with_input_idx,
  1111. const KernelGraphPtr &graph) {
  1112. auto real_from_actor = from_actor;
  1113. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1114. auto from_kernel = from_kernel_with_output_idx.first;
  1115. // Update the from kernel info by the real node info.
  1116. MS_EXCEPTION_IF_NULL(from_kernel);
  1117. if (IsSkippedKernelActor(from_kernel)) {
  1118. real_from_kernel_with_output_idx = common::AnfAlgo::GetPrevNodeOutput(from_kernel, 0, false);
  1119. MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
  1120. LinkControlArrowBySkippedNode(to_actor, from_kernel);
  1121. MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
  1122. MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
  1123. << to_kernel_with_input_idx.first->fullname_with_scope()
  1124. << ", aggregate input index: " << to_kernel_with_input_idx.second
  1125. << ", skip node: " << from_kernel->fullname_with_scope()
  1126. << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
  1127. real_from_actor = FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope());
  1128. MS_EXCEPTION_IF_NULL(real_from_actor);
  1129. }
  1130. LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx,
  1131. graph);
  1132. }
  1133. void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1134. const KernelWithIndex &from_kernel_with_output_idx,
  1135. const KernelWithIndex &to_kernel_with_input_idx) {
  1136. MS_EXCEPTION_IF_NULL(from_actor);
  1137. MS_EXCEPTION_IF_NULL(to_actor);
  1138. auto from_kernel = from_kernel_with_output_idx.first;
  1139. MS_EXCEPTION_IF_NULL(from_kernel);
  1140. std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
  1141. "_output_index:" + std::to_string(from_kernel_with_output_idx.second);
  1142. CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
  1143. // Link between from actor and copy actor.
  1144. if (copy_actor == nullptr) {
  1145. // Create the copy actor.
  1146. auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1147. (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
  1148. copy_actor = copy_actor_shared_ptr.get();
  1149. MS_EXCEPTION_IF_NULL(copy_actor);
  1150. InsertActor(copy_actor);
  1151. // Set the member device_contexts_ of the copy actor.
  1152. auto position = from_actor->FetchNodePosition(from_kernel);
  1153. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  1154. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  1155. }
  1156. auto from_device_context = from_actor->device_contexts_[position];
  1157. auto to_device_context = to_actor->device_contexts_[0];
  1158. MS_EXCEPTION_IF_NULL(from_device_context);
  1159. MS_EXCEPTION_IF_NULL(to_device_context);
  1160. (void)copy_actor->device_contexts_.emplace_back(from_device_context);
  1161. (void)copy_actor->device_contexts_.emplace_back(to_device_context);
  1162. // Set the member output_ of the copy actor.
  1163. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1164. copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false);
  1165. } else {
  1166. copy_actor->output_ =
  1167. AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false);
  1168. }
  1169. MS_EXCEPTION_IF_NULL(copy_actor->output_);
  1170. if (copy_actor->output_->DeviceType() != to_device_context->GetDeviceAddressType()) {
  1171. MS_LOG(EXCEPTION) << "The device type is not equal, output device type:" << copy_actor->output_->DeviceType()
  1172. << ", to device context type:" << to_device_context->GetDeviceAddressType();
  1173. }
  1174. // Link between from actor and copy actor.
  1175. AddDataArrow(from_actor, copy_actor, from_kernel, from_kernel_with_output_idx.second, 0);
  1176. }
  1177. // If the copy actor already exists, only need link between copy actor and to actor.
  1178. AddDataArrow(copy_actor, to_actor, nullptr, 0, to_kernel_with_input_idx.second);
  1179. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1180. UpdateRefCount(copy_actor->output_.get(), true);
  1181. } else {
  1182. UpdateRefCount(copy_actor->output_.get(), false);
  1183. }
  1184. }
  1185. void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node,
  1186. const KernelGraphPtr &graph, const ControlNodeParserPtr &parser) {
  1187. MS_EXCEPTION_IF_NULL(to_actor);
  1188. MS_EXCEPTION_IF_NULL(from_node);
  1189. MS_EXCEPTION_IF_NULL(graph);
  1190. // Find the real input node, include the monad node and make tuple node.
  1191. const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
  1192. prim::kPrimMakeTuple};
  1193. const auto &input_kernel_with_output_idx =
  1194. common::AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
  1195. MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
  1196. auto input_anfnode = input_kernel_with_output_idx.first;
  1197. CNodePtr input_cnode = nullptr;
  1198. if (input_anfnode->isa<CNode>()) {
  1199. input_cnode = input_anfnode->cast<CNodePtr>();
  1200. }
  1201. // Make tuple node needs to be expanded.
  1202. if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
  1203. MS_EXCEPTION_IF_NULL(input_cnode);
  1204. for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
  1205. LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser);
  1206. }
  1207. return;
  1208. }
  1209. const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
  1210. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
  1211. // Get the real depend input by monad node which needs to link the control arrow.
  1212. std::vector<AnfNodePtr> real_depend_inputs;
  1213. if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
  1214. common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
  1215. MS_EXCEPTION_IF_NULL(input_cnode);
  1216. real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
  1217. // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
  1218. // node in this scene.
  1219. if (IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
  1220. real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
  1221. }
  1222. } else if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
  1223. MS_EXCEPTION_IF_NULL(input_cnode);
  1224. for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
  1225. real_depend_inputs.push_back(input_cnode->input(i));
  1226. }
  1227. } else {
  1228. real_depend_inputs.push_back(input_anfnode);
  1229. }
  1230. for (const auto &real_depend_input : real_depend_inputs) {
  1231. auto real_depend_input_with_idx =
  1232. common::AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
  1233. MS_EXCEPTION_IF_NULL(real_depend_input_with_idx.first);
  1234. auto real_depend_kernel = real_depend_input_with_idx.first;
  1235. // Update the real depend kernel in the subgraphs connecting scene.
  1236. if (IsInternalParameter(real_depend_kernel, graph)) {
  1237. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
  1238. MS_EXCEPTION_IF_NULL(front_output_with_index.first);
  1239. if (IsTakenOverByControlFlow(front_output_with_index.first, graph, parser)) {
  1240. MS_LOG(DEBUG) << "Skip in control flow from node:" << front_output_with_index.first->DebugString()
  1241. << " is not in the graph:" << graph->ToString();
  1242. continue;
  1243. }
  1244. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1245. MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
  1246. }
  1247. real_depend_kernel = graph_output_to_actor_[front_output_with_index].second.first;
  1248. MS_EXCEPTION_IF_NULL(real_depend_kernel);
  1249. MS_LOG(INFO) << "The graph " << graph->graph_id() << " link control arrow by auto monad from internal parameter: "
  1250. << real_depend_input_with_idx.first->DebugString()
  1251. << ", front output node: " << front_output_with_index.first->fullname_with_scope()
  1252. << ", backend output node: " << real_depend_kernel->fullname_with_scope();
  1253. auto from_actor = graph_output_to_actor_[front_output_with_index].first;
  1254. if (from_actor != nullptr) {
  1255. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1256. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1257. AddControlArrow(from_actor, to_actor);
  1258. continue;
  1259. }
  1260. }
  1261. // The monad node and make tuple node need recursion.
  1262. if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
  1263. LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser);
  1264. continue;
  1265. }
  1266. auto type = FetchKernelTransformType(real_depend_kernel, nullptr);
  1267. auto from_actor = FetchActor(type, "", real_depend_kernel);
  1268. if (from_actor == nullptr) {
  1269. MS_LOG(DEBUG) << "Link control arrow by auto monad from depend node:" << real_depend_kernel->fullname_with_scope()
  1270. << " is not actor for the graph: " << graph->graph_id();
  1271. continue;
  1272. }
  1273. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1274. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1275. AddControlArrow(from_actor, to_actor);
  1276. }
  1277. }
  1278. void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node) {
  1279. MS_EXCEPTION_IF_NULL(to_actor);
  1280. MS_EXCEPTION_IF_NULL(skipped_node);
  1281. // Link the control arrow from all the inputs of skipped node to the user of skipped node.
  1282. auto input_num = common::AnfAlgo::GetInputTensorNum(skipped_node);
  1283. for (size_t i = 0; i < input_num; ++i) {
  1284. auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
  1285. MS_EXCEPTION_IF_NULL(kernel_with_index.first);
  1286. auto from_actor = FetchActor(kernel_with_index.first->fullname_with_scope());
  1287. MS_EXCEPTION_IF_NULL(from_actor);
  1288. MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
  1289. << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
  1290. AddControlArrow(from_actor, to_actor);
  1291. }
  1292. }
  1293. void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
  1294. MS_EXCEPTION_IF_NULL(graph);
  1295. for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
  1296. auto to_allreduce_node = from_iter.first;
  1297. auto from_send_node = from_iter.second.first;
  1298. auto from_recv_node = from_iter.second.second;
  1299. MS_EXCEPTION_IF_NULL(to_allreduce_node);
  1300. MS_EXCEPTION_IF_NULL(from_send_node);
  1301. MS_EXCEPTION_IF_NULL(from_recv_node);
  1302. MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
  1303. auto to_allreduce_actor = FetchActor(to_allreduce_node->fullname_with_scope());
  1304. auto from_send_actor = FetchActor(from_send_node->fullname_with_scope());
  1305. auto from_recv_actor = FetchActor(from_recv_node->fullname_with_scope());
  1306. MS_EXCEPTION_IF_NULL(to_allreduce_actor);
  1307. MS_EXCEPTION_IF_NULL(from_send_actor);
  1308. MS_EXCEPTION_IF_NULL(from_recv_actor);
  1309. // inputs of to_allreduce_actor --> from_send_actor
  1310. for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
  1311. auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
  1312. if (input_actor != nullptr) {
  1313. AddControlArrow(input_actor, from_send_actor);
  1314. }
  1315. }
  1316. // from_send_actor --> from_recv_actor
  1317. AddControlArrow(from_send_actor, from_recv_actor);
  1318. // from_recv_actor --> to_allreduce_actor
  1319. AddControlArrow(from_recv_actor, to_allreduce_actor);
  1320. }
  1321. for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
  1322. auto from_allreduce_node = to_iter.first;
  1323. auto to_send_node = to_iter.second.first;
  1324. auto to_recv_node = to_iter.second.second;
  1325. MS_EXCEPTION_IF_NULL(from_allreduce_node);
  1326. MS_EXCEPTION_IF_NULL(to_send_node);
  1327. MS_EXCEPTION_IF_NULL(to_recv_node);
  1328. MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
  1329. auto from_allreduce_actor = FetchActor(from_allreduce_node->fullname_with_scope());
  1330. auto to_send_actor = FetchActor(to_send_node->fullname_with_scope());
  1331. auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
  1332. MS_EXCEPTION_IF_NULL(from_allreduce_actor);
  1333. MS_EXCEPTION_IF_NULL(to_send_actor);
  1334. MS_EXCEPTION_IF_NULL(to_recv_actor);
  1335. // from_allreduce_actor --> to_send_actor
  1336. AddControlArrow(from_allreduce_actor, to_send_actor);
  1337. // to_send_actor --> to_recv_actor
  1338. AddControlArrow(to_send_actor, to_recv_actor);
  1339. // to_recv_actor --> outputs of from_allreduce_actor
  1340. for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
  1341. auto output_actor = FetchActor(output_data_arrow->to_op_id_.Name());
  1342. if (output_actor != nullptr) {
  1343. AddControlArrow(to_recv_actor, output_actor);
  1344. }
  1345. }
  1346. // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
  1347. // reused only when the recv node runs finished, which is expressed by the reference count increased.
  1348. for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
  1349. auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
  1350. MS_EXCEPTION_IF_NULL(device_tensor);
  1351. UpdateRefCount(device_tensor.get());
  1352. (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
  1353. }
  1354. }
  1355. }
  1356. void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set,
  1357. const GroupNameToCommuNodes &communication_node_groups,
  1358. const std::vector<AbstractActor *> &auto_monad_actors,
  1359. const GraphCompilerInfo &graph_compiler_info) {
  1360. MS_EXCEPTION_IF_NULL(actor_set);
  1361. for (const auto &communication_nodes : communication_node_groups) {
  1362. // Link the control arrows by the communication nodes to ensure communication nodes running order.
  1363. LinkControlArrowByCommunicationNode(communication_nodes.second, graph_compiler_info);
  1364. }
  1365. // Auto monad actor may modify the device tensor store.
  1366. LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
  1367. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  1368. actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
  1369. // Link the control arrows of data prepare actor, which depends on the no input kernel actors.
  1370. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) || (!IsSingleOpActorSet(actor_set))) {
  1371. LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set,
  1372. graph_compiler_info.control_node_parser_);
  1373. }
  1374. // Link control arrows for custom actor.
  1375. LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
  1376. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
  1377. graph_compiler_info.control_node_parser_);
  1378. LinkControlArrowForOutputActor(actor_set->output_actor_.get(), actor_set);
  1379. }
  1380. void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
  1381. const GraphCompilerInfo &graph_compiler_info) {
  1382. MS_EXCEPTION_IF_NULL(actor_set);
  1383. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  1384. // Link depend(custom, custom) or depend(custom, kernel) or depend(internal parameter, custom).
  1385. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1386. const auto &graph = graph_compiler_info.graphs_[i];
  1387. MS_EXCEPTION_IF_NULL(graph);
  1388. if (graph->is_executing_sink()) {
  1389. continue;
  1390. }
  1391. auto all_nodes = TopoSort(graph->get_return());
  1392. for (const auto &node : all_nodes) {
  1393. if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
  1394. continue;
  1395. }
  1396. auto depend_cnode = node->cast<CNodePtr>();
  1397. MS_EXCEPTION_IF_NULL(depend_cnode);
  1398. auto from_node = depend_cnode->input(kDependAttachNodeIndex);
  1399. auto to_node = depend_cnode->input(kRealInputIndexInDepend);
  1400. MS_EXCEPTION_IF_NULL(from_node);
  1401. MS_EXCEPTION_IF_NULL(to_node);
  1402. if (!AnfUtils::IsCustomActorNode(from_node) && !AnfUtils::IsCustomActorNode(to_node)) {
  1403. continue;
  1404. }
  1405. AbstractActor *from_actor = nullptr;
  1406. // InternalParameter --> CustomActor.
  1407. if (IsInternalParameter(from_node, graph)) {
  1408. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(from_node);
  1409. auto front_output_node = front_output_with_index.first;
  1410. MS_EXCEPTION_IF_NULL(front_output_node);
  1411. if (IsSwitchActor(front_output_node) || (graph_output_to_actor_.count(front_output_with_index) == 0)) {
  1412. continue;
  1413. }
  1414. from_actor = graph_output_to_actor_[front_output_with_index].first;
  1415. } else {
  1416. auto from_kernel_type = FetchKernelTransformType(from_node, graph, graph_compiler_info.origin_parameters_order_,
  1417. graph_compiler_info.strategy_);
  1418. from_actor = FetchActor(from_kernel_type, graph_compiler_info.name_, from_node, graph);
  1419. }
  1420. MS_EXCEPTION_IF_NULL(from_actor);
  1421. auto to_kernel_type = FetchKernelTransformType(to_node, graph, graph_compiler_info.origin_parameters_order_,
  1422. graph_compiler_info.strategy_);
  1423. auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_node, graph);
  1424. MS_EXCEPTION_IF_NULL(to_actor);
  1425. AddControlArrow(from_actor, to_actor);
  1426. }
  1427. }
  1428. // In control flow, no input actors should be linked to entrance actors.
  1429. const auto &parser = graph_compiler_info.control_node_parser_;
  1430. MS_EXCEPTION_IF_NULL(parser);
  1431. if (parser->IsInited()) {
  1432. return;
  1433. }
  1434. // Handle the no input custom actor.
  1435. for (const auto &custom_actor : actor_set->custom_actors_) {
  1436. MS_EXCEPTION_IF_NULL(custom_actor);
  1437. if (custom_actor->input_controls_num_ > 0) {
  1438. continue;
  1439. }
  1440. auto kernel = custom_actor->kernel().lock();
  1441. MS_EXCEPTION_IF_NULL(kernel);
  1442. auto base_node = AnfUtils::GetCustomActorBaseNode(kernel);
  1443. MS_EXCEPTION_IF_NULL(base_node);
  1444. auto dynamic_shape_depends = abstract::GetDependsFormMap(base_node);
  1445. for (auto iter = dynamic_shape_depends.begin(); iter != dynamic_shape_depends.end(); ++iter) {
  1446. auto input_node = common::AnfAlgo::GetInputNode(base_node, *iter);
  1447. KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  1448. auto graph = FetchKernelGraph(from_kernel_with_output_idx.first);
  1449. auto kernel_type =
  1450. FetchKernelTransformType(from_kernel_with_output_idx.first, graph, graph_compiler_info.origin_parameters_order_,
  1451. graph_compiler_info.strategy_);
  1452. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, from_kernel_with_output_idx.first, graph);
  1453. // The input_node maybe a data(Tensor) and the from_actor is nullptr
  1454. if (from_actor != nullptr) {
  1455. AddDataArrow(from_actor, custom_actor.get(), from_kernel_with_output_idx.first,
  1456. from_kernel_with_output_idx.second, *iter);
  1457. }
  1458. }
  1459. AddControlArrow(actor_set->data_prepare_actor_.get(), custom_actor.get());
  1460. }
  1461. }
  1462. void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  1463. const GraphCompilerInfo &graph_compiler_info) {
  1464. const size_t kCommunicationNodesMinNum = 2;
  1465. if (communication_nodes.size() < kCommunicationNodesMinNum) {
  1466. return;
  1467. }
  1468. // Ensure communication node to execute orderly.
  1469. for (size_t i = 1; i < communication_nodes.size(); ++i) {
  1470. auto from_actor = FetchActor(communication_nodes[i - 1]->fullname_with_scope());
  1471. auto to_actor = FetchActor(communication_nodes[i]->fullname_with_scope());
  1472. MS_EXCEPTION_IF_NULL(from_actor);
  1473. MS_EXCEPTION_IF_NULL(to_actor);
  1474. AddControlArrow(from_actor, to_actor);
  1475. }
  1476. // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
  1477. // Using the multi stream to optimize the performance in the future.
  1478. for (auto &graph : graph_compiler_info.graphs_) {
  1479. MS_EXCEPTION_IF_NULL(graph);
  1480. auto &execution_order = graph->execution_order();
  1481. for (size_t i = 1; i < execution_order.size(); ++i) {
  1482. auto from_actor = FetchActor(execution_order[i - 1]->fullname_with_scope());
  1483. auto to_actor = FetchActor(execution_order[i]->fullname_with_scope());
  1484. if ((from_actor != nullptr) && (to_actor != nullptr)) {
  1485. AddControlArrow(from_actor, to_actor);
  1486. }
  1487. }
  1488. }
  1489. }
  1490. void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor,
  1491. const ActorSet *actor_set,
  1492. const ControlNodeParserPtr &parser) {
  1493. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  1494. MS_EXCEPTION_IF_NULL(actor_set);
  1495. MS_EXCEPTION_IF_NULL(parser);
  1496. // Data prepare actor --> data source actor.
  1497. for (auto &data_source_actor : actor_set->data_source_actors_) {
  1498. MS_EXCEPTION_IF_NULL(data_source_actor);
  1499. AddControlArrow(data_prepare_actor, data_source_actor.get());
  1500. }
  1501. // In control flow, control arrow of no input kernel actor needs to be connected to the corresponding entrance actor.
  1502. if (!parser->IsInited()) {
  1503. // Data prepare actor --> no input kernel actor.
  1504. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  1505. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  1506. AddControlArrow(data_prepare_actor, no_input_kernel_actor.get());
  1507. }
  1508. }
  1509. // Data prepare actor --> loop count actor.
  1510. if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
  1511. (actor_set->loop_count_actor_ != nullptr)) {
  1512. AddControlArrow(data_prepare_actor, actor_set->loop_count_actor_.get());
  1513. }
  1514. }
  1515. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  1516. const ControlNodeParserPtr &parser) {
  1517. MS_EXCEPTION_IF_NULL(actor_set);
  1518. MS_EXCEPTION_IF_NULL(parser);
  1519. // There is no loop count actor in step mode.
  1520. if (loop_count_actor == nullptr) {
  1521. return;
  1522. }
  1523. // Collect the actors which have no output.
  1524. std::vector<AbstractActor *> no_output_actors;
  1525. for (auto &super_actor : actor_set->super_kernel_actors_) {
  1526. if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
  1527. (void)no_output_actors.emplace_back(super_actor.get());
  1528. }
  1529. }
  1530. for (auto &kernel_actor : actor_set->kernel_actors_) {
  1531. // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
  1532. if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
  1533. (void)no_output_actors.emplace_back(kernel_actor.get());
  1534. }
  1535. }
  1536. for (auto &data_actor : actor_set->data_source_actors_) {
  1537. if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
  1538. (void)no_output_actors.emplace_back(data_actor.get());
  1539. }
  1540. }
  1541. for (auto &copy_actor : copy_actors_) {
  1542. if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
  1543. (void)no_output_actors.emplace_back(copy_actor.get());
  1544. }
  1545. }
  1546. for (auto &custom_actor : actor_set->custom_actors_) {
  1547. if ((custom_actor->output_data_arrows_.size() == 0) && (custom_actor->output_control_arrows_.size() == 0)) {
  1548. (void)no_output_actors.emplace_back(custom_actor.get());
  1549. }
  1550. }
  1551. // No output actor --> loop count actor.
  1552. // In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count.
  1553. if (!parser->IsInited()) {
  1554. for (auto &no_output_actor : no_output_actors) {
  1555. AddControlArrow(no_output_actor, loop_count_actor);
  1556. }
  1557. }
  1558. // Loop count actor --> output actor.
  1559. AddControlArrow(loop_count_actor, actor_set->output_actor_.get());
  1560. // Loop count actor --> data prepare actor.
  1561. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  1562. loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
  1563. actor_set->data_prepare_actor_->input_controls_num_++;
  1564. (void)actor_set->data_prepare_actor_->input_control_arrow_aids_.emplace_back(loop_count_actor->GetAID());
  1565. }
  1566. void GraphScheduler::LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set) {
  1567. MS_EXCEPTION_IF_NULL(actor_set);
  1568. // There is no output actor in step mode.
  1569. if (output_actor == nullptr) {
  1570. return;
  1571. }
  1572. // Output actor --> data prepare actor.
  1573. // The output actor needs to free the output memory in the running and needs this control arrow.
  1574. AddControlArrow(output_actor, actor_set->data_prepare_actor_.get());
  1575. }
  1576. void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
  1577. const GraphCompilerInfo &graph_compiler_info) {
  1578. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep ||
  1579. (graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) {
  1580. // In control flow, the exit actor of the root graph sends output data to the output actor.
  1581. return;
  1582. }
  1583. MS_EXCEPTION_IF_NULL(to_actor);
  1584. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1585. const auto &graph = graph_compiler_info.graphs_[i];
  1586. MS_EXCEPTION_IF_NULL(graph);
  1587. auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
  1588. std::set<std::vector<size_t>> unique_output_positions;
  1589. std::set<KernelWithIndex> unique_outputs;
  1590. for (const auto &output : outputs) {
  1591. if (IsInternalParameter(output.first, graph)) {
  1592. MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
  1593. continue;
  1594. }
  1595. (void)unique_outputs.insert(output);
  1596. }
  1597. for (const auto &output_with_index : unique_outputs) {
  1598. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1599. auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
  1600. const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
  1601. if (iter == graph_compiler_info.origin_outputs_order_.end()) {
  1602. continue;
  1603. }
  1604. // Skip duplicate position.
  1605. if (unique_output_positions.count(iter->second) > 0) {
  1606. continue;
  1607. }
  1608. (void)unique_output_positions.insert(iter->second);
  1609. for (auto &output_position : iter->second) {
  1610. if (output_position >= to_actor->device_contexts_.size()) {
  1611. MS_LOG(EXCEPTION) << "The output position is out of range.";
  1612. }
  1613. to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
  1614. // The graph output is from device tensor store.
  1615. if (IsPersistentDeviceTensor(output_with_index.first)) {
  1616. (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
  1617. if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) {
  1618. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1619. MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit";
  1620. continue;
  1621. }
  1622. // In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0
  1623. // as output_idx directly.
  1624. auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false);
  1625. MS_EXCEPTION_IF_NULL(device_tensor);
  1626. // The output actor need use the relevant information of node to create output tensor.
  1627. device_tensor->SetNodeIndex(output_with_index.first, 0);
  1628. continue;
  1629. }
  1630. // The graph output is from kernel actor or data source actor.
  1631. auto kernel_type = FetchKernelTransformType(
  1632. output_with_index.first, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_);
  1633. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_with_index.first, graph);
  1634. if (from_actor == nullptr) {
  1635. continue;
  1636. }
  1637. auto real_from_kernel = output_with_index.first;
  1638. // Update the real node in the host data source actor.
  1639. if (kernel_type == KernelTransformType::kHostDataSourceActor) {
  1640. auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1641. MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
  1642. auto position = host_queue_ds_actor->FetchNodePosition(output_with_index.first);
  1643. real_from_kernel = host_queue_ds_actor->FetchNode(position);
  1644. UpdateRefCount(output_with_index.first, output_with_index.second, true);
  1645. }
  1646. AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second, output_position);
  1647. }
  1648. }
  1649. }
  1650. }
  1651. void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors) {
  1652. const size_t kNeedUpdateDeviceTensorStoreNum = 2;
  1653. for (auto &auto_monad_actor : auto_monad_actors) {
  1654. MS_EXCEPTION_IF_NULL(auto_monad_actor);
  1655. for (auto &device_tensor_store_key : auto_monad_actor->device_tensor_store_keys_) {
  1656. auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
  1657. if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
  1658. continue;
  1659. }
  1660. // Find the device tensor store that needs to be processed accurately.
  1661. if ((auto_monad_actor->type_ == KernelTransformType::kSuperKernelActor) &&
  1662. (auto_monad_actor->auto_monad_device_tensor_stores_.find(device_tensor_store_key.second) ==
  1663. auto_monad_actor->auto_monad_device_tensor_stores_.end())) {
  1664. continue;
  1665. }
  1666. // Create the copy actor.
  1667. std::string name = "copy_from:" + auto_monad_actor->GetAID().Name() +
  1668. "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
  1669. if (FetchActor(name) != nullptr) {
  1670. continue;
  1671. }
  1672. auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1673. MS_EXCEPTION_IF_NULL(copy_actor);
  1674. (void)copy_actors_.emplace_back(copy_actor);
  1675. InsertActor(copy_actor.get());
  1676. // Set the member of the copy actor.
  1677. (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
  1678. auto input_device_context = auto_monad_actor->device_contexts_[0];
  1679. (void)copy_actor->device_contexts_.emplace_back(input_device_context);
  1680. auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
  1681. ? device_tensors[1]
  1682. : device_tensors[0];
  1683. MS_EXCEPTION_IF_NULL(another_device_tensor);
  1684. auto another_device_type = another_device_tensor->DeviceType();
  1685. const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
  1686. {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
  1687. MS_EXCEPTION_IF_NULL(another_device_context);
  1688. (void)copy_actor->device_contexts_.emplace_back(another_device_context);
  1689. MS_LOG(INFO) << "The auto monad actor: " << auto_monad_actor->GetAID().Name()
  1690. << "has control arrows number:" << auto_monad_actor->output_control_arrows_.size();
  1691. // Link from copy actor to auto monad actor users.
  1692. for (auto &output_contorl : auto_monad_actor->output_control_arrows_) {
  1693. (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
  1694. }
  1695. // Move the control arrows from auto monad actor to auto monad actor users.
  1696. auto_monad_actor->output_control_arrows_.clear();
  1697. // Link from auto monad actor to copy actor.
  1698. AddControlArrow(auto_monad_actor, copy_actor.get());
  1699. }
  1700. }
  1701. }
  1702. void GraphScheduler::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
  1703. MS_EXCEPTION_IF_NULL(anf_node);
  1704. MS_EXCEPTION_IF_NULL(device_tensor);
  1705. MS_LOG(DEBUG) << "Add device tensor store:" << device_tensor << " for node:" << anf_node->DebugString();
  1706. DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
  1707. UpdateRefCount(device_tensor.get(), true);
  1708. }
  1709. void GraphScheduler::AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1710. const AnfNodePtr &from_kernel, size_t from_output_index, size_t to_input_index) {
  1711. MS_EXCEPTION_IF_NULL(from_actor);
  1712. MS_EXCEPTION_IF_NULL(to_actor);
  1713. auto data_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), to_input_index);
  1714. (void)from_actor->output_data_arrows_.emplace_back(data_arrow);
  1715. (void)from_actor->output_data_nodes_.emplace_back(from_kernel);
  1716. to_actor->input_datas_num_++;
  1717. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1718. if (from_kernel == nullptr) {
  1719. return;
  1720. }
  1721. // Update the reference count of from_kernel.
  1722. // The device address of super kernel actor can't be changed, so set the max reference count.
  1723. if ((from_actor->type_ == KernelTransformType::kSuperKernelActor) ||
  1724. (to_actor->type_ == KernelTransformType::kSuperKernelActor)) {
  1725. UpdateRefCount(from_kernel, from_output_index, true);
  1726. } else {
  1727. UpdateRefCount(from_kernel, from_output_index, false);
  1728. }
  1729. }
  1730. void GraphScheduler::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
  1731. const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
  1732. MS_EXCEPTION_IF_NULL(from_actor);
  1733. MS_EXCEPTION_IF_NULL(to_actor);
  1734. MS_EXCEPTION_IF_NULL(from_kernel);
  1735. auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
  1736. (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
  1737. (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
  1738. to_actor->input_datas_num_++;
  1739. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1740. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
  1741. MS_EXCEPTION_IF_NULL(device_tensor);
  1742. // The output actor need use the relevant information of node to create output tensor.
  1743. device_tensor->SetNodeIndex(from_kernel, from_output_index);
  1744. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
  1745. UpdateRefCount(device_tensor.get(), true);
  1746. }
  1747. void GraphScheduler::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
  1748. MS_EXCEPTION_IF_NULL(from_actor);
  1749. MS_EXCEPTION_IF_NULL(to_actor);
  1750. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1751. to_actor->input_controls_num_++;
  1752. (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
  1753. }
  1754. void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
  1755. MS_EXCEPTION_IF_NULL(actor_set);
  1756. auto actors = CollectActors(actor_set);
  1757. for (auto &actor : actors) {
  1758. MS_EXCEPTION_IF_NULL(actor);
  1759. if (actor->type_ >= KernelTransformType::kSwitchActor) {
  1760. continue;
  1761. }
  1762. if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
  1763. (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
  1764. MS_LOG(EXCEPTION) << "The input num of " << actor->GetAID().Name()
  1765. << " is wrong, expect data num: " << actor->input_datas_num_
  1766. << ", actual data num: " << actor->input_data_arrow_aids_.size()
  1767. << ", expect control num: " << actor->input_controls_num_
  1768. << ", actual control num: " << actor->input_control_arrow_aids_.size();
  1769. }
  1770. if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kCustomActor) &&
  1771. (actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0)) {
  1772. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
  1773. }
  1774. if ((actor->type_ != KernelTransformType::kDataPrepareActor) &&
  1775. (actor->type_ != KernelTransformType::kCustomActor) && (actor->input_datas_num_ == 0) &&
  1776. (actor->input_controls_num_ == 0)) {
  1777. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
  1778. }
  1779. // Check the input of kernel actors and copy actors.
  1780. if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
  1781. size_t expect_toal_input_num = 1;
  1782. if (actor->type_ == KernelTransformType::kKernelActor) {
  1783. auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
  1784. MS_EXCEPTION_IF_NULL(kernel_actor);
  1785. expect_toal_input_num = common::AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
  1786. }
  1787. auto input_data_num = actor->input_datas_num_;
  1788. auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
  1789. if (input_data_num + device_tensor_store_num != expect_toal_input_num) {
  1790. MS_LOG(EXCEPTION) << "The input building of " << actor->GetAID().Name()
  1791. << " is wrong, input data num: " << input_data_num
  1792. << ", device tensor store num: " << device_tensor_store_num
  1793. << ", total input num: " << expect_toal_input_num;
  1794. }
  1795. }
  1796. }
  1797. // Check the output actor.
  1798. auto output_actor = actor_set->output_actor_;
  1799. MS_EXCEPTION_IF_NULL(output_actor);
  1800. if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num_) {
  1801. MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: "
  1802. << output_actor->outputs_num_ << ", the input data arrows num: " << output_actor->input_datas_num_
  1803. << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
  1804. }
  1805. control_node_scheduler_.CheckActorValid(actor_set);
  1806. }
  1807. void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
  1808. const auto &parser = graph_compiler_info.control_node_parser_;
  1809. MS_EXCEPTION_IF_NULL(parser);
  1810. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1811. const auto &graph = graph_compiler_info.graphs_[i];
  1812. const auto &device_context = graph_compiler_info.device_contexts_[i];
  1813. MS_EXCEPTION_IF_NULL(graph);
  1814. MS_EXCEPTION_IF_NULL(device_context);
  1815. for (auto &value_node : graph->graph_value_nodes()) {
  1816. MS_EXCEPTION_IF_NULL(value_node);
  1817. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1818. MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
  1819. continue;
  1820. }
  1821. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
  1822. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1823. device_tensor->SetNodeIndex(value_node, 0);
  1824. AddDeviceTensorStore(front_node.get(), device_tensor);
  1825. }
  1826. for (auto &input_node : graph->input_nodes()) {
  1827. MS_EXCEPTION_IF_NULL(input_node);
  1828. AnfNodePtr front_node = nullptr;
  1829. if (IsInternalParameter(input_node, graph)) {
  1830. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
  1831. front_node = front_output_with_index.first;
  1832. } else if (IsPersistentDeviceTensor(input_node)) {
  1833. front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1834. }
  1835. // The front node may be value node in the heterogeneous scene, needs to handle.
  1836. if ((front_node == nullptr) ||
  1837. (!front_node->isa<ValueNode>() && !parser->IsRootGraphPersistentDeviceTensor(front_node))) {
  1838. continue;
  1839. }
  1840. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
  1841. MS_EXCEPTION_IF_NULL(device_tensor);
  1842. if (IsPersistentDeviceTensor(input_node) || device_tensor->is_ptr_persisted()) {
  1843. device_tensor->SetNodeIndex(input_node, 0);
  1844. AddDeviceTensorStore(front_node.get(), device_tensor);
  1845. }
  1846. // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
  1847. if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
  1848. MS_LOG(WARNING) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
  1849. << ", type:" << device_context->GetDeviceAddressType();
  1850. auto other_type_device_tensor =
  1851. device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), device_tensor->format(),
  1852. device_tensor->type_id(), device_tensor->host_shape());
  1853. other_type_device_tensor->SetNodeIndex(input_node, 0);
  1854. other_type_device_tensor->set_from_persistent_mem(input_node->isa<Parameter>());
  1855. AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
  1856. }
  1857. }
  1858. }
  1859. PersistDeviceTensorForRootGraphControlNode(graph_compiler_info);
  1860. }
  1861. void GraphScheduler::PersistDeviceTensorForRootGraphControlNode(const GraphCompilerInfo &graph_compiler_info) {
  1862. const auto &parser = graph_compiler_info.control_node_parser_;
  1863. if (parser == nullptr || (!parser->IsInited())) {
  1864. return;
  1865. }
  1866. for (auto &root_graph_parameter : graph_compiler_info.origin_parameters_order_) {
  1867. MS_EXCEPTION_IF_NULL(root_graph_parameter);
  1868. if (!IsPersistentDeviceTensor(root_graph_parameter)) {
  1869. continue;
  1870. }
  1871. // The device tensor store has been done in the backend kernel graph corresponding to the root graph.
  1872. if (!DeviceTensorStore::GetInstance().Fetch(root_graph_parameter.get()).empty()) {
  1873. continue;
  1874. }
  1875. // The different root graph parameters may correspond to parameter of same sub kernel graph when call the same sub
  1876. // graph using the different root graph parameters. So can not use the device tensor of sub kernel graph parameter
  1877. // directly and choose the first backend parameter in sub kernel graphs to create new device tensor to make sure
  1878. // that the device tensor of root graph parameters are different.
  1879. const auto &backend_parameter_with_context =
  1880. parser->FetchBackendParameterWithContextByFrontParameter({root_graph_parameter, 0});
  1881. if (backend_parameter_with_context.first == nullptr) {
  1882. MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << root_graph_parameter->DebugString();
  1883. }
  1884. const auto &backend_node = backend_parameter_with_context.first;
  1885. const auto &device_context = backend_parameter_with_context.second;
  1886. MS_EXCEPTION_IF_NULL(backend_node);
  1887. MS_EXCEPTION_IF_NULL(device_context);
  1888. auto sub_device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
  1889. MS_EXCEPTION_IF_NULL(sub_device_tensor);
  1890. auto new_device_tensor =
  1891. device_context->CreateDeviceAddress(nullptr, sub_device_tensor->GetSize(), sub_device_tensor->format(),
  1892. sub_device_tensor->type_id(), sub_device_tensor->host_shape());
  1893. MS_EXCEPTION_IF_NULL(new_device_tensor);
  1894. new_device_tensor->SetNodeIndex(backend_node, 0);
  1895. new_device_tensor->set_is_ptr_persisted(sub_device_tensor->is_ptr_persisted());
  1896. new_device_tensor->set_from_persistent_mem(true);
  1897. AddDeviceTensorStore(root_graph_parameter.get(), new_device_tensor);
  1898. MS_LOG(INFO) << "Add device tensor store by root graph parameter:" << root_graph_parameter->fullname_with_scope()
  1899. << ", backend node:" << backend_node->DebugString()
  1900. << ", type:" << device_context->GetDeviceAddressType();
  1901. }
  1902. }
  1903. void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
  1904. MS_EXCEPTION_IF_NULL(actor_set);
  1905. const auto &context_ptr = MsContext::GetInstance();
  1906. MS_EXCEPTION_IF_NULL(context_ptr);
  1907. auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1908. if (!save_graphs) {
  1909. return;
  1910. }
  1911. // Get the saved actor set name.
  1912. auto &kernel_graphs = graph_compiler_info.graphs_;
  1913. MS_EXCEPTION_IF_NULL(kernel_graphs.front());
  1914. auto first_graph_id = kernel_graphs.front()->graph_id();
  1915. MS_EXCEPTION_IF_NULL(kernel_graphs.back());
  1916. auto last_graph_id = kernel_graphs.back()->graph_id();
  1917. std::string strategy = (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) ? "pipeline" : "step";
  1918. std::string save_name = "actor_set_" + strategy + "_kernel_graph_" + std::to_string(first_graph_id);
  1919. if (last_graph_id != first_graph_id) {
  1920. save_name = save_name + "-" + std::to_string(last_graph_id);
  1921. }
  1922. std::string filename = GetSaveGraphsPathName(save_name + ".ir");
  1923. std::ofstream ofs(filename);
  1924. if (!ofs.is_open()) {
  1925. MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
  1926. return;
  1927. }
  1928. DumpDeviceTensorStore(graph_compiler_info, ofs);
  1929. DumpDataPrepareActor(actor_set->data_prepare_actor_, ofs);
  1930. DumpDSActors(actor_set->data_source_actors_, ofs);
  1931. DumpKernelActors(actor_set->kernel_actors_, ofs);
  1932. DumpSuperKernelActors(actor_set->super_kernel_actors_, ofs);
  1933. // The on input kernel actors are taken over by control actor in the control flow scene.
  1934. if ((graph_compiler_info.control_node_parser_ == nullptr) ||
  1935. (!graph_compiler_info.control_node_parser_->IsInited())) {
  1936. DumpNoInputKernelActors(actor_set->no_input_kernel_actors_, ofs);
  1937. }
  1938. DumpCopyActors(actor_set->copy_actors_, ofs);
  1939. DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
  1940. DumpOutputActor(actor_set->output_actor_, ofs);
  1941. DumpControlActors(actor_set->control_actors_, ofs);
  1942. DumpCustomActors(actor_set->custom_actors_, ofs);
  1943. }
  1944. void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
  1945. ofs << "[Device tensor stores]\n";
  1946. for (const auto &graph : graph_compiler_info.graphs_) {
  1947. MS_EXCEPTION_IF_NULL(graph);
  1948. ofs << "\tgraph_id:" << graph->graph_id() << "\tis_executing_sink:" << graph->is_executing_sink()
  1949. << "\tis_loop_count_sink:" << graph->is_loop_count_sink()
  1950. << "\texecution_strategy:" << graph_compiler_info.strategy_ << "\n";
  1951. for (auto &value_node : graph->graph_value_nodes()) {
  1952. MS_EXCEPTION_IF_NULL(value_node);
  1953. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1954. continue;
  1955. }
  1956. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1957. MS_EXCEPTION_IF_NULL(front_node);
  1958. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1959. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1960. << "\n";
  1961. for (const auto &device_tensor : device_tensors) {
  1962. MS_EXCEPTION_IF_NULL(device_tensor);
  1963. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1964. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1965. << "\tdynamic_ref_count:" << device_tensor->dynamic_ref_count()
  1966. << "\tdevice_type:" << device_tensor->DeviceType()
  1967. << "\tis_ptr_persisted:" << device_tensor->is_ptr_persisted() << "\n ";
  1968. }
  1969. }
  1970. for (auto &input_node : graph->input_nodes()) {
  1971. MS_EXCEPTION_IF_NULL(input_node);
  1972. if (!IsPersistentDeviceTensor(input_node)) {
  1973. continue;
  1974. }
  1975. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1976. const auto &root_parameters = graph_compiler_info.origin_parameters_order_;
  1977. if (front_node == nullptr ||
  1978. find(root_parameters.begin(), root_parameters.end(), front_node) == root_parameters.end()) {
  1979. continue;
  1980. }
  1981. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1982. MS_EXCEPTION_IF_NULL(front_node);
  1983. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1984. << "\n";
  1985. for (const auto &device_tensor : device_tensors) {
  1986. MS_EXCEPTION_IF_NULL(device_tensor);
  1987. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1988. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1989. << "\tdynamic_ref_count:" << device_tensor->dynamic_ref_count()
  1990. << "\tdevice_type:" << device_tensor->DeviceType()
  1991. << "\tis_ptr_persisted:" << device_tensor->is_ptr_persisted() << "\n ";
  1992. }
  1993. }
  1994. ofs << "\n";
  1995. for (auto &backend_front_map : graph->backend_front_anf_map()) {
  1996. MS_EXCEPTION_IF_NULL(backend_front_map.first);
  1997. MS_EXCEPTION_IF_NULL(backend_front_map.second);
  1998. MS_LOG(DEBUG) << "Graph: " << graph->graph_id()
  1999. << ", backend node: " << backend_front_map.first->fullname_with_scope()
  2000. << ", front node: " << backend_front_map.second->DebugString();
  2001. }
  2002. }
  2003. }
  2004. } // namespace runtime
  2005. } // namespace mindspore