You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 107 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/graph_scheduler/graph_scheduler.h"
  17. #include <queue>
  18. #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
  19. #include "runtime/graph_scheduler/actor/debug_actor.h"
  20. #include "runtime/graph_scheduler/actor/recorder_actor.h"
  21. #include "runtime/hardware/device_context_manager.h"
  22. #include "mindrt/src/actor/actormgr.h"
  23. #include "mindrt/include/async/async.h"
  24. #include "backend/common/session/anf_runtime_algorithm.h"
  25. #include "include/common/utils/anfalgo.h"
  26. #include "backend/common/optimizer/helper.h"
  27. #include "utils/anf_utils.h"
  28. #include "include/common/utils/config_manager.h"
  29. #include "utils/log_adapter.h"
  30. #include "include/common/utils/convert_utils.h"
  31. #include "utils/ms_context.h"
  32. #include "utils/profile.h"
  33. #if !defined(_WIN32) && !defined(_WIN64)
  34. #include "include/common/utils/signal_util.h"
  35. #endif
  36. #ifndef ENABLE_SECURITY
  37. #include "debug/data_dump/dump_json_parser.h"
  38. #endif
  39. #ifdef ENABLE_DUMP_IR
  40. #include "debug/rdr/recorder_manager.h"
  41. #include "debug/rdr/running_data_recorder.h"
  42. #endif
  43. #ifdef ENABLE_DEBUGGER
  44. #include "debug/debugger/debugger.h"
  45. #endif
  46. #include "profiler/device/profiling.h"
  47. #include "debug/common.h"
  48. #include "runtime/recovery/recovery_context.h"
  49. namespace mindspore {
  50. namespace runtime {
  51. using recovery::RecoveryContext;
  52. namespace {
  53. bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
  54. MS_EXCEPTION_IF_NULL(from_device_context);
  55. MS_EXCEPTION_IF_NULL(to_device_context);
  56. if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
  57. return false;
  58. } else {
  59. return true;
  60. }
  61. }
  62. inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
  63. MS_EXCEPTION_IF_NULL(actor_set);
  64. return actor_set->kernel_actors_.size() == 1;
  65. }
  66. // Convert the actors vector by the actor set.
  67. std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
  68. MS_EXCEPTION_IF_NULL(actor_set);
  69. std::vector<AbstractActorPtr> actors;
  70. if (actor_set->data_prepare_actor_ != nullptr) {
  71. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
  72. }
  73. for (auto &data_source_actor : actor_set->data_source_actors_) {
  74. MS_EXCEPTION_IF_NULL(data_source_actor);
  75. (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
  76. }
  77. for (auto &custom_actor : actor_set->custom_actors_) {
  78. MS_EXCEPTION_IF_NULL(custom_actor);
  79. (void)actors.emplace_back(static_cast<AbstractActorPtr>(custom_actor));
  80. }
  81. for (auto &kernel_actor : actor_set->kernel_actors_) {
  82. MS_EXCEPTION_IF_NULL(kernel_actor);
  83. (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
  84. }
  85. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  86. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  87. (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
  88. }
  89. for (auto &copy_actor : actor_set->copy_actors_) {
  90. MS_EXCEPTION_IF_NULL(copy_actor);
  91. (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
  92. }
  93. if (actor_set->loop_count_actor_ != nullptr) {
  94. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
  95. }
  96. if (actor_set->output_actor_ != nullptr) {
  97. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
  98. }
  99. if (actor_set->control_actors_ != nullptr) {
  100. const auto &control_actor_set = actor_set->control_actors_;
  101. for (auto &switch_actor : control_actor_set->switch_actors_) {
  102. MS_EXCEPTION_IF_NULL(switch_actor);
  103. (void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
  104. }
  105. for (auto &gather_actor : control_actor_set->gather_actors_) {
  106. MS_EXCEPTION_IF_NULL(gather_actor);
  107. (void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
  108. }
  109. for (auto &entrance_actor : control_actor_set->entrance_actors_) {
  110. MS_EXCEPTION_IF_NULL(entrance_actor);
  111. (void)actors.emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
  112. }
  113. for (auto &exit_actor : control_actor_set->exit_actors_) {
  114. MS_EXCEPTION_IF_NULL(exit_actor);
  115. (void)actors.emplace_back(static_cast<AbstractActorPtr>(exit_actor));
  116. }
  117. for (auto &stack_actor : control_actor_set->stack_actors_) {
  118. MS_EXCEPTION_IF_NULL(stack_actor);
  119. (void)actors.emplace_back(static_cast<AbstractActorPtr>(stack_actor));
  120. }
  121. }
  122. return actors;
  123. }
  124. void ClearNodeInfo(const KernelGraphPtr &graph) {
  125. MS_EXCEPTION_IF_NULL(graph);
  126. // Clear input parameter device tensor and device tensor store.
  127. for (const auto &input_node : graph->input_nodes()) {
  128. MS_EXCEPTION_IF_NULL(input_node);
  129. if (!input_node->isa<Parameter>()) {
  130. continue;
  131. }
  132. auto parameter = input_node->cast<ParameterPtr>();
  133. MS_EXCEPTION_IF_NULL(parameter);
  134. parameter->DecreaseUsedGraphCount();
  135. // Only the parameter has no graph used, then clear the device tensor.
  136. if (parameter->used_graph_count() != 0) {
  137. continue;
  138. }
  139. auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
  140. DeviceTensorStore::GetInstance().Remove(front_input_node.get());
  141. size_t output_num = common::AnfAlgo::GetOutputTensorNum(input_node);
  142. for (size_t index = 0; index < output_num; ++index) {
  143. if (AnfAlgo::OutputAddrExist(input_node, index)) {
  144. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  145. }
  146. }
  147. }
  148. // Clear input value node device tensor and device tensor store.
  149. for (const auto &value_node : graph->graph_value_nodes()) {
  150. auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
  151. DeviceTensorStore::GetInstance().Remove(front_value_node.get());
  152. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  153. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  154. }
  155. }
  156. // Clear cnode device tensor.
  157. for (const auto &cnode : graph->execution_order()) {
  158. size_t output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
  159. for (size_t index = 0; index < output_num; ++index) {
  160. if (AnfAlgo::OutputAddrExist(cnode, index)) {
  161. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  162. }
  163. }
  164. }
  165. }
  166. #if !defined(_WIN32) && !defined(_WIN64)
  167. void IntHandler(int, siginfo_t *, void *) {
  168. int this_pid = getpid();
  169. MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
  170. (void)kill(this_pid, SIGTERM);
  171. }
  172. #endif
  173. } // namespace
  174. void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs,
  175. const std::vector<AnfNodePtr> &root_graph_parameters,
  176. const ControlNodeParserPtr &parser) noexcept {
  177. // Terminate the actors of actor info.
  178. if (actors_.count(actor_info) > 0) {
  179. auto actor_manager = ActorMgr::GetActorMgrRef();
  180. if (actor_manager == nullptr) {
  181. MS_LOG(ERROR) << "Actor manager is not exist.";
  182. return;
  183. }
  184. auto actor_set = actors_[actor_info];
  185. auto base_actors = CollectActors(actor_set.get());
  186. for (auto &base_actor : base_actors) {
  187. MS_EXCEPTION_IF_NULL(base_actor);
  188. EraseActor(base_actor->GetAID().Name());
  189. actor_manager->Terminate(base_actor->GetAID());
  190. }
  191. }
  192. // Clear device tensor and device tensor store.
  193. for (auto &graph : graphs) {
  194. ClearNodeInfo(graph);
  195. }
  196. if (parser != nullptr && parser->IsInited()) {
  197. const auto &front_value_nodes = parser->front_value_nodes();
  198. for (const auto &front_value_node : front_value_nodes) {
  199. const auto &node = front_value_node.first.first;
  200. size_t index = front_value_node.first.second;
  201. if (AnfAlgo::OutputAddrExist(node, index)) {
  202. AnfAlgo::SetOutputAddr(nullptr, index, node.get());
  203. }
  204. }
  205. }
  206. // Clear the member of DeviceTensorStore.
  207. for (auto &root_graph_parameter : root_graph_parameters) {
  208. DeviceTensorStore::GetInstance().Remove(root_graph_parameter.get());
  209. }
  210. // Clear global maps of actor info.
  211. (void)actors_.erase(actor_info);
  212. }
  213. void GraphScheduler::Clear() {
  214. // Terminate all actors.
  215. auto actor_manager = ActorMgr::GetActorMgrRef();
  216. MS_EXCEPTION_IF_NULL(actor_manager);
  217. actor_manager->Finalize();
  218. // Clear the member of DeviceTensorStore.
  219. DeviceTensorStore::GetInstance().Clear();
  220. // Clear global maps.
  221. actors_.clear();
  222. ClearAllActors();
  223. }
  224. void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
  225. MS_EXCEPTION_IF_NULL(actor_set);
  226. // Clear the member of DeviceTensorCopyStore.
  227. DeviceTensorCopyStore::GetInstance().Clear();
  228. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  229. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  230. super_kernel_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
  231. }
  232. control_node_scheduler_.ClearActorData(actor_set->control_actors_.get());
  233. // At the end of the step, the op data sent to the stack actor in each actor should be clear.
  234. auto total_actors = CollectActors(actor_set);
  235. for (auto &actor : total_actors) {
  236. MS_EXCEPTION_IF_NULL(actor);
  237. actor->to_stack_data_.clear();
  238. }
  239. }
  240. using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
  241. const KernelWithIndex &, const KernelGraphPtr &);
  242. static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;
  243. void GraphScheduler::Initialize() {
  244. if (init_) {
  245. return;
  246. }
  247. init_ = true;
  248. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
  249. &GraphScheduler::LinkDataArrowForBaseActor);
  250. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
  251. &GraphScheduler::LinkDataArrowForHostDSActor);
  252. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
  253. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSuperKernelActor,
  254. &GraphScheduler::LinkDataArrowForBaseActor);
  255. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
  256. &GraphScheduler::LinkDataArrowForDeviceTensorStore);
  257. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
  258. &GraphScheduler::LinkDataArrowForInternalParameter);
  259. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSendActor, &GraphScheduler::LinkDataArrowForBaseActor);
  260. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kRecvActor, &GraphScheduler::LinkDataArrowForBaseActor);
  261. // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
  262. size_t actor_thread_num = 0;
  263. size_t actor_and_kernel_thread_num = 0;
  264. ComputeThreadNums(&actor_thread_num, &actor_and_kernel_thread_num);
  265. auto actor_manager = ActorMgr::GetActorMgrRef();
  266. MS_EXCEPTION_IF_NULL(actor_manager);
  267. auto ret = actor_manager->Initialize(true, actor_thread_num, actor_and_kernel_thread_num);
  268. if (ret != MINDRT_OK) {
  269. MS_LOG(EXCEPTION) << "Actor manager init failed.";
  270. }
  271. common::SetOMPThreadNum();
  272. MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
  273. << ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num);
  274. #ifdef ENABLE_RPC_ACTOR
  275. // Create and initialize RpcNodeScheduler.
  276. rpc_node_scheduler_ = std::make_unique<RpcNodeScheduler>();
  277. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  278. rpc_node_scheduler_->Initialize();
  279. #endif
  280. BuildAndScheduleGlobalActor();
  281. }
  282. void GraphScheduler::BuildAndScheduleGlobalActor() {
  283. auto actor_manager = ActorMgr::GetActorMgrRef();
  284. MS_EXCEPTION_IF_NULL(actor_manager);
  285. // Create and schedule memory manager actor.
  286. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  287. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  288. memory_manager_aid_ = memory_manager_actor->GetAID();
  289. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  290. // Bind single thread to response to memory alloc and free quickly.
  291. (void)actor_manager->Spawn(base_actor, false);
  292. // Create and schedule recorder actor.
  293. auto recorder_actor = std::make_shared<RecorderActor>();
  294. MS_EXCEPTION_IF_NULL(recorder_actor);
  295. recorder_aid_ = &(recorder_actor->GetAID());
  296. auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
  297. (void)actor_manager->Spawn(base_recorder_actor, true);
  298. // Create and schedule debug actor.
  299. // debugger_actor_need is true for CPU when e2e dump is enabled and for Ascend and GPU is true when debugger or dump
  300. // is enabled.
  301. #ifndef ENABLE_SECURITY
  302. bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
  303. #endif
  304. #ifdef ENABLE_DEBUGGER
  305. if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
  306. debugger_actor_need = true;
  307. }
  308. #endif
  309. #ifndef ENABLE_SECURITY
  310. if (debugger_actor_need) {
  311. auto debug_actor = std::make_shared<DebugActor>();
  312. MS_EXCEPTION_IF_NULL(debug_actor);
  313. debug_aid_ = &(debug_actor->GetAID());
  314. auto base_debug_actor = static_cast<ActorReference>(debug_actor);
  315. (void)actor_manager->Spawn(base_debug_actor, true);
  316. }
  317. #endif
  318. }
  319. ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
  320. struct ScopeCleaner {
  321. GraphScheduler *const scheduler_;
  322. explicit ScopeCleaner(GraphScheduler *scheduler) : scheduler_(scheduler) {}
  323. ~ScopeCleaner() {
  324. // Local maps and vectors clear.
  325. if (scheduler_ == nullptr) {
  326. return;
  327. }
  328. scheduler_->graph_output_to_actor_.clear();
  329. scheduler_->copy_actors_.clear();
  330. }
  331. };
  332. // cppcheck-suppress unreadVariable
  333. ScopeCleaner cleaner(this);
  334. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
  335. if (graph_compiler_info.graphs_.size() == 0) {
  336. MS_LOG(EXCEPTION) << "The number of graphs is zero.";
  337. }
  338. if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
  339. MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
  340. }
  341. PersistDeviceTensor(graph_compiler_info);
  342. const auto &actor_set = Build(graph_compiler_info);
  343. MS_EXCEPTION_IF_NULL(actor_set);
  344. CacheGraphOutputToActor(graph_compiler_info);
  345. Link(actor_set.get(), graph_compiler_info);
  346. Optimize(actor_set.get());
  347. DumpActor(actor_set.get(), graph_compiler_info);
  348. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  349. CheckActorValid(actor_set.get());
  350. }
  351. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
  352. return actor_set.get();
  353. }
  354. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  355. MS_EXCEPTION_IF_NULL(actor_set);
  356. auto actors = CollectActors(actor_set);
  357. // Schedule actors.
  358. auto actor_manager = ActorMgr::GetActorMgrRef();
  359. MS_EXCEPTION_IF_NULL(actor_manager);
  360. for (auto actor : actors) {
  361. (void)actor_manager->Spawn(actor);
  362. }
  363. #ifdef ENABLE_RPC_ACTOR
  364. // Build physical connections in 'RpcNodeScheduler::Schedule()' method. This costs some time.
  365. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  366. rpc_node_scheduler_->Schedule();
  367. #endif
  368. }
  369. void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceContext *> &device_contexts,
  370. const std::vector<std::vector<TensorPtr>> &input_tensors,
  371. const std::vector<TensorPtr> &input_tensors_with_value_node, GraphExecutionStrategy strategy) {
  372. MS_EXCEPTION_IF_NULL(actor_set);
  373. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  374. #if !defined(_WIN32) && !defined(_WIN64)
  375. SignalGuard sg(IntHandler);
  376. #endif
  377. // Construct OpContext.
  378. OpContext<DeviceTensor> op_context;
  379. std::vector<Promise<int>> result(1);
  380. op_context.sequential_num_ = RandInt::Instance().Get();
  381. op_context.results_ = &result;
  382. #ifdef ENABLE_RPC_ACTOR
  383. // Set OpContext to rpc node scheduler.
  384. auto op_context_setter = std::make_shared<RpcActorOpContextSetter>(rpc_node_scheduler_.get(), &op_context);
  385. MS_EXCEPTION_IF_NULL(op_context_setter);
  386. #endif
  387. if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  388. actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context, GraphExecutionStrategy::kStep);
  389. MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
  390. actor_set->kernel_actors_[0]->RunOpControlWithInputTensor(nullptr, &op_context, &input_tensors_with_value_node);
  391. return;
  392. }
  393. // Trigger data prepare actor running.
  394. MS_EXCEPTION_IF_NULL(ActorMgr::GetActorMgrRef());
  395. auto thread_pool = ActorMgr::GetActorMgrRef()->GetActorThreadPool();
  396. MS_EXCEPTION_IF_NULL(thread_pool);
  397. ActorDispatcher::is_multi_thread_execution(actor_set->is_multi_thread_execution_);
  398. double start_time = GetTime();
  399. ActorDispatcher::Send(actor_set->data_prepare_actor_->GetAID(), &DataPrepareActor::PrepareData, input_tensors,
  400. &op_context, GraphExecutionStrategy::kPipeline);
  401. // Get the run result.
  402. auto result_future = result[0].GetFuture();
  403. result_future.Wait();
  404. MsException::Instance().CheckException();
  405. if (!result_future.IsOK()) {
  406. #ifdef ENABLE_DUMP_IR
  407. mindspore::RDR::TriggerAll();
  408. #endif
  409. // When temporary variable 'op_context' has beed set failed status, the main thread need wait other threads until
  410. // they finish respective task, otherwise segmentation fault will happen when these task access 'op_context',
  411. // because it has been destroyed.
  412. std::mutex mutex;
  413. std::unique_lock<std::mutex> locker(mutex);
  414. std::condition_variable thread_blocker;
  415. const int64_t kTimeToWait = 2;
  416. (void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait));
  417. // May set exception in the wait time, need throw the exception to avoid affecting the next execution.
  418. MsException::Instance().CheckException();
  419. MS_LOG(EXCEPTION) << op_context.error_info_;
  420. }
  421. double end_time = GetTime();
  422. const size_t kSecondsToMilliseconds = 1000;
  423. SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
  424. if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
  425. MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
  426. if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
  427. MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
  428. }
  429. MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";
  430. RecoveryContext::GetInstance()->set_need_reinit_collective(false);
  431. }
  432. }
  433. void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
  434. double execution_time) const {
  435. MS_EXCEPTION_IF_NULL(actor_set);
  436. MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_);
  437. ++actor_set->execution_count_;
  438. MS_LOG(DEBUG) << "Execution count: " << actor_set->execution_count_ << ", execution time cost: " << execution_time
  439. << " ms in multi thread or not: " << actor_set->is_multi_thread_execution_ << ".";
  440. #if defined(_WIN32) || defined(_WIN64)
  441. return;
  442. #endif
  443. // The step mode uses the default multi thread.
  444. if (strategy == GraphExecutionStrategy::kStep) {
  445. return;
  446. }
  447. // The constraint condition of not supporting the single thread execution.
  448. if ((actor_set->control_actors_ != nullptr) || (actor_set->copy_actors_.size() > 0) ||
  449. (actor_set->super_kernel_actors_.size() > 0) || (actor_set->loop_count_actor_->loop_count() > 1) ||
  450. (actor_set->kernel_actors_.size() > ActorDispatcher::kSingleThreadExecutionActorMaxNum)) {
  451. return;
  452. }
  453. if ((actor_set->is_multi_thread_execution_) &&
  454. (actor_set->execution_count_ >= ActorDispatcher::kMultiThreadExecutionCountBegin) &&
  455. (actor_set->execution_count_ <= ActorDispatcher::kMultiThreadExecutionCountEnd)) {
  456. actor_set->multi_thread_execution_time_ += execution_time;
  457. if (actor_set->execution_count_ == ActorDispatcher::kMultiThreadExecutionCountEnd) {
  458. actor_set->multi_thread_execution_time_ /=
  459. ((ActorDispatcher::kMultiThreadExecutionCountEnd - ActorDispatcher::kMultiThreadExecutionCountBegin) + 1);
  460. actor_set->is_multi_thread_execution_ = false;
  461. }
  462. return;
  463. }
  464. if ((!actor_set->is_multi_thread_execution_) &&
  465. (actor_set->execution_count_ >= ActorDispatcher::kSingleThreadExecutionCountBegin) &&
  466. (actor_set->execution_count_ <= ActorDispatcher::kSingleThreadExecutionCountEnd)) {
  467. actor_set->single_thread_execution_time_ += execution_time;
  468. if (actor_set->execution_count_ == ActorDispatcher::kSingleThreadExecutionCountEnd) {
  469. actor_set->single_thread_execution_time_ /=
  470. (ActorDispatcher::kSingleThreadExecutionCountEnd - ActorDispatcher::kSingleThreadExecutionCountBegin + 1);
  471. actor_set->is_multi_thread_execution_ =
  472. (actor_set->multi_thread_execution_time_ <= actor_set->single_thread_execution_time_) ? true : false;
  473. MS_LOG(INFO) << "Multi thread execution time cost: " << actor_set->multi_thread_execution_time_
  474. << " ms, single thread execution time cost: " << actor_set->single_thread_execution_time_
  475. << " ms, decide to use multi thread execution or not: " << actor_set->is_multi_thread_execution_
  476. << ".";
  477. }
  478. return;
  479. }
  480. }
  481. ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
  482. auto iter = actors_.find(actor_info);
  483. if (iter != actors_.end()) {
  484. return iter->second.get();
  485. } else {
  486. MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
  487. return nullptr;
  488. }
  489. }
  490. ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
  491. auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
  492. MS_EXCEPTION_IF_NULL(actor_set);
  493. (void)actors_.emplace(actor_set->name_, actor_set);
  494. auto host_queue = std::make_shared<HostTensorQueue>();
  495. actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
  496. actor_set->custom_actors_ = BuildCustomActor(graph_compiler_info);
  497. actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
  498. actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
  499. actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
  500. actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
  501. actor_set->data_prepare_actor_ =
  502. BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
  503. actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_);
  504. #ifdef ENABLE_RPC_ACTOR
  505. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  506. actor_set->rpc_actors_ = rpc_node_scheduler_->Build(graph_compiler_info);
  507. #endif
  508. return actor_set;
  509. }
  510. void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
  511. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  512. return;
  513. }
  514. for (const auto &graph : graph_compiler_info.graphs_) {
  515. MS_EXCEPTION_IF_NULL(graph);
  516. auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
  517. for (const auto &output_with_index : outputs) {
  518. auto output_kernel = output_with_index.first;
  519. MS_EXCEPTION_IF_NULL(output_kernel);
  520. auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
  521. if (origin_output_with_index.first == nullptr) {
  522. MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  523. << " with index: " << output_with_index.second << " has no front node.";
  524. continue;
  525. }
  526. auto kernel_type = FetchKernelTransformType(output_kernel, graph, graph_compiler_info.origin_parameters_order_);
  527. auto output_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_kernel, graph);
  528. if (output_actor == nullptr) {
  529. MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  530. << " with index:" << output_with_index.second
  531. << " is not actor, and the kernel type is:" << kernel_type;
  532. }
  533. auto output_actor_name = (output_actor != nullptr) ? output_actor->GetAID().Name() : "";
  534. (void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(output_actor, output_with_index));
  535. MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  536. << " with index:" << output_with_index.second << " to actor:" << output_actor_name
  537. << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
  538. << " with index:" << origin_output_with_index.second;
  539. }
  540. }
  541. }
  542. void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
  543. MS_EXCEPTION_IF_NULL(actor_set);
  544. std::vector<AbstractActor *> auto_monad_actors;
  545. GroupNameToCommuNodes group_name_to_communication_nodes;
  546. std::string default_group_name = "";
  547. const auto &parser = graph_compiler_info.control_node_parser_;
  548. MS_EXCEPTION_IF_NULL(parser);
  549. for (const auto &graph : graph_compiler_info.graphs_) {
  550. MS_EXCEPTION_IF_NULL(graph);
  551. if (graph->execution_order().empty()) {
  552. MS_LOG(INFO) << "The graph " << graph->graph_id() << " is an empty graph and skips linking.";
  553. continue;
  554. }
  555. if (graph->is_executing_sink()) {
  556. LinkDataArrowInSinkMode(graph, graph_compiler_info, &auto_monad_actors);
  557. } else {
  558. // In the control flow, the communication nodes need to be guaranteed to be executed in order. The order
  559. // within the kernel graph group needs to add control arrows between the communication nodes, and the order
  560. // between groups is guaranteed by the control flow framework. Therefore, communication nodes need to be
  561. // grouped by group name. And this is not required in non-control flow, the default unified group name is used.
  562. std::vector<CNodePtr> communication_nodes;
  563. const auto &group_name = (parser->IsInited() ? parser->FetchGroupNameByKernelGraph(graph) : default_group_name);
  564. LinkDataArrowInNonSinkMode(graph, graph_compiler_info, &auto_monad_actors, &communication_nodes);
  565. group_name_to_communication_nodes[group_name].insert(group_name_to_communication_nodes[group_name].end(),
  566. communication_nodes.begin(), communication_nodes.end());
  567. }
  568. }
  569. LinkGlobalControlArrow(actor_set, group_name_to_communication_nodes, auto_monad_actors, graph_compiler_info);
  570. LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
  571. // The copy actors are built in the link, so need push into the actor set after link.
  572. actor_set->copy_actors_ = copy_actors_;
  573. // Link the arrow in the control flow scene.
  574. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline &&
  575. graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited()) {
  576. control_node_scheduler_.Link(actor_set, graph_compiler_info);
  577. }
  578. #ifdef ENABLE_RPC_ACTOR
  579. // Link inter-process arrows for rpc actors.
  580. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  581. rpc_node_scheduler_->Link(actor_set);
  582. #endif
  583. }
  584. void GraphScheduler::Optimize(ActorSet *const actor_set) {
  585. MS_EXCEPTION_IF_NULL(actor_set);
  586. control_node_scheduler_.Optimize(actor_set->control_actors_.get());
  587. }
  588. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  589. const HostTensorQueuePtr &host_queue) {
  590. std::vector<DataSourceActorPtr> data_source_actors;
  591. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  592. size_t data_node_position = 0;
  593. mindspore::HashMap<AnfNodePtr, size_t> front_node_position_temp_map;
  594. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  595. const auto &graph = graph_compiler_info.graphs_[i];
  596. const auto &device_context = graph_compiler_info.device_contexts_[i];
  597. MS_EXCEPTION_IF_NULL(graph);
  598. // Build host queue data source actor.
  599. const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
  600. const auto &root_parameters = graph_compiler_info.origin_parameters_order_;
  601. for (size_t j = 0; j < input_nodes.size(); j++) {
  602. const auto &input_node = input_nodes[j];
  603. MS_EXCEPTION_IF_NULL(input_node);
  604. if (IsHostQueueDSActor(input_node, graph, root_parameters, graph_compiler_info.strategy_)) {
  605. // In control flow, parameters from subgraph need not init in data source actor.
  606. if (graph_compiler_info.control_node_parser_->IsInited()) {
  607. auto node_with_index = graph->GetElementInTupleBackendFrontIndexMap(input_node);
  608. if (node_with_index.first != nullptr && node_with_index.first->isa<Parameter>() &&
  609. find(root_parameters.begin(), root_parameters.end(), node_with_index.first) == root_parameters.end())
  610. continue;
  611. }
  612. if (host_queue_ds_actor == nullptr) {
  613. auto actor_name = graph_compiler_info.name_ + kHostDSActorNameSuffix;
  614. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  615. host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
  616. nullptr, host_queue);
  617. InsertActor(host_queue_ds_actor.get());
  618. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  619. }
  620. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  621. // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
  622. // is saved in the host queue data source actor.
  623. if (front_node_position_temp_map.count(front_node) > 0) {
  624. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
  625. front_node_position_temp_map[front_node]);
  626. continue;
  627. }
  628. (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  629. (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
  630. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
  631. // In control flow, need to rely on the front node to find the location of the corresponding real parameter.
  632. (void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position);
  633. (void)front_node_position_temp_map.emplace(front_node, data_node_position);
  634. data_node_position++;
  635. }
  636. }
  637. // The graph sink mode has no device queue data source actor.
  638. if (!graph->is_executing_sink()) {
  639. // Build device queue data source actor.
  640. const auto &execution_order = graph->execution_order();
  641. const auto &iter =
  642. std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
  643. return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
  644. });
  645. if (iter != execution_order.end()) {
  646. auto actor_name =
  647. graph_compiler_info.name_ + kDeviceDSActorNameSuffix + "_" + std::to_string(graph->graph_id());
  648. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  649. auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
  650. actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
  651. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  652. InsertActor(device_queue_ds_actor.get());
  653. (void)data_source_actors.emplace_back(device_queue_ds_actor);
  654. device_queue_ds_actor->data_kernel_ = *iter;
  655. device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
  656. }
  657. }
  658. }
  659. control_node_scheduler_.BuildDataSourceActorForControlNode(graph_compiler_info, host_queue, host_queue_ds_actor,
  660. memory_manager_aid_, &data_source_actors);
  661. return data_source_actors;
  662. }
  663. std::vector<CustomActorPtr> GraphScheduler::BuildCustomActor(const GraphCompilerInfo &graph_compiler_info) {
  664. std::vector<CustomActorPtr> custom_actors;
  665. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  666. const auto &device_context = graph_compiler_info.device_contexts_[i];
  667. const auto &graph = graph_compiler_info.graphs_[i];
  668. MS_EXCEPTION_IF_NULL(graph);
  669. if (graph->is_executing_sink()) {
  670. continue;
  671. }
  672. auto all_nodes = TopoSort(graph->get_return());
  673. for (const auto &node : all_nodes) {
  674. if (!AnfUtils::IsCustomActorNode(node)) {
  675. continue;
  676. }
  677. auto actor_name = AnfUtils::GetCustomActorName(node);
  678. auto custom_actor = std::make_shared<CustomActor>(actor_name, node, device_context, recorder_aid_);
  679. MS_EXCEPTION_IF_NULL(custom_actor);
  680. InsertActor(custom_actor.get());
  681. custom_actors.emplace_back(custom_actor);
  682. }
  683. }
  684. return custom_actors;
  685. }
  686. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  687. std::vector<KernelActorPtr> kernel_actors;
  688. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  689. const auto &graph = graph_compiler_info.graphs_[i];
  690. const auto &device_context = graph_compiler_info.device_contexts_[i];
  691. MS_EXCEPTION_IF_NULL(graph);
  692. if (graph->is_executing_sink()) {
  693. continue;
  694. }
  695. auto execution_order = graph->execution_order();
  696. // Single op graph in step mode, kernel actor executes synchronously.
  697. bool is_single_op_graph = execution_order.size() == 1;
  698. GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
  699. if (strategy == GraphExecutionStrategy::kStep) {
  700. strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
  701. }
  702. for (auto &kernel : execution_order) {
  703. MS_EXCEPTION_IF_NULL(kernel);
  704. if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
  705. auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
  706. auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph);
  707. KernelActorPtr kernel_actor = nullptr;
  708. if (IsRpcActor(kernel)) {
  709. kernel_actor = GenerateRpcActor(kernel, device_context, strategy, ref_input_indexes, ref_output_indexes);
  710. } else {
  711. kernel_actor =
  712. std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
  713. debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
  714. }
  715. MS_EXCEPTION_IF_NULL(kernel_actor);
  716. InsertActor(kernel_actor.get());
  717. (void)kernel_actors.emplace_back(kernel_actor);
  718. }
  719. }
  720. }
  721. return kernel_actors;
  722. }
  723. std::vector<SuperKernelActorPtr> GraphScheduler::BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  724. std::vector<SuperKernelActorPtr> super_kernel_actors;
  725. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  726. const auto &graph = graph_compiler_info.graphs_[i];
  727. const auto &device_context = graph_compiler_info.device_contexts_[i];
  728. MS_EXCEPTION_IF_NULL(graph);
  729. if (!graph->is_executing_sink()) {
  730. continue;
  731. }
  732. if (graph->execution_order().empty()) {
  733. MS_LOG(INFO) << "The graph " << graph->graph_id() << " is an empty graph and skips building.";
  734. continue;
  735. }
  736. auto actor_name = graph->ToString() + kSuperKernelActorNameSuffix;
  737. auto super_kernel_actor =
  738. std::make_shared<SuperKernelActor>(actor_name, graph, device_context, memory_manager_aid_, debug_aid_, nullptr);
  739. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  740. InsertActor(super_kernel_actor.get());
  741. (void)super_kernel_actors.emplace_back(super_kernel_actor);
  742. }
  743. return super_kernel_actors;
  744. }
  745. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
  746. auto actor_set = Fetch(graph_compiler_info.name_);
  747. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  748. return nullptr;
  749. }
  750. auto loop_count = ConfigManager::GetInstance().iter_num();
  751. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  752. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  753. loop_count = 1;
  754. }
  755. auto actor_name = graph_compiler_info.name_ + kLoopCountActorNameSuffix;
  756. auto loop_count_actor =
  757. std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_,
  758. graph_compiler_info.strategy_, graph_compiler_info.device_contexts_);
  759. MS_LOG(INFO) << "Create loop count actor: " << actor_name;
  760. MS_EXCEPTION_IF_NULL(loop_count_actor);
  761. InsertActor(loop_count_actor.get());
  762. return loop_count_actor;
  763. }
  764. OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
  765. auto actor_set = Fetch(graph_compiler_info.name_);
  766. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  767. return nullptr;
  768. }
  769. auto loop_count = ConfigManager::GetInstance().iter_num();
  770. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  771. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  772. loop_count = 1;
  773. }
  774. auto actor_name = graph_compiler_info.name_ + kOutputActorNameSuffix;
  775. auto output_actor = std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_);
  776. MS_LOG(INFO) << "Create output actor: " << actor_name;
  777. MS_EXCEPTION_IF_NULL(output_actor);
  778. InsertActor(output_actor.get());
  779. return output_actor;
  780. }
  781. DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  782. const std::vector<DataSourceActorPtr> &data_source_actors,
  783. const HostTensorQueuePtr &host_queue) {
  784. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  785. auto iter = std::find_if(data_source_actors.begin(), data_source_actors.end(), [&](const auto &data_source_actor) {
  786. return data_source_actor->type_ == KernelTransformType::kHostDataSourceActor;
  787. });
  788. if (iter != data_source_actors.end()) {
  789. host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
  790. }
  791. auto actor_name = graph_compiler_info.name_ + kDataPrepareActorNameSuffix;
  792. auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
  793. &graph_compiler_info, host_queue_ds_actor, host_queue);
  794. MS_LOG(INFO) << "Create data prepare actor: " << actor_name;
  795. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  796. // Cache the nodes which need continuous memory.
  797. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  798. for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
  799. const auto &graph = graph_compiler_info.graphs_[index];
  800. MS_EXCEPTION_IF_NULL(graph);
  801. if (graph->is_executing_sink()) {
  802. continue;
  803. }
  804. auto &execution_order = graph->execution_order();
  805. for (auto &kernel : execution_order) {
  806. if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
  807. continue;
  808. }
  809. auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
  810. auto value = std::make_pair(false, false);
  811. if (common::AnfAlgo::GetInputTensorNum(kernel) > 1) {
  812. value.first = true;
  813. }
  814. if (common::AnfAlgo::GetOutputTensorNum(kernel) > 1) {
  815. value.second = true;
  816. }
  817. if ((value.first == true) || (value.second == true)) {
  818. data_prepare_actor->continuous_memory_nodes_[key] = value;
  819. }
  820. }
  821. }
  822. }
  823. InsertActor(data_prepare_actor.get());
  824. return data_prepare_actor;
  825. }
  826. std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
  827. GraphExecutionStrategy strategy) {
  828. MS_EXCEPTION_IF_NULL(actor_set);
  829. std::vector<AbstractActorPtr> no_input_kernel_actors;
  830. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  831. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  832. if ((super_kernel_actor->input_datas_num_ == 0) && (super_kernel_actor->input_controls_num_ == 0)) {
  833. (void)no_input_kernel_actors.emplace_back(super_kernel_actor);
  834. }
  835. }
  836. for (auto &kernel_actor : actor_set->kernel_actors_) {
  837. MS_EXCEPTION_IF_NULL(kernel_actor);
  838. // Framework will trigger kernel actor running in the step execution strategy.
  839. if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
  840. kernel_actor->input_controls_num_++;
  841. continue;
  842. }
  843. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  844. (void)no_input_kernel_actors.emplace_back(kernel_actor);
  845. }
  846. }
  847. return no_input_kernel_actors;
  848. }
  849. KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
  850. GraphExecutionStrategy strategy,
  851. const std::set<size_t> &ref_input_indexes,
  852. const std::set<size_t> &ref_output_indexes) {
  853. MS_EXCEPTION_IF_NULL(kernel);
  854. MS_EXCEPTION_IF_NULL(device_context);
  855. #ifdef ENABLE_RPC_ACTOR
  856. MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
  857. if (common::AnfAlgo::GetCNodeName(kernel) == kRpcSendOpName) {
  858. auto send_actor =
  859. std::make_shared<SendActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
  860. debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
  861. MS_EXCEPTION_IF_NULL(send_actor);
  862. rpc_node_scheduler_->InsertSendActor(send_actor);
  863. return send_actor;
  864. } else if (common::AnfAlgo::GetCNodeName(kernel) == kRpcRecvOpName) {
  865. auto recv_actor =
  866. std::make_shared<RecvActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
  867. debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
  868. MS_EXCEPTION_IF_NULL(recv_actor);
  869. rpc_node_scheduler_->InsertRecvActor(recv_actor);
  870. return recv_actor;
  871. } else {
  872. MS_LOG(EXCEPTION) << "Kernel " << kernel->fullname_with_scope() << " is not an rpc kernel.";
  873. }
  874. #endif
  875. return nullptr;
  876. }
  877. void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  878. std::vector<AbstractActor *> *const auto_monad_actors) {
  879. MS_EXCEPTION_IF_NULL(graph);
  880. // The data arrow linking is taken over by the control flow.
  881. if (graph_compiler_info.control_node_parser_ != nullptr &&
  882. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, nullptr)) {
  883. return;
  884. }
  885. auto to_actor_name = graph->ToString() + kSuperKernelActorNameSuffix;
  886. auto to_actor = FetchActor(to_actor_name);
  887. MS_EXCEPTION_IF_NULL(to_actor);
  888. auto &input_nodes = graph->input_nodes();
  889. for (size_t node_index = 0; node_index < input_nodes.size(); ++node_index) {
  890. auto &input_node = input_nodes[node_index];
  891. MS_EXCEPTION_IF_NULL(input_node);
  892. if (HasAbstractMonad(input_node)) {
  893. MS_LOG(INFO) << "The graph:" << graph->graph_id()
  894. << " has abstract monad input node:" << input_node->DebugString() << ", input index:" << node_index;
  895. LinkControlArrowByAutoMonad(to_actor, input_node, graph);
  896. continue; // No data arrow for monad input.
  897. }
  898. UpdateRefCount(input_node, 0, true);
  899. KernelWithIndex from_kernel_with_output_idx = std::make_pair(input_node, 0);
  900. KernelWithIndex to_kernel_with_input_idx = std::make_pair(input_node, node_index);
  901. // The gather of linking data arrows of kernel by the different from kernel type.
  902. LinkDataArrow(to_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  903. }
  904. std::vector<CNodePtr> auto_monad_kernels;
  905. // Foreach the execution order to get the auto monad kernels.
  906. auto &execution_order = graph->execution_order();
  907. (void)std::for_each(execution_order.begin(), execution_order.end(), [&](const CNodePtr &kernel) {
  908. for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
  909. auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
  910. if (HasAbstractMonad(input_node)) {
  911. (void)auto_monad_kernels.emplace_back(kernel);
  912. continue;
  913. }
  914. }
  915. });
  916. // Foreach auto monad kernels to get the auto monad device tensor stores.
  917. (void)std::for_each(auto_monad_kernels.begin(), auto_monad_kernels.end(), [&](const CNodePtr &kernel) {
  918. for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(kernel); ++i) {
  919. KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::GetPrevNodeOutput(kernel, i, false);
  920. auto front_node = FetchFrontNodeByBackendNode(from_kernel_with_output_idx.first, graph);
  921. if (IsPersistentDeviceTensor(front_node)) {
  922. (void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
  923. }
  924. }
  925. });
  926. if (to_actor->auto_monad_device_tensor_stores_.size() > 0) {
  927. (void)auto_monad_actors->emplace_back(to_actor);
  928. }
  929. }
  930. void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
  931. const GraphCompilerInfo &graph_compiler_info,
  932. std::vector<AbstractActor *> *const auto_monad_actors,
  933. std::vector<CNodePtr> *const communication_nodes) {
  934. MS_EXCEPTION_IF_NULL(graph);
  935. MS_EXCEPTION_IF_NULL(auto_monad_actors);
  936. MS_EXCEPTION_IF_NULL(communication_nodes);
  937. const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
  938. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
  939. auto &execution_order = graph->execution_order();
  940. // Foreach the execution order to link the actors.
  941. for (const auto &kernel : execution_order) {
  942. MS_EXCEPTION_IF_NULL(kernel);
  943. if (common::AnfAlgo::IsCommunicationOp(kernel)) {
  944. (void)communication_nodes->emplace_back(kernel);
  945. }
  946. if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
  947. continue;
  948. }
  949. const auto &kernel_actor = FetchActor(kernel->fullname_with_scope());
  950. MS_EXCEPTION_IF_NULL(kernel_actor);
  951. for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
  952. auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
  953. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
  954. if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
  955. LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_);
  956. }
  957. if (HasAbstractMonad(input_node)) {
  958. (void)auto_monad_actors->emplace_back(kernel_actor);
  959. continue; // No data arrow for monad input.
  960. }
  961. KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  962. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  963. // The data arrow linking is taken over by the control flow.
  964. if (graph_compiler_info.control_node_parser_ != nullptr &&
  965. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, from_kernel_with_output_idx.first)) {
  966. continue;
  967. }
  968. // The gather of linking data arrows of kernel by the different from kernel type.
  969. LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  970. }
  971. }
  972. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  973. LinkControlArrowBySendRecvNodes(graph);
  974. }
  975. void GraphScheduler::LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  976. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  977. const KernelWithIndex &to_kernel_with_input_idx) {
  978. MS_EXCEPTION_IF_NULL(to_actor);
  979. MS_EXCEPTION_IF_NULL(graph);
  980. auto from_kernel = from_kernel_with_output_idx.first;
  981. MS_EXCEPTION_IF_NULL(from_kernel);
  982. auto kernel_type = FetchKernelTransformType(from_kernel, graph, graph_compiler_info.origin_parameters_order_,
  983. graph_compiler_info.strategy_);
  984. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, from_kernel, graph);
  985. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  986. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  987. MS_LOG(WARNING) << "Invalid from node:" << from_kernel->fullname_with_scope() << ", type:" << kernel_type;
  988. }
  989. return;
  990. }
  991. (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
  992. to_kernel_with_input_idx, graph);
  993. }
  994. void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, AbstractActor *const to_actor,
  995. const KernelWithIndex &from_kernel_with_output_idx,
  996. const KernelWithIndex &to_kernel_with_input_idx,
  997. const KernelGraphPtr &graph) {
  998. MS_EXCEPTION_IF_NULL(to_actor);
  999. MS_EXCEPTION_IF_NULL(graph);
  1000. auto from_kernel = from_kernel_with_output_idx.first;
  1001. MS_EXCEPTION_IF_NULL(from_kernel);
  1002. auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
  1003. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
  1004. }
  1005. void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, AbstractActor *to_actor,
  1006. const KernelWithIndex &from_kernel_with_output_idx,
  1007. const KernelWithIndex &to_kernel_with_input_idx,
  1008. const KernelGraphPtr &graph) {
  1009. MS_EXCEPTION_IF_NULL(to_actor);
  1010. MS_EXCEPTION_IF_NULL(graph);
  1011. auto internal_parameter = from_kernel_with_output_idx.first;
  1012. MS_EXCEPTION_IF_NULL(internal_parameter);
  1013. // Parameter ---> front node.
  1014. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
  1015. auto front_output_node = front_output_with_index.first;
  1016. MS_EXCEPTION_IF_NULL(front_output_node);
  1017. if (IsSwitchActor(front_output_node)) {
  1018. return;
  1019. }
  1020. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1021. AbstractActor *real_from_actor = nullptr;
  1022. KernelTransformType kernel_type;
  1023. if (IsPersistentDeviceTensor(front_output_node)) {
  1024. kernel_type = KernelTransformType::kDeviceTensorStore;
  1025. } else {
  1026. // front node ---> actor.
  1027. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1028. MS_LOG(EXCEPTION) << "Can't find actor by front node:" << common::AnfAlgo::GetNodeDebugString(front_output_node)
  1029. << ", internal parameter:" << common::AnfAlgo::GetNodeDebugString(internal_parameter);
  1030. }
  1031. auto actor_pair = graph_output_to_actor_[front_output_with_index];
  1032. MS_EXCEPTION_IF_NULL(actor_pair.first);
  1033. MS_EXCEPTION_IF_NULL(actor_pair.second.first);
  1034. MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
  1035. << ", corresponding front node:" << front_output_node->fullname_with_scope()
  1036. << " with index:" << front_output_with_index.second
  1037. << ", from actor:" << actor_pair.first->GetAID().Name()
  1038. << " node:" << actor_pair.second.first->fullname_with_scope()
  1039. << " with index:" << actor_pair.second.second << ", to actor:" << to_actor->GetAID().Name()
  1040. << " with index:" << to_kernel_with_input_idx.second;
  1041. real_from_actor = actor_pair.first;
  1042. real_from_kernel_with_output_idx = actor_pair.second;
  1043. kernel_type = actor_pair.first->type_;
  1044. }
  1045. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  1046. MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
  1047. }
  1048. (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
  1049. to_kernel_with_input_idx, graph);
  1050. }
  1051. void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1052. const KernelWithIndex &from_kernel_with_output_idx,
  1053. const KernelWithIndex &to_kernel_with_input_idx,
  1054. const KernelGraphPtr &) {
  1055. MS_EXCEPTION_IF_NULL(from_actor);
  1056. MS_EXCEPTION_IF_NULL(to_actor);
  1057. auto from_kernel = from_kernel_with_output_idx.first;
  1058. MS_EXCEPTION_IF_NULL(from_kernel);
  1059. auto from_output_index = from_kernel_with_output_idx.second;
  1060. auto to_input_index = to_kernel_with_input_idx.second;
  1061. // Get the position of from kernel in the data source actor.
  1062. auto position = from_actor->FetchNodePosition(from_kernel);
  1063. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  1064. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  1065. }
  1066. if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
  1067. LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  1068. } else {
  1069. AddDataArrow(from_actor, to_actor, from_kernel, from_output_index, to_input_index);
  1070. }
  1071. }
  1072. void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1073. const KernelWithIndex &from_kernel_with_output_idx,
  1074. const KernelWithIndex &to_kernel_with_input_idx,
  1075. const KernelGraphPtr &graph) {
  1076. auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1077. MS_EXCEPTION_IF_NULL(host_ds_actor);
  1078. MS_EXCEPTION_IF_NULL(from_kernel_with_output_idx.first);
  1079. KernelWithIndex real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1080. // Get the position and real kernel by from kernel in the data source actor.
  1081. auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
  1082. real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(position);
  1083. LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx, graph);
  1084. }
  1085. void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1086. const KernelWithIndex &from_kernel_with_output_idx,
  1087. const KernelWithIndex &to_kernel_with_input_idx,
  1088. const KernelGraphPtr &graph) {
  1089. auto real_from_actor = from_actor;
  1090. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1091. auto from_kernel = from_kernel_with_output_idx.first;
  1092. // Update the from kernel info by the real node info.
  1093. MS_EXCEPTION_IF_NULL(from_kernel);
  1094. if (IsSkippedKernelActor(from_kernel)) {
  1095. real_from_kernel_with_output_idx = common::AnfAlgo::GetPrevNodeOutput(from_kernel, 0, false);
  1096. MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
  1097. LinkControlArrowBySkippedNode(to_actor, from_kernel);
  1098. MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
  1099. MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
  1100. << to_kernel_with_input_idx.first->fullname_with_scope()
  1101. << ", aggregate input index: " << to_kernel_with_input_idx.second
  1102. << ", skip node: " << from_kernel->fullname_with_scope()
  1103. << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
  1104. real_from_actor = FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope());
  1105. MS_EXCEPTION_IF_NULL(real_from_actor);
  1106. }
  1107. LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx,
  1108. graph);
  1109. }
  1110. void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1111. const KernelWithIndex &from_kernel_with_output_idx,
  1112. const KernelWithIndex &to_kernel_with_input_idx) {
  1113. MS_EXCEPTION_IF_NULL(from_actor);
  1114. MS_EXCEPTION_IF_NULL(to_actor);
  1115. auto from_kernel = from_kernel_with_output_idx.first;
  1116. MS_EXCEPTION_IF_NULL(from_kernel);
  1117. std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
  1118. "_output_index:" + std::to_string(from_kernel_with_output_idx.second);
  1119. CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
  1120. // Link between from actor and copy actor.
  1121. if (copy_actor == nullptr) {
  1122. // Create the copy actor.
  1123. auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1124. (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
  1125. copy_actor = copy_actor_shared_ptr.get();
  1126. MS_EXCEPTION_IF_NULL(copy_actor);
  1127. InsertActor(copy_actor);
  1128. // Set the member device_contexts_ of the copy actor.
  1129. auto position = from_actor->FetchNodePosition(from_kernel);
  1130. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  1131. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  1132. }
  1133. auto from_device_context = from_actor->device_contexts_[position];
  1134. auto to_device_context = to_actor->device_contexts_[0];
  1135. MS_EXCEPTION_IF_NULL(from_device_context);
  1136. MS_EXCEPTION_IF_NULL(to_device_context);
  1137. (void)copy_actor->device_contexts_.emplace_back(from_device_context);
  1138. (void)copy_actor->device_contexts_.emplace_back(to_device_context);
  1139. // Set the member output_ of the copy actor.
  1140. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1141. copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false);
  1142. } else {
  1143. copy_actor->output_ =
  1144. AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false);
  1145. }
  1146. MS_EXCEPTION_IF_NULL(copy_actor->output_);
  1147. if (copy_actor->output_->DeviceType() != to_device_context->GetDeviceAddressType()) {
  1148. MS_LOG(EXCEPTION) << "The device type is not equal, output device type:" << copy_actor->output_->DeviceType()
  1149. << ", to device context type:" << to_device_context->GetDeviceAddressType();
  1150. }
  1151. // Link between from actor and copy actor.
  1152. AddDataArrow(from_actor, copy_actor, from_kernel, from_kernel_with_output_idx.second, 0);
  1153. }
  1154. // If the copy actor already exists, only need link between copy actor and to actor.
  1155. AddDataArrow(copy_actor, to_actor, nullptr, 0, to_kernel_with_input_idx.second);
  1156. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1157. UpdateRefCount(copy_actor->output_.get(), true);
  1158. } else {
  1159. UpdateRefCount(copy_actor->output_.get(), false);
  1160. }
  1161. }
  1162. void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node,
  1163. const KernelGraphPtr &graph, const ControlNodeParserPtr &parser) {
  1164. MS_EXCEPTION_IF_NULL(to_actor);
  1165. MS_EXCEPTION_IF_NULL(from_node);
  1166. MS_EXCEPTION_IF_NULL(graph);
  1167. // Find the real input node, include the monad node and make tuple node.
  1168. const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
  1169. prim::kPrimMakeTuple};
  1170. const auto &input_kernel_with_output_idx =
  1171. common::AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
  1172. MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
  1173. auto input_anfnode = input_kernel_with_output_idx.first;
  1174. CNodePtr input_cnode = nullptr;
  1175. if (input_anfnode->isa<CNode>()) {
  1176. input_cnode = input_anfnode->cast<CNodePtr>();
  1177. }
  1178. // Make tuple node needs to be expanded.
  1179. if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
  1180. MS_EXCEPTION_IF_NULL(input_cnode);
  1181. for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
  1182. LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser);
  1183. }
  1184. return;
  1185. }
  1186. const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
  1187. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
  1188. // Get the real depend input by monad node which needs to link the control arrow.
  1189. std::vector<AnfNodePtr> real_depend_inputs;
  1190. if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
  1191. common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
  1192. MS_EXCEPTION_IF_NULL(input_cnode);
  1193. real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
  1194. // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
  1195. // node in this scene.
  1196. if (IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
  1197. real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
  1198. }
  1199. } else if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
  1200. MS_EXCEPTION_IF_NULL(input_cnode);
  1201. for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
  1202. real_depend_inputs.push_back(input_cnode->input(i));
  1203. }
  1204. } else {
  1205. real_depend_inputs.push_back(input_anfnode);
  1206. }
  1207. for (const auto &real_depend_input : real_depend_inputs) {
  1208. auto real_depend_input_with_idx =
  1209. common::AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
  1210. MS_EXCEPTION_IF_NULL(real_depend_input_with_idx.first);
  1211. auto real_depend_kernel = real_depend_input_with_idx.first;
  1212. // Update the real depend kernel in the subgraphs connecting scene.
  1213. if (IsInternalParameter(real_depend_kernel, graph)) {
  1214. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
  1215. MS_EXCEPTION_IF_NULL(front_output_with_index.first);
  1216. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1217. if (common::AnfAlgo::IsCallNode(front_output_with_index.first)) {
  1218. continue;
  1219. }
  1220. MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
  1221. }
  1222. if (parser != nullptr && parser->IsInited() &&
  1223. (!parser->IsSameKernelGraphGroup(front_output_with_index.first, graph))) {
  1224. MS_LOG(DEBUG) << "Skip in control flow from node:" << front_output_with_index.first->DebugString()
  1225. << " is not in the graph:" << graph->ToString();
  1226. continue;
  1227. }
  1228. real_depend_kernel = graph_output_to_actor_[front_output_with_index].second.first;
  1229. MS_EXCEPTION_IF_NULL(real_depend_kernel);
  1230. MS_LOG(INFO) << "The graph " << graph->graph_id() << " link control arrow by auto monad from internal parameter: "
  1231. << real_depend_input_with_idx.first->DebugString()
  1232. << ", front output node: " << front_output_with_index.first->fullname_with_scope()
  1233. << ", backend output node: " << real_depend_kernel->fullname_with_scope();
  1234. auto from_actor = graph_output_to_actor_[front_output_with_index].first;
  1235. if (from_actor != nullptr) {
  1236. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1237. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1238. AddControlArrow(from_actor, to_actor);
  1239. continue;
  1240. }
  1241. }
  1242. // The monad node and make tuple node need recursion.
  1243. if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
  1244. LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser);
  1245. continue;
  1246. }
  1247. auto type = FetchKernelTransformType(real_depend_kernel, nullptr);
  1248. auto from_actor = FetchActor(type, "", real_depend_kernel);
  1249. if (from_actor == nullptr) {
  1250. MS_LOG(DEBUG) << "Link control arrow by auto monad from depend node:" << real_depend_kernel->fullname_with_scope()
  1251. << " is not actor for the graph: " << graph->graph_id();
  1252. continue;
  1253. }
  1254. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1255. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1256. AddControlArrow(from_actor, to_actor);
  1257. }
  1258. }
  1259. void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node) {
  1260. MS_EXCEPTION_IF_NULL(to_actor);
  1261. MS_EXCEPTION_IF_NULL(skipped_node);
  1262. // Link the control arrow from all the inputs of skipped node to the user of skipped node.
  1263. auto input_num = common::AnfAlgo::GetInputTensorNum(skipped_node);
  1264. for (size_t i = 0; i < input_num; ++i) {
  1265. auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
  1266. MS_EXCEPTION_IF_NULL(kernel_with_index.first);
  1267. auto from_actor = FetchActor(kernel_with_index.first->fullname_with_scope());
  1268. MS_EXCEPTION_IF_NULL(from_actor);
  1269. MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
  1270. << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
  1271. AddControlArrow(from_actor, to_actor);
  1272. }
  1273. }
  1274. void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
  1275. MS_EXCEPTION_IF_NULL(graph);
  1276. for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
  1277. auto to_allreduce_node = from_iter.first;
  1278. auto from_send_node = from_iter.second.first;
  1279. auto from_recv_node = from_iter.second.second;
  1280. MS_EXCEPTION_IF_NULL(to_allreduce_node);
  1281. MS_EXCEPTION_IF_NULL(from_send_node);
  1282. MS_EXCEPTION_IF_NULL(from_recv_node);
  1283. MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
  1284. auto to_allreduce_actor = FetchActor(to_allreduce_node->fullname_with_scope());
  1285. auto from_send_actor = FetchActor(from_send_node->fullname_with_scope());
  1286. auto from_recv_actor = FetchActor(from_recv_node->fullname_with_scope());
  1287. MS_EXCEPTION_IF_NULL(to_allreduce_actor);
  1288. MS_EXCEPTION_IF_NULL(from_send_actor);
  1289. MS_EXCEPTION_IF_NULL(from_recv_actor);
  1290. // inputs of to_allreduce_actor --> from_send_actor
  1291. for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
  1292. auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
  1293. if (input_actor != nullptr) {
  1294. AddControlArrow(input_actor, from_send_actor);
  1295. }
  1296. }
  1297. // from_send_actor --> from_recv_actor
  1298. AddControlArrow(from_send_actor, from_recv_actor);
  1299. // from_recv_actor --> to_allreduce_actor
  1300. AddControlArrow(from_recv_actor, to_allreduce_actor);
  1301. }
  1302. for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
  1303. auto from_allreduce_node = to_iter.first;
  1304. auto to_send_node = to_iter.second.first;
  1305. auto to_recv_node = to_iter.second.second;
  1306. MS_EXCEPTION_IF_NULL(from_allreduce_node);
  1307. MS_EXCEPTION_IF_NULL(to_send_node);
  1308. MS_EXCEPTION_IF_NULL(to_recv_node);
  1309. MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
  1310. auto from_allreduce_actor = FetchActor(from_allreduce_node->fullname_with_scope());
  1311. auto to_send_actor = FetchActor(to_send_node->fullname_with_scope());
  1312. auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
  1313. MS_EXCEPTION_IF_NULL(from_allreduce_actor);
  1314. MS_EXCEPTION_IF_NULL(to_send_actor);
  1315. MS_EXCEPTION_IF_NULL(to_recv_actor);
  1316. // from_allreduce_actor --> to_send_actor
  1317. AddControlArrow(from_allreduce_actor, to_send_actor);
  1318. // to_send_actor --> to_recv_actor
  1319. AddControlArrow(to_send_actor, to_recv_actor);
  1320. // to_recv_actor --> outputs of from_allreduce_actor
  1321. for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
  1322. auto output_actor = FetchActor(output_data_arrow->to_op_id_.Name());
  1323. if (output_actor != nullptr) {
  1324. AddControlArrow(to_recv_actor, output_actor);
  1325. }
  1326. }
  1327. // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
  1328. // reused only when the recv node runs finished, which is expressed by the reference count increased.
  1329. for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
  1330. auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
  1331. MS_EXCEPTION_IF_NULL(device_tensor);
  1332. UpdateRefCount(device_tensor.get());
  1333. (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
  1334. }
  1335. }
  1336. }
  1337. void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set,
  1338. const GroupNameToCommuNodes &communication_node_groups,
  1339. const std::vector<AbstractActor *> &auto_monad_actors,
  1340. const GraphCompilerInfo &graph_compiler_info) {
  1341. MS_EXCEPTION_IF_NULL(actor_set);
  1342. for (const auto &communication_nodes : communication_node_groups) {
  1343. // Link the control arrows by the communication nodes to ensure communication nodes running order.
  1344. LinkControlArrowByCommunicationNode(communication_nodes.second, graph_compiler_info);
  1345. }
  1346. // Auto monad actor may modify the device tensor store.
  1347. LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
  1348. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  1349. actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
  1350. // Link the control arrows of data prepare actor, which depends on the no input kernel actors.
  1351. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) || (!IsSingleOpActorSet(actor_set))) {
  1352. LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set,
  1353. graph_compiler_info.control_node_parser_);
  1354. }
  1355. // Link control arrows for custom actor.
  1356. LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
  1357. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
  1358. graph_compiler_info.control_node_parser_);
  1359. LinkControlArrowForOutputActor(actor_set->output_actor_.get(), actor_set);
  1360. }
  1361. void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
  1362. const GraphCompilerInfo &graph_compiler_info) {
  1363. constexpr size_t kDependFromIdx = 2;
  1364. constexpr size_t kDependToIdx = 1;
  1365. MS_EXCEPTION_IF_NULL(actor_set);
  1366. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  1367. // prepare for kernel => actor map
  1368. HashMap<AnfNodePtr, AbstractActorPtr> kernel_to_actors = {};
  1369. HashSet<CustomActorPtr> no_depend_custom_actors = {};
  1370. for (const auto &actor : actor_set->custom_actors_) {
  1371. MS_EXCEPTION_IF_NULL(actor);
  1372. auto kernel = actor->kernel().lock();
  1373. MS_EXCEPTION_IF_NULL(kernel);
  1374. kernel_to_actors.emplace(kernel, actor);
  1375. no_depend_custom_actors.insert(actor);
  1376. }
  1377. for (const auto &actor : actor_set->kernel_actors_) {
  1378. MS_EXCEPTION_IF_NULL(actor);
  1379. auto kernel = actor->kernel();
  1380. MS_EXCEPTION_IF_NULL(kernel);
  1381. kernel_to_actors.emplace(kernel, actor);
  1382. }
  1383. for (const auto &actor : actor_set->data_source_actors_) {
  1384. MS_EXCEPTION_IF_NULL(actor);
  1385. auto device_data_source_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
  1386. if (device_data_source_actor != nullptr) {
  1387. auto kernel = device_data_source_actor->data_kernel();
  1388. MS_EXCEPTION_IF_NULL(kernel);
  1389. if (common::AnfAlgo::GetCNodeName(kernel) == kGetNextOpName) {
  1390. kernel_to_actors.emplace(kernel, actor);
  1391. }
  1392. }
  1393. }
  1394. // find depend(custom, custom)
  1395. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1396. const auto &graph = graph_compiler_info.graphs_[i];
  1397. MS_EXCEPTION_IF_NULL(graph);
  1398. if (graph->is_executing_sink()) {
  1399. continue;
  1400. }
  1401. auto all_nodes = TopoSort(graph->get_return());
  1402. for (const auto &node : all_nodes) {
  1403. if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
  1404. continue;
  1405. }
  1406. auto depend_cnode = node->cast<CNodePtr>();
  1407. MS_EXCEPTION_IF_NULL(depend_cnode);
  1408. MS_EXCEPTION_IF_CHECK_FAIL(depend_cnode->size() > kDependFromIdx,
  1409. "depend node " + depend_cnode->DebugString() + " input size " +
  1410. std::to_string(depend_cnode->size()) + " is invalid.");
  1411. MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependFromIdx));
  1412. MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependToIdx));
  1413. auto from_node = depend_cnode->input(kDependFromIdx);
  1414. auto to_node = depend_cnode->input(kDependToIdx);
  1415. if (!AnfUtils::IsCustomActorNode(from_node) && !AnfUtils::IsCustomActorNode(to_node)) {
  1416. continue;
  1417. }
  1418. auto from_iter = kernel_to_actors.find(from_node);
  1419. if (from_iter == kernel_to_actors.end()) {
  1420. MS_LOG(INFO) << from_node->fullname_with_scope() << " is a CNode but cannot find Actor.";
  1421. continue;
  1422. }
  1423. auto to_iter = kernel_to_actors.find(to_node);
  1424. if (to_iter == kernel_to_actors.end()) {
  1425. MS_LOG(INFO) << to_node->fullname_with_scope() << " is a CNode but cannot find Actor.";
  1426. continue;
  1427. }
  1428. AddControlArrow(from_iter->second.get(), to_iter->second.get());
  1429. no_depend_custom_actors.erase(std::dynamic_pointer_cast<CustomActor>(to_iter->second));
  1430. }
  1431. }
  1432. // In control flow, no input actors should be linked to entrance actors.
  1433. const auto &parser = graph_compiler_info.control_node_parser_;
  1434. MS_EXCEPTION_IF_NULL(parser);
  1435. if (parser->IsInited()) {
  1436. return;
  1437. }
  1438. for (const auto &custom_actor : no_depend_custom_actors) {
  1439. auto kernel = custom_actor->kernel().lock();
  1440. MS_EXCEPTION_IF_NULL(kernel);
  1441. auto base_node = AnfUtils::GetCustomActorBaseNode(kernel);
  1442. MS_EXCEPTION_IF_NULL(base_node);
  1443. auto dynamic_shape_depends = abstract::GetDependsFormMap(base_node);
  1444. for (auto iter = dynamic_shape_depends.begin(); iter != dynamic_shape_depends.end(); ++iter) {
  1445. auto input_node = common::AnfAlgo::GetInputNode(base_node, *iter);
  1446. KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  1447. auto graph = FetchKernelGraph(from_kernel_with_output_idx.first);
  1448. auto kernel_type =
  1449. FetchKernelTransformType(from_kernel_with_output_idx.first, graph, graph_compiler_info.origin_parameters_order_,
  1450. graph_compiler_info.strategy_);
  1451. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, from_kernel_with_output_idx.first, graph);
  1452. // The input_node maybe a data(Tensor) and the from_actor is nullptr
  1453. if (from_actor != nullptr) {
  1454. AddDataArrow(from_actor, custom_actor.get(), from_kernel_with_output_idx.first,
  1455. from_kernel_with_output_idx.second, *iter);
  1456. }
  1457. }
  1458. AddControlArrow(actor_set->data_prepare_actor_.get(), custom_actor.get());
  1459. }
  1460. }
  1461. void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  1462. const GraphCompilerInfo &graph_compiler_info) {
  1463. const size_t kCommunicationNodesMinNum = 2;
  1464. if (communication_nodes.size() < kCommunicationNodesMinNum) {
  1465. return;
  1466. }
  1467. // Ensure communication node to execute orderly.
  1468. for (size_t i = 1; i < communication_nodes.size(); ++i) {
  1469. auto from_actor = FetchActor(communication_nodes[i - 1]->fullname_with_scope());
  1470. auto to_actor = FetchActor(communication_nodes[i]->fullname_with_scope());
  1471. MS_EXCEPTION_IF_NULL(from_actor);
  1472. MS_EXCEPTION_IF_NULL(to_actor);
  1473. AddControlArrow(from_actor, to_actor);
  1474. }
  1475. // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
  1476. // Using the multi stream to optimize the performance in the future.
  1477. for (auto &graph : graph_compiler_info.graphs_) {
  1478. MS_EXCEPTION_IF_NULL(graph);
  1479. auto &execution_order = graph->execution_order();
  1480. for (size_t i = 1; i < execution_order.size(); ++i) {
  1481. auto from_actor = FetchActor(execution_order[i - 1]->fullname_with_scope());
  1482. auto to_actor = FetchActor(execution_order[i]->fullname_with_scope());
  1483. if ((from_actor != nullptr) && (to_actor != nullptr)) {
  1484. AddControlArrow(from_actor, to_actor);
  1485. }
  1486. }
  1487. }
  1488. }
  1489. void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor,
  1490. const ActorSet *actor_set,
  1491. const ControlNodeParserPtr &parser) {
  1492. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  1493. MS_EXCEPTION_IF_NULL(actor_set);
  1494. MS_EXCEPTION_IF_NULL(parser);
  1495. // Data prepare actor --> data source actor.
  1496. for (auto &data_source_actor : actor_set->data_source_actors_) {
  1497. MS_EXCEPTION_IF_NULL(data_source_actor);
  1498. AddControlArrow(data_prepare_actor, data_source_actor.get());
  1499. }
  1500. // In control flow, control arrow of no input kernel actor needs to be connected to the corresponding entrance actor.
  1501. if (!parser->IsInited()) {
  1502. // Data prepare actor --> no input kernel actor.
  1503. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  1504. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  1505. AddControlArrow(data_prepare_actor, no_input_kernel_actor.get());
  1506. }
  1507. }
  1508. // Data prepare actor --> loop count actor.
  1509. if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
  1510. (actor_set->loop_count_actor_ != nullptr)) {
  1511. AddControlArrow(data_prepare_actor, actor_set->loop_count_actor_.get());
  1512. }
  1513. }
  1514. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  1515. const ControlNodeParserPtr &parser) {
  1516. MS_EXCEPTION_IF_NULL(actor_set);
  1517. MS_EXCEPTION_IF_NULL(parser);
  1518. // There is no loop count actor in step mode.
  1519. if (loop_count_actor == nullptr) {
  1520. return;
  1521. }
  1522. // Collect the actors which have no output.
  1523. std::vector<AbstractActor *> no_output_actors;
  1524. for (auto &super_actor : actor_set->super_kernel_actors_) {
  1525. if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
  1526. (void)no_output_actors.emplace_back(super_actor.get());
  1527. }
  1528. }
  1529. for (auto &kernel_actor : actor_set->kernel_actors_) {
  1530. // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
  1531. if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
  1532. (void)no_output_actors.emplace_back(kernel_actor.get());
  1533. }
  1534. }
  1535. for (auto &data_actor : actor_set->data_source_actors_) {
  1536. if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
  1537. (void)no_output_actors.emplace_back(data_actor.get());
  1538. }
  1539. }
  1540. for (auto &copy_actor : copy_actors_) {
  1541. if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
  1542. (void)no_output_actors.emplace_back(copy_actor.get());
  1543. }
  1544. }
  1545. for (auto &custom_actor : actor_set->custom_actors_) {
  1546. if ((custom_actor->output_data_arrows_.size() == 0) && (custom_actor->output_control_arrows_.size() == 0)) {
  1547. (void)no_output_actors.emplace_back(custom_actor.get());
  1548. }
  1549. }
  1550. // No output actor --> loop count actor.
  1551. // In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count.
  1552. if (!parser->IsInited()) {
  1553. for (auto &no_output_actor : no_output_actors) {
  1554. AddControlArrow(no_output_actor, loop_count_actor);
  1555. }
  1556. }
  1557. // Loop count actor --> output actor.
  1558. AddControlArrow(loop_count_actor, actor_set->output_actor_.get());
  1559. // Loop count actor --> data prepare actor.
  1560. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  1561. loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
  1562. actor_set->data_prepare_actor_->input_controls_num_++;
  1563. (void)actor_set->data_prepare_actor_->input_control_arrow_aids_.emplace_back(loop_count_actor->GetAID());
  1564. }
  1565. void GraphScheduler::LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set) {
  1566. MS_EXCEPTION_IF_NULL(actor_set);
  1567. // There is no output actor in step mode.
  1568. if (output_actor == nullptr) {
  1569. return;
  1570. }
  1571. // Output actor --> data prepare actor.
  1572. // The output actor needs to free the output memory in the running and needs this control arrow.
  1573. AddControlArrow(output_actor, actor_set->data_prepare_actor_.get());
  1574. }
  1575. void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
  1576. const GraphCompilerInfo &graph_compiler_info) {
  1577. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep ||
  1578. (graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) {
  1579. // In control flow, the exit actor of the root graph sends output data to the output actor.
  1580. return;
  1581. }
  1582. MS_EXCEPTION_IF_NULL(to_actor);
  1583. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1584. const auto &graph = graph_compiler_info.graphs_[i];
  1585. MS_EXCEPTION_IF_NULL(graph);
  1586. auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
  1587. std::set<std::vector<size_t>> unique_output_positions;
  1588. std::set<KernelWithIndex> unique_outputs;
  1589. for (const auto &output : outputs) {
  1590. if (IsInternalParameter(output.first, graph)) {
  1591. MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
  1592. continue;
  1593. }
  1594. (void)unique_outputs.insert(output);
  1595. }
  1596. for (const auto &output_with_index : unique_outputs) {
  1597. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1598. auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
  1599. const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
  1600. if (iter == graph_compiler_info.origin_outputs_order_.end()) {
  1601. continue;
  1602. }
  1603. // Skip duplicate position.
  1604. if (unique_output_positions.count(iter->second) > 0) {
  1605. continue;
  1606. }
  1607. (void)unique_output_positions.insert(iter->second);
  1608. for (auto &output_position : iter->second) {
  1609. if (output_position >= to_actor->device_contexts_.size()) {
  1610. MS_LOG(EXCEPTION) << "The output position is out of range.";
  1611. }
  1612. to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
  1613. // The graph output is from device tensor store.
  1614. if (IsPersistentDeviceTensor(output_with_index.first)) {
  1615. (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
  1616. if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) {
  1617. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1618. MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit";
  1619. continue;
  1620. }
  1621. // In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0
  1622. // as output_idx directly.
  1623. auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false);
  1624. MS_EXCEPTION_IF_NULL(device_tensor);
  1625. // The output actor need use the relevant information of node to create output tensor.
  1626. device_tensor->SetNodeIndex(output_with_index.first, 0);
  1627. continue;
  1628. }
  1629. // The graph output is from kernel actor or data source actor.
  1630. auto kernel_type = FetchKernelTransformType(
  1631. output_with_index.first, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_);
  1632. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_with_index.first, graph);
  1633. if (from_actor == nullptr) {
  1634. continue;
  1635. }
  1636. auto real_from_kernel = output_with_index.first;
  1637. // Update the real node in the host data source actor.
  1638. if (kernel_type == KernelTransformType::kHostDataSourceActor) {
  1639. auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1640. MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
  1641. auto position = host_queue_ds_actor->FetchNodePosition(output_with_index.first);
  1642. real_from_kernel = host_queue_ds_actor->FetchNode(position);
  1643. UpdateRefCount(output_with_index.first, output_with_index.second, true);
  1644. }
  1645. AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second, output_position);
  1646. }
  1647. }
  1648. }
  1649. }
  1650. void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors) {
  1651. const size_t kNeedUpdateDeviceTensorStoreNum = 2;
  1652. for (auto &auto_monad_actor : auto_monad_actors) {
  1653. MS_EXCEPTION_IF_NULL(auto_monad_actor);
  1654. for (auto &device_tensor_store_key : auto_monad_actor->device_tensor_store_keys_) {
  1655. auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
  1656. if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
  1657. continue;
  1658. }
  1659. // Find the device tensor store that needs to be processed accurately.
  1660. if ((auto_monad_actor->type_ == KernelTransformType::kSuperKernelActor) &&
  1661. (auto_monad_actor->auto_monad_device_tensor_stores_.find(device_tensor_store_key.second) ==
  1662. auto_monad_actor->auto_monad_device_tensor_stores_.end())) {
  1663. continue;
  1664. }
  1665. // Create the copy actor.
  1666. std::string name = "copy_from:" + auto_monad_actor->GetAID().Name() +
  1667. "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
  1668. if (FetchActor(name) != nullptr) {
  1669. continue;
  1670. }
  1671. auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1672. MS_EXCEPTION_IF_NULL(copy_actor);
  1673. (void)copy_actors_.emplace_back(copy_actor);
  1674. InsertActor(copy_actor.get());
  1675. // Set the member of the copy actor.
  1676. (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
  1677. auto input_device_context = auto_monad_actor->device_contexts_[0];
  1678. (void)copy_actor->device_contexts_.emplace_back(input_device_context);
  1679. auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
  1680. ? device_tensors[1]
  1681. : device_tensors[0];
  1682. MS_EXCEPTION_IF_NULL(another_device_tensor);
  1683. auto another_device_type = another_device_tensor->DeviceType();
  1684. const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
  1685. {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
  1686. MS_EXCEPTION_IF_NULL(another_device_context);
  1687. (void)copy_actor->device_contexts_.emplace_back(another_device_context);
  1688. MS_LOG(INFO) << "The auto monad actor: " << auto_monad_actor->GetAID().Name()
  1689. << "has control arrows number:" << auto_monad_actor->output_control_arrows_.size();
  1690. // Link from copy actor to auto monad actor users.
  1691. for (auto &output_contorl : auto_monad_actor->output_control_arrows_) {
  1692. (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
  1693. }
  1694. // Move the control arrows from auto monad actor to auto monad actor users.
  1695. auto_monad_actor->output_control_arrows_.clear();
  1696. // Link from auto monad actor to copy actor.
  1697. AddControlArrow(auto_monad_actor, copy_actor.get());
  1698. }
  1699. }
  1700. }
  1701. void GraphScheduler::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
  1702. MS_EXCEPTION_IF_NULL(device_tensor);
  1703. DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
  1704. UpdateRefCount(device_tensor.get(), true);
  1705. }
  1706. void GraphScheduler::AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1707. const AnfNodePtr &from_kernel, size_t from_output_index, size_t to_input_index) {
  1708. MS_EXCEPTION_IF_NULL(from_actor);
  1709. MS_EXCEPTION_IF_NULL(to_actor);
  1710. auto data_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), to_input_index);
  1711. (void)from_actor->output_data_arrows_.emplace_back(data_arrow);
  1712. (void)from_actor->output_data_nodes_.emplace_back(from_kernel);
  1713. to_actor->input_datas_num_++;
  1714. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1715. if (from_kernel == nullptr) {
  1716. return;
  1717. }
  1718. // Update the reference count of from_kernel.
  1719. // The device address of super kernel actor can't be changed, so set the max reference count.
  1720. if ((from_actor->type_ == KernelTransformType::kSuperKernelActor) ||
  1721. (to_actor->type_ == KernelTransformType::kSuperKernelActor)) {
  1722. UpdateRefCount(from_kernel, from_output_index, true);
  1723. } else {
  1724. UpdateRefCount(from_kernel, from_output_index, false);
  1725. }
  1726. }
  1727. void GraphScheduler::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
  1728. const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
  1729. MS_EXCEPTION_IF_NULL(from_actor);
  1730. MS_EXCEPTION_IF_NULL(to_actor);
  1731. MS_EXCEPTION_IF_NULL(from_kernel);
  1732. auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
  1733. (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
  1734. (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
  1735. to_actor->input_datas_num_++;
  1736. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1737. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
  1738. MS_EXCEPTION_IF_NULL(device_tensor);
  1739. // The output actor need use the relevant information of node to create output tensor.
  1740. device_tensor->SetNodeIndex(from_kernel, from_output_index);
  1741. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
  1742. UpdateRefCount(device_tensor.get(), true);
  1743. }
  1744. void GraphScheduler::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
  1745. MS_EXCEPTION_IF_NULL(from_actor);
  1746. MS_EXCEPTION_IF_NULL(to_actor);
  1747. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1748. to_actor->input_controls_num_++;
  1749. (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
  1750. }
  1751. void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
  1752. MS_EXCEPTION_IF_NULL(actor_set);
  1753. auto actors = CollectActors(actor_set);
  1754. for (auto &actor : actors) {
  1755. MS_EXCEPTION_IF_NULL(actor);
  1756. if (actor->type_ >= KernelTransformType::kSwitchActor) {
  1757. continue;
  1758. }
  1759. if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
  1760. (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
  1761. MS_LOG(EXCEPTION) << "The input num of " << actor->GetAID().Name()
  1762. << " is wrong, expect data num: " << actor->input_datas_num_
  1763. << ", actual data num: " << actor->input_data_arrow_aids_.size()
  1764. << ", expect control num: " << actor->input_controls_num_
  1765. << ", actual control num: " << actor->input_control_arrow_aids_.size();
  1766. }
  1767. if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kCustomActor) &&
  1768. (actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0)) {
  1769. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
  1770. }
  1771. if ((actor->type_ != KernelTransformType::kDataPrepareActor) &&
  1772. (actor->type_ != KernelTransformType::kCustomActor) && (actor->input_datas_num_ == 0) &&
  1773. (actor->input_controls_num_ == 0)) {
  1774. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
  1775. }
  1776. // Check the input of kernel actors and copy actors.
  1777. if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
  1778. size_t expect_toal_input_num = 1;
  1779. if (actor->type_ == KernelTransformType::kKernelActor) {
  1780. auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
  1781. MS_EXCEPTION_IF_NULL(kernel_actor);
  1782. expect_toal_input_num = common::AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
  1783. }
  1784. auto input_data_num = actor->input_datas_num_;
  1785. auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
  1786. if (input_data_num + device_tensor_store_num != expect_toal_input_num) {
  1787. MS_LOG(EXCEPTION) << "The input building of " << actor->GetAID().Name()
  1788. << " is wrong, input data num: " << input_data_num
  1789. << ", device tensor store num: " << device_tensor_store_num
  1790. << ", total input num: " << expect_toal_input_num;
  1791. }
  1792. }
  1793. }
  1794. // Check the output actor.
  1795. auto output_actor = actor_set->output_actor_;
  1796. MS_EXCEPTION_IF_NULL(output_actor);
  1797. if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num_) {
  1798. MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: "
  1799. << output_actor->outputs_num_ << ", the input data arrows num: " << output_actor->input_datas_num_
  1800. << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
  1801. }
  1802. control_node_scheduler_.CheckActorValid(actor_set);
  1803. }
  1804. void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
  1805. const auto &parser = graph_compiler_info.control_node_parser_;
  1806. MS_EXCEPTION_IF_NULL(parser);
  1807. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1808. const auto &graph = graph_compiler_info.graphs_[i];
  1809. const auto &device_context = graph_compiler_info.device_contexts_[i];
  1810. MS_EXCEPTION_IF_NULL(graph);
  1811. MS_EXCEPTION_IF_NULL(device_context);
  1812. for (auto &value_node : graph->graph_value_nodes()) {
  1813. MS_EXCEPTION_IF_NULL(value_node);
  1814. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1815. MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
  1816. continue;
  1817. }
  1818. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
  1819. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1820. device_tensor->SetNodeIndex(value_node, 0);
  1821. AddDeviceTensorStore(front_node.get(), device_tensor);
  1822. }
  1823. for (auto &input_node : graph->input_nodes()) {
  1824. MS_EXCEPTION_IF_NULL(input_node);
  1825. AnfNodePtr front_node = nullptr;
  1826. if (IsInternalParameter(input_node, graph)) {
  1827. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
  1828. front_node = front_output_with_index.first;
  1829. } else if (IsPersistentDeviceTensor(input_node)) {
  1830. front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1831. }
  1832. // The front node may be value node in the heterogeneous scene, needs to handle.
  1833. if ((front_node == nullptr) ||
  1834. (front_node->isa<Parameter>() && !parser->IsRootGraphPersistentDeviceTensor(front_node))) {
  1835. continue;
  1836. }
  1837. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
  1838. MS_EXCEPTION_IF_NULL(device_tensor);
  1839. if (IsPersistentDeviceTensor(input_node) || device_tensor->is_ptr_persisted()) {
  1840. device_tensor->SetNodeIndex(input_node, 0);
  1841. AddDeviceTensorStore(front_node.get(), device_tensor);
  1842. }
  1843. // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
  1844. if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
  1845. MS_LOG(WARNING) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
  1846. << ", type:" << device_context->GetDeviceAddressType();
  1847. auto other_type_device_tensor =
  1848. device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), device_tensor->format(),
  1849. device_tensor->type_id(), device_tensor->host_shape());
  1850. other_type_device_tensor->SetNodeIndex(input_node, 0);
  1851. other_type_device_tensor->set_from_persistent_mem(input_node->isa<Parameter>());
  1852. AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
  1853. }
  1854. }
  1855. }
  1856. PersistDeviceTensorForRootGraphControlNode(graph_compiler_info);
  1857. }
  1858. void GraphScheduler::PersistDeviceTensorForRootGraphControlNode(const GraphCompilerInfo &graph_compiler_info) {
  1859. const auto &parser = graph_compiler_info.control_node_parser_;
  1860. if (parser == nullptr || (!parser->IsInited())) {
  1861. return;
  1862. }
  1863. for (auto &root_graph_parameter : graph_compiler_info.origin_parameters_order_) {
  1864. MS_EXCEPTION_IF_NULL(root_graph_parameter);
  1865. if (!IsPersistentDeviceTensor(root_graph_parameter)) {
  1866. continue;
  1867. }
  1868. // The device tensor store has been done in the backend kernel graph corresponding to the root graph.
  1869. if (!DeviceTensorStore::GetInstance().Fetch(root_graph_parameter.get()).empty()) {
  1870. continue;
  1871. }
  1872. // The different root graph parameters may correspond to parameter of same sub kernel graph when call the same sub
  1873. // graph using the different root graph parameters. So can not use the device tensor of sub kernel graph parameter
  1874. // directly and choose the first backend parameter in sub kernel graphs to create new device tensor to make sure
  1875. // that the device tensor of root graph parameters are different.
  1876. const auto &backend_parameter_with_context =
  1877. parser->FetchBackendParameterWithContextByFrontParameter({root_graph_parameter, 0});
  1878. if (backend_parameter_with_context.first == nullptr) {
  1879. MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << root_graph_parameter->DebugString();
  1880. }
  1881. const auto &backend_node = backend_parameter_with_context.first;
  1882. const auto &device_context = backend_parameter_with_context.second;
  1883. MS_EXCEPTION_IF_NULL(backend_node);
  1884. MS_EXCEPTION_IF_NULL(device_context);
  1885. auto sub_device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
  1886. MS_EXCEPTION_IF_NULL(sub_device_tensor);
  1887. auto new_device_tensor =
  1888. device_context->CreateDeviceAddress(nullptr, sub_device_tensor->GetSize(), sub_device_tensor->format(),
  1889. sub_device_tensor->type_id(), sub_device_tensor->host_shape());
  1890. MS_EXCEPTION_IF_NULL(new_device_tensor);
  1891. new_device_tensor->SetNodeIndex(backend_node, 0);
  1892. new_device_tensor->set_is_ptr_persisted(sub_device_tensor->is_ptr_persisted());
  1893. new_device_tensor->set_from_persistent_mem(true);
  1894. AddDeviceTensorStore(root_graph_parameter.get(), new_device_tensor);
  1895. MS_LOG(INFO) << "Add device tensor store by root graph parameter:" << root_graph_parameter->fullname_with_scope()
  1896. << ", backend node:" << backend_node->DebugString()
  1897. << ", type:" << device_context->GetDeviceAddressType();
  1898. }
  1899. }
  1900. void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
  1901. MS_EXCEPTION_IF_NULL(actor_set);
  1902. const auto &context_ptr = MsContext::GetInstance();
  1903. MS_EXCEPTION_IF_NULL(context_ptr);
  1904. auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1905. if (!save_graphs) {
  1906. return;
  1907. }
  1908. // Get the saved actor set name.
  1909. auto &kernel_graphs = graph_compiler_info.graphs_;
  1910. MS_EXCEPTION_IF_NULL(kernel_graphs.front());
  1911. auto first_graph_id = kernel_graphs.front()->graph_id();
  1912. MS_EXCEPTION_IF_NULL(kernel_graphs.back());
  1913. auto last_graph_id = kernel_graphs.back()->graph_id();
  1914. std::string strategy = (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) ? "pipeline" : "step";
  1915. std::string save_name = "actor_set_" + strategy + "_kernel_graph_" + std::to_string(first_graph_id);
  1916. if (last_graph_id != first_graph_id) {
  1917. save_name = save_name + "-" + std::to_string(last_graph_id);
  1918. }
  1919. std::string filename = GetSaveGraphsPathName(save_name + ".ir");
  1920. std::ofstream ofs(filename);
  1921. if (!ofs.is_open()) {
  1922. MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
  1923. return;
  1924. }
  1925. DumpDeviceTensorStore(graph_compiler_info, ofs);
  1926. DumpDataPrepareActor(actor_set->data_prepare_actor_, ofs);
  1927. DumpDSActors(actor_set->data_source_actors_, ofs);
  1928. DumpKernelActors(actor_set->kernel_actors_, ofs);
  1929. DumpSuperKernelActors(actor_set->super_kernel_actors_, ofs);
  1930. // The on input kernel actors are taken over by control actor in the control flow scene.
  1931. if ((graph_compiler_info.control_node_parser_ == nullptr) ||
  1932. (!graph_compiler_info.control_node_parser_->IsInited())) {
  1933. DumpNoInputKernelActors(actor_set->no_input_kernel_actors_, ofs);
  1934. }
  1935. DumpCopyActors(actor_set->copy_actors_, ofs);
  1936. DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
  1937. DumpOutputActor(actor_set->output_actor_, ofs);
  1938. DumpControlActors(actor_set->control_actors_, ofs);
  1939. DumpCustomActors(actor_set->custom_actors_, ofs);
  1940. }
  1941. void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
  1942. ofs << "[Device tensor stores]\n";
  1943. for (const auto &graph : graph_compiler_info.graphs_) {
  1944. MS_EXCEPTION_IF_NULL(graph);
  1945. ofs << "\tgraph_id:" << graph->graph_id() << "\tis_executing_sink:" << graph->is_executing_sink()
  1946. << "\tis_loop_count_sink:" << graph->is_loop_count_sink()
  1947. << "\texecution_strategy:" << graph_compiler_info.strategy_ << "\n";
  1948. for (auto &value_node : graph->graph_value_nodes()) {
  1949. MS_EXCEPTION_IF_NULL(value_node);
  1950. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1951. continue;
  1952. }
  1953. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1954. MS_EXCEPTION_IF_NULL(front_node);
  1955. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1956. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1957. << "\n";
  1958. for (const auto &device_tensor : device_tensors) {
  1959. MS_EXCEPTION_IF_NULL(device_tensor);
  1960. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1961. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1962. << "\tdynamic_ref_count:" << device_tensor->dynamic_ref_count()
  1963. << "\tdevice_type:" << device_tensor->DeviceType()
  1964. << "\tis_ptr_persisted:" << device_tensor->is_ptr_persisted() << "\n ";
  1965. }
  1966. }
  1967. for (auto &input_node : graph->input_nodes()) {
  1968. MS_EXCEPTION_IF_NULL(input_node);
  1969. if (!IsPersistentDeviceTensor(input_node)) {
  1970. continue;
  1971. }
  1972. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1973. const auto &root_parameters = graph_compiler_info.origin_parameters_order_;
  1974. if (front_node == nullptr ||
  1975. find(root_parameters.begin(), root_parameters.end(), front_node) == root_parameters.end()) {
  1976. continue;
  1977. }
  1978. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1979. MS_EXCEPTION_IF_NULL(front_node);
  1980. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1981. << "\n";
  1982. for (const auto &device_tensor : device_tensors) {
  1983. MS_EXCEPTION_IF_NULL(device_tensor);
  1984. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1985. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1986. << "\tdynamic_ref_count:" << device_tensor->dynamic_ref_count()
  1987. << "\tdevice_type:" << device_tensor->DeviceType()
  1988. << "\tis_ptr_persisted:" << device_tensor->is_ptr_persisted() << "\n ";
  1989. }
  1990. }
  1991. ofs << "\n";
  1992. for (auto &backend_front_map : graph->backend_front_anf_map()) {
  1993. MS_EXCEPTION_IF_NULL(backend_front_map.first);
  1994. MS_EXCEPTION_IF_NULL(backend_front_map.second);
  1995. MS_LOG(DEBUG) << "Graph: " << graph->graph_id()
  1996. << ", backend node: " << backend_front_map.first->fullname_with_scope()
  1997. << ", front node: " << backend_front_map.second->DebugString();
  1998. }
  1999. }
  2000. }
  2001. } // namespace runtime
  2002. } // namespace mindspore